Skip to content

Commit 7992bea

Browse files
feat: add caching
1 parent b9f04ec commit 7992bea

14 files changed

+476
-45
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ __pycache__
33
.venv
44
.vscode
55
.idea
6+
*.db

requirements.txt

+3
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
fastapi~=0.115.0
22
requests~=2.32.3
3+
mysql-connector~=2.2.9
4+
uvicorn~=0.21.1
5+
pydantic~=1.10.7

server.py

-45
This file was deleted.

server/__init__.py

Whitespace-only changes.

server/caching/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .cache_handler import get_cache_provider

server/caching/cache_handler.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
2+
import os
3+
4+
import requests
5+
from server.caching.cache_provider import CacheProvider
6+
from server.caching.cache_provider.in_memory_cache_provider import InMemoryCacheProvider
7+
from server.caching.cache_provider.mysql_cache_provider import MySQLCacheProvider
8+
from server.caching.cache_provider.no_cache_provider import NoCacheProvider
9+
from server.caching.cache_provider.sqlite_cache_provider import SQLiteCacheProvider
10+
from server.caching.cache_request import CacheRequest
11+
12+
CACHE_DB_FILE = "./cache.db"
13+
14+
def get_cache_provider() -> CacheProvider:
15+
cache_type = os.environ.get("CACHE_MODE", "memory")
16+
17+
print("Using cache provider: " + cache_type)
18+
19+
match cache_type:
20+
case "sqlite":
21+
db_file = os.environ.get("SQLITE_FILE", CACHE_DB_FILE)
22+
return SQLiteCacheProvider(db_file)
23+
case "memory":
24+
return InMemoryCacheProvider()
25+
case "mysql":
26+
host = os.environ.get("MYSQL_HOST")
27+
user = os.environ.get("MYSQL_USER")
28+
password = os.environ.get("MYSQL_PASSWORD")
29+
database = os.environ.get("MYSQL_DATABASE")
30+
port = os.environ.get("MYSQL_PORT", 3306)
31+
32+
if None in (host, user, password, database):
33+
print("WARNING: Missing environment variables for MySQL cache provider. "
34+
"Required: MYSQL_HOST, MYSQL_USER, MYSQL_PASSWORD, MYSQL_DATABASE"
35+
"\nFalling back to NoCacheProvider.")
36+
return NoCacheProvider()
37+
38+
return MySQLCacheProvider(
39+
host=host, # type: ignore
40+
user=user, # type: ignore
41+
password=password, # type: ignore
42+
database=database, # type: ignore
43+
port=port # type: ignore
44+
)
45+
case "none":
46+
return NoCacheProvider()
47+
48+
print(f"WARNING: No cache provider found for type {cache_type}. Using NoCacheProvider.")
49+
50+
return NoCacheProvider()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .cache_provider import CacheProvider
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from abc import ABC, abstractmethod
2+
from datetime import datetime
3+
4+
from server.caching.cache_request import CacheRequest
5+
6+
type TimedCache = tuple[bytes, datetime]
7+
8+
9+
class CacheProvider(ABC):
10+
"""
11+
A cache provider that can be used to cache requests
12+
"""
13+
14+
@abstractmethod
15+
def _get(self, request: CacheRequest) -> TimedCache | None:
16+
"""
17+
Get the response and timestamp from the cache
18+
:param request: The request to get the response for
19+
:return: A tuple containing the response and timestamp
20+
"""
21+
22+
def get(self, request: CacheRequest) -> bytes | None:
23+
"""
24+
Get the response from the cache if it exists and is not stale
25+
:param request: The request to get the response for
26+
:return: The response if it exists and is not stale, None otherwise
27+
"""
28+
cached = self._get(request)
29+
if not cached:
30+
return None
31+
32+
response, timestamp = cached
33+
34+
time_diff = datetime.now() - timestamp
35+
36+
if request.max_age <= 0 or 0 < time_diff.seconds < request.max_age:
37+
return response
38+
39+
return None
40+
41+
@abstractmethod
42+
def set(self, request: CacheRequest, response: bytes) -> None:
43+
"""
44+
Save the response to the cache
45+
:param request: The request to save the response for
46+
:param response: The response to save
47+
:return:
48+
"""
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import json
2+
from datetime import datetime
3+
4+
from server.caching.cache_provider import CacheProvider
5+
from server.caching.cache_provider.cache_provider import TimedCache
6+
from server.caching.cache_request import CacheRequest
7+
8+
9+
def _hash_request(request: CacheRequest) -> int:
10+
"""
11+
Hash a request
12+
:param request:
13+
:return:
14+
"""
15+
return hash((request.method, request.url, json.dumps(request.body), json.dumps(request.headers)))
16+
17+
18+
class InMemoryCacheProvider(CacheProvider):
19+
def __init__(self):
20+
self.cache: dict[int, TimedCache] = {}
21+
22+
def _get(self, request: CacheRequest) -> TimedCache | None:
23+
hashed = _hash_request(request)
24+
return self.cache.get(hashed)
25+
26+
def set(self, request: CacheRequest, response: bytes) -> None:
27+
hashed = _hash_request(request)
28+
self.cache[hashed] = (response, datetime.now())
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import json
2+
from datetime import datetime
3+
4+
from server.caching.cache_provider import CacheProvider
5+
from server.caching.cache_provider.cache_provider import TimedCache
6+
from server.caching.cache_request import CacheRequest
7+
import mysql.connector
8+
import mysql.connector.cursor
9+
10+
_TABLE_SQL = """
11+
CREATE TABLE IF NOT EXISTS cache (
12+
id INTEGER PRIMARY KEY AUTO_INCREMENT,
13+
method TEXT NOT NULL,
14+
url TEXT NOT NULL,
15+
response MEDIUMBLOB NOT NULL,
16+
body TEXT NULL,
17+
headers TEXT NULL,
18+
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
19+
);"""
20+
21+
22+
def _get_current_timestamp() -> str:
23+
"""
24+
Get the current timestamp as a string
25+
:return:
26+
"""
27+
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
28+
29+
30+
class MySQLCacheProvider(CacheProvider):
31+
def __init__(self, host: str, user: str, password: str, database: str, port: int = 3306):
32+
self.host = host
33+
self.user = user
34+
self.password = password
35+
self.database = database
36+
self.port = port
37+
self.initialized = False
38+
try:
39+
self._create_table()
40+
self.initialized = True
41+
except mysql.connector.errors.Error as e:
42+
print("Error while initializing cache: " + str(e))
43+
44+
def _connect(self) -> mysql.connector.MySQLConnection:
45+
return mysql.connector.connect(
46+
host=self.host,
47+
user=self.user,
48+
password=self.password,
49+
database=self.database,
50+
port=self.port
51+
) # type: ignore
52+
53+
def _create_table(self):
54+
db = self._connect()
55+
try:
56+
c = db.cursor()
57+
c.execute(_TABLE_SQL)
58+
db.commit()
59+
except mysql.connector.errors.Error as e:
60+
print(e)
61+
62+
def _get(self, request: CacheRequest) -> TimedCache | None:
63+
"""
64+
Get the response and timestamp from the cache
65+
:param request: The request to get the response for
66+
:return: A tuple containing the response and timestamp
67+
"""
68+
69+
headers = json.dumps(request.headers) if request.headers else ""
70+
71+
db = self._connect()
72+
c = db.cursor()
73+
c.execute(
74+
"SELECT response, timestamp FROM cache WHERE method = %s AND url = %s AND body <=> %s AND headers <=> %s",
75+
(request.method, request.url, request.body, headers))
76+
77+
result = c.fetchone()
78+
if result:
79+
return result[0], result[1]
80+
return None
81+
82+
def set(self, request: CacheRequest, response: bytes) -> None:
83+
"""
84+
Save the response to the cache
85+
:param request: The request to save the response for
86+
:param response: The response to save
87+
:return:
88+
"""
89+
headers = json.dumps(request.headers) if request.headers else ""
90+
91+
db = self._connect()
92+
c = db.cursor()
93+
c.execute(
94+
"SELECT id FROM cache WHERE method = %s AND url = %s AND body <=> %s AND headers <=> %s",
95+
(request.method, request.url, request.body, headers))
96+
existing = c.fetchone()
97+
98+
if existing:
99+
c.execute("UPDATE cache SET response = %s, timestamp = %s WHERE id = %s",
100+
(response, _get_current_timestamp(), existing[0]))
101+
else:
102+
c.execute(
103+
"INSERT INTO cache (method, url, response, body, headers, timestamp) VALUES (%s, %s, %s, %s, %s, %s, %s)",
104+
(request.method, request.url, response, request.body, headers, _get_current_timestamp()))
105+
106+
db.commit()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from server.caching.cache_provider import CacheProvider
2+
from server.caching.cache_provider.cache_provider import TimedCache
3+
from server.caching.cache_request import CacheRequest
4+
5+
6+
class NoCacheProvider(CacheProvider):
7+
def _get(self, request: CacheRequest) -> TimedCache | None:
8+
return None
9+
10+
def set(self, request: CacheRequest, response: bytes) -> None:
11+
pass

0 commit comments

Comments
 (0)