Skip to content

Commit 238d1a1

Browse files
cmyuitsunyoku
andauthored
Refactor scores, stats and channels repos to build queries with sqlalchemy core 1.4 (osuAkatsuki#630)
* bugfixes Co-authored-git stasby: James Wilson <[email protected]> * begin to refactor Co-authored-by: James Wilson <[email protected]> * channels, stats Co-authored-by: James Wilson <[email protected]> * resolve dependencies + add sqlalchemy type stubs * sqlalchemy-stubs * le bugfix de la typing * _POLEASE * type fixes * type fixes * remove non working test makefile stuff * fixes * re-add pymypysql/aiomysql --------- Co-authored-by: James Wilson <[email protected]>
1 parent 0a16cd6 commit 238d1a1

File tree

10 files changed

+656
-677
lines changed

10 files changed

+656
-677
lines changed

Makefile

-6
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,6 @@ test:
2222
docker-compose -f docker-compose.test.yml up -d bancho-test mysql-test redis-test
2323
docker-compose -f docker-compose.test.yml exec -T bancho-test /srv/root/scripts/run-tests.sh
2424

25-
test-local:
26-
poetry run pytest -vv tests/
27-
28-
test-dbg:
29-
poetry run pytest -vv --pdb -s tests/
30-
3125
lint:
3226
poetry run pre-commit run --all-files
3327

app/api/domains/osu.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -928,7 +928,7 @@ async def osuSubmitModularSelector(
928928
# update global & country ranking
929929
stats.rank = await score.player.update_rank(score.mode)
930930

931-
await stats_repo.update(
931+
await stats_repo.partial_update(
932932
score.player.id,
933933
score.mode.value,
934934
plays=stats_updates.get("plays", UNSET),
@@ -1388,7 +1388,11 @@ async def getScores(
13881388
return Response("\n".join(response_lines).encode())
13891389

13901390
if personal_best_score_row is not None:
1391-
user_clan = await clans_repo.fetch_one(id=player.clan_id)
1391+
user_clan = (
1392+
await clans_repo.fetch_one(id=player.clan_id)
1393+
if player.clan_id is not None
1394+
else None
1395+
)
13921396
display_name = (
13931397
f"[{user_clan['tag']}] {player.name}"
13941398
if user_clan is not None

app/commands.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -867,7 +867,11 @@ async def user(ctx: Context) -> str | None:
867867
else "False"
868868
)
869869

870-
user_clan = await clans_repo.fetch_one(id=player.clan_id)
870+
user_clan = (
871+
await clans_repo.fetch_one(id=player.clan_id)
872+
if player.clan_id is not None
873+
else None
874+
)
871875
display_name = (
872876
f"[{user_clan['tag']}] {player.name}" if user_clan is not None else player.name
873877
)

app/repositories/__init__.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from __future__ import annotations
2+
3+
from sqlalchemy.dialects.mysql.mysqldb import MySQLDialect_mysqldb
4+
from sqlalchemy.orm import DeclarativeMeta
5+
from sqlalchemy.orm import registry
6+
7+
mapper_registry = registry()
8+
9+
10+
class Base(metaclass=DeclarativeMeta):
11+
__abstract__ = True
12+
13+
registry = mapper_registry
14+
metadata = mapper_registry.metadata
15+
16+
__init__ = mapper_registry.constructor
17+
18+
19+
class MySQLDialect(MySQLDialect_mysqldb):
20+
default_paramstyle = "named"
21+
22+
23+
DIALECT = MySQLDialect()

app/repositories/achievements.py

+89-117
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import textwrap
43
from collections.abc import Callable
54
from typing import TYPE_CHECKING
65
from typing import Any
@@ -10,24 +9,45 @@
109
import app.state.services
1110
from app._typing import UNSET
1211
from app._typing import _UnsetSentinel
12+
from app.repositories import DIALECT
13+
from app.repositories import Base
1314

1415
if TYPE_CHECKING:
1516
from app.objects.score import Score
1617

17-
# +-------+--------------+------+-----+---------+----------------+
18-
# | Field | Type | Null | Key | Default | Extra |
19-
# +-------+--------------+------+-----+---------+----------------+
20-
# | id | int | NO | PRI | NULL | auto_increment |
21-
# | file | varchar(128) | NO | UNI | NULL | |
22-
# | name | varchar(128) | NO | UNI | NULL | |
23-
# | desc | varchar(256) | NO | UNI | NULL | |
24-
# | cond | varchar(64) | NO | | NULL | |
25-
# +-------+--------------+------+-----+---------+----------------+
26-
27-
READ_PARAMS = textwrap.dedent(
28-
"""\
29-
id, file, name, `desc`, cond
30-
""",
18+
from sqlalchemy import Column
19+
from sqlalchemy import Index
20+
from sqlalchemy import Integer
21+
from sqlalchemy import String
22+
from sqlalchemy import delete
23+
from sqlalchemy import func
24+
from sqlalchemy import insert
25+
from sqlalchemy import select
26+
from sqlalchemy import update
27+
28+
29+
class AchievementsTable(Base):
30+
__tablename__ = "achievements"
31+
32+
id = Column("id", Integer, primary_key=True)
33+
file = Column("file", String(128), nullable=False)
34+
name = Column("name", String(128, collation="utf8"), nullable=False)
35+
desc = Column("desc", String(256, collation="utf8"), nullable=False)
36+
cond = Column("cond", String(64), nullable=False)
37+
38+
__table_args__ = (
39+
Index("achievements_desc_uindex", desc, unique=True),
40+
Index("achievements_file_uindex", file, unique=True),
41+
Index("achievements_name_uindex", name, unique=True),
42+
)
43+
44+
45+
READ_PARAMS = (
46+
AchievementsTable.id,
47+
AchievementsTable.file,
48+
AchievementsTable.name,
49+
AchievementsTable.desc,
50+
AchievementsTable.cond,
3151
)
3252

3353

@@ -39,41 +59,27 @@ class Achievement(TypedDict):
3959
cond: Callable[[Score, int], bool]
4060

4161

42-
class AchievementUpdateFields(TypedDict, total=False):
43-
file: str
44-
name: str
45-
desc: str
46-
cond: str
47-
48-
4962
async def create(
5063
file: str,
5164
name: str,
5265
desc: str,
5366
cond: str,
5467
) -> Achievement:
5568
"""Create a new achievement."""
56-
query = """\
57-
INSERT INTO achievements (file, name, desc, cond)
58-
VALUES (:file, :name, :desc, :cond)
59-
"""
60-
params: dict[str, Any] = {
61-
"file": file,
62-
"name": name,
63-
"desc": desc,
64-
"cond": cond,
65-
}
66-
rec_id = await app.state.services.database.execute(query, params)
67-
68-
query = f"""\
69-
SELECT {READ_PARAMS}
70-
FROM achievements
71-
WHERE id = :id
72-
"""
73-
params = {
74-
"id": rec_id,
75-
}
76-
rec = await app.state.services.database.fetch_one(query, params)
69+
insert_stmt = insert(AchievementsTable).values(
70+
file=file,
71+
name=name,
72+
desc=desc,
73+
cond=cond,
74+
)
75+
compiled = insert_stmt.compile(dialect=DIALECT)
76+
77+
rec_id = await app.state.services.database.execute(str(compiled), compiled.params)
78+
79+
select_stmt = select(READ_PARAMS).where(AchievementsTable.id == rec_id)
80+
compiled = select_stmt.compile(dialect=DIALECT)
81+
82+
rec = await app.state.services.database.fetch_one(str(compiled), compiled.params)
7783
assert rec is not None
7884

7985
achievement = dict(rec._mapping)
@@ -89,17 +95,15 @@ async def fetch_one(
8995
if id is None and name is None:
9096
raise ValueError("Must provide at least one parameter.")
9197

92-
query = f"""\
93-
SELECT {READ_PARAMS}
94-
FROM achievements
95-
WHERE id = COALESCE(:id, id)
96-
OR name = COALESCE(:name, name)
97-
"""
98-
params: dict[str, Any] = {
99-
"id": id,
100-
"name": name,
101-
}
102-
rec = await app.state.services.database.fetch_one(query, params)
98+
select_stmt = select(READ_PARAMS)
99+
100+
if id is not None:
101+
select_stmt = select_stmt.where(AchievementsTable.id == id)
102+
if name is not None:
103+
select_stmt = select_stmt.where(AchievementsTable.name == name)
104+
105+
compiled = select_stmt.compile(dialect=DIALECT)
106+
rec = await app.state.services.database.fetch_one(str(compiled), compiled.params)
103107

104108
if rec is None:
105109
return None
@@ -111,13 +115,10 @@ async def fetch_one(
111115

112116
async def fetch_count() -> int:
113117
"""Fetch the number of achievements."""
114-
query = """\
115-
SELECT COUNT(*) AS count
116-
FROM achievements
117-
"""
118-
params: dict[str, Any] = {}
118+
select_stmt = select(func.count().label("count")).select_from(AchievementsTable)
119+
compiled = select_stmt.compile(dialect=DIALECT)
119120

120-
rec = await app.state.services.database.fetch_one(query, params)
121+
rec = await app.state.services.database.fetch_one(str(compiled), compiled.params)
121122
assert rec is not None
122123
return cast(int, rec._mapping["count"])
123124

@@ -127,21 +128,16 @@ async def fetch_many(
127128
page_size: int | None = None,
128129
) -> list[Achievement]:
129130
"""Fetch a list of achievements."""
130-
query = f"""\
131-
SELECT {READ_PARAMS}
132-
FROM achievements
133-
"""
134-
params: dict[str, Any] = {}
135-
131+
select_stmt = select(READ_PARAMS)
136132
if page is not None and page_size is not None:
137-
query += """\
138-
LIMIT :limit
139-
OFFSET :offset
140-
"""
141-
params["page_size"] = page_size
142-
params["offset"] = (page - 1) * page_size
133+
select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size)
143134

144-
records = await app.state.services.database.fetch_all(query, params)
135+
compiled = select_stmt.compile(dialect=DIALECT)
136+
137+
records = await app.state.services.database.fetch_all(
138+
str(compiled),
139+
compiled.params,
140+
)
145141

146142
achievements: list[dict[str, Any]] = []
147143

@@ -153,74 +149,50 @@ async def fetch_many(
153149
return cast(list[Achievement], achievements)
154150

155151

156-
async def update(
152+
async def partial_update(
157153
id: int,
158154
file: str | _UnsetSentinel = UNSET,
159155
name: str | _UnsetSentinel = UNSET,
160156
desc: str | _UnsetSentinel = UNSET,
161157
cond: str | _UnsetSentinel = UNSET,
162158
) -> Achievement | None:
163159
"""Update an existing achievement."""
164-
update_fields: AchievementUpdateFields = {}
160+
update_stmt = update(AchievementsTable).where(AchievementsTable.id == id)
165161
if not isinstance(file, _UnsetSentinel):
166-
update_fields["file"] = file
162+
update_stmt = update_stmt.values(file=file)
167163
if not isinstance(name, _UnsetSentinel):
168-
update_fields["name"] = name
164+
update_stmt = update_stmt.values(name=name)
169165
if not isinstance(desc, _UnsetSentinel):
170-
update_fields["desc"] = desc
166+
update_stmt = update_stmt.values(desc=desc)
171167
if not isinstance(cond, _UnsetSentinel):
172-
update_fields["cond"] = cond
173-
174-
query = f"""\
175-
UPDATE achievements
176-
SET {",".join(f"{k} = COALESCE(:{k}, {k})" for k in update_fields)}
177-
WHERE id = :id
178-
"""
179-
params: dict[str, Any] = {
180-
"id": id,
181-
} | update_fields
182-
await app.state.services.database.execute(query, params)
183-
184-
query = f"""\
185-
SELECT {READ_PARAMS}
186-
FROM achievements
187-
WHERE id = :id
188-
"""
189-
params = {
190-
"id": id,
191-
}
192-
rec = await app.state.services.database.fetch_one(query, params)
168+
update_stmt = update_stmt.values(cond=cond)
169+
170+
compiled = update_stmt.compile(dialect=DIALECT)
171+
await app.state.services.database.execute(str(compiled), compiled.params)
172+
173+
select_stmt = select(READ_PARAMS).where(AchievementsTable.id == id)
174+
compiled = select_stmt.compile(dialect=DIALECT)
175+
rec = await app.state.services.database.fetch_one(str(compiled), compiled.params)
193176
assert rec is not None
194177

195178
achievement = dict(rec._mapping)
196179
achievement["cond"] = eval(f'lambda score, mode_vn: {rec["cond"]}')
197180
return cast(Achievement, achievement)
198181

199182

200-
async def delete(
183+
async def delete_one(
201184
id: int,
202185
) -> Achievement | None:
203186
"""Delete an existing achievement."""
204-
query = f"""\
205-
SELECT {READ_PARAMS}
206-
FROM achievements
207-
WHERE id = :id
208-
"""
209-
params: dict[str, Any] = {
210-
"id": id,
211-
}
212-
rec = await app.state.services.database.fetch_one(query, params)
187+
select_stmt = select(READ_PARAMS).where(AchievementsTable.id == id)
188+
compiled = select_stmt.compile(dialect=DIALECT)
189+
rec = await app.state.services.database.fetch_one(str(compiled), compiled.params)
213190
if rec is None:
214191
return None
215192

216-
query = """\
217-
DELETE FROM achievements
218-
WHERE id = :id
219-
"""
220-
params = {
221-
"id": id,
222-
}
223-
await app.state.services.database.execute(query, params)
193+
delete_stmt = delete(AchievementsTable).where(AchievementsTable.id == id)
194+
compiled = delete_stmt.compile(dialect=DIALECT)
195+
await app.state.services.database.execute(str(compiled), compiled.params)
224196

225197
achievement = dict(rec._mapping)
226198
achievement["cond"] = eval(f'lambda score, mode_vn: {rec["cond"]}')

0 commit comments

Comments
 (0)