1
1
from __future__ import annotations
2
2
3
- import textwrap
4
3
from collections .abc import Callable
5
4
from typing import TYPE_CHECKING
6
5
from typing import Any
10
9
import app .state .services
11
10
from app ._typing import UNSET
12
11
from app ._typing import _UnsetSentinel
12
+ from app .repositories import DIALECT
13
+ from app .repositories import Base
13
14
14
15
if TYPE_CHECKING :
15
16
from app .objects .score import Score
16
17
17
- # +-------+--------------+------+-----+---------+----------------+
18
- # | Field | Type | Null | Key | Default | Extra |
19
- # +-------+--------------+------+-----+---------+----------------+
20
- # | id | int | NO | PRI | NULL | auto_increment |
21
- # | file | varchar(128) | NO | UNI | NULL | |
22
- # | name | varchar(128) | NO | UNI | NULL | |
23
- # | desc | varchar(256) | NO | UNI | NULL | |
24
- # | cond | varchar(64) | NO | | NULL | |
25
- # +-------+--------------+------+-----+---------+----------------+
26
-
27
- READ_PARAMS = textwrap .dedent (
28
- """\
29
- id, file, name, `desc`, cond
30
- """ ,
18
+ from sqlalchemy import Column
19
+ from sqlalchemy import Index
20
+ from sqlalchemy import Integer
21
+ from sqlalchemy import String
22
+ from sqlalchemy import delete
23
+ from sqlalchemy import func
24
+ from sqlalchemy import insert
25
+ from sqlalchemy import select
26
+ from sqlalchemy import update
27
+
28
+
29
+ class AchievementsTable (Base ):
30
+ __tablename__ = "achievements"
31
+
32
+ id = Column ("id" , Integer , primary_key = True )
33
+ file = Column ("file" , String (128 ), nullable = False )
34
+ name = Column ("name" , String (128 , collation = "utf8" ), nullable = False )
35
+ desc = Column ("desc" , String (256 , collation = "utf8" ), nullable = False )
36
+ cond = Column ("cond" , String (64 ), nullable = False )
37
+
38
+ __table_args__ = (
39
+ Index ("achievements_desc_uindex" , desc , unique = True ),
40
+ Index ("achievements_file_uindex" , file , unique = True ),
41
+ Index ("achievements_name_uindex" , name , unique = True ),
42
+ )
43
+
44
+
45
+ READ_PARAMS = (
46
+ AchievementsTable .id ,
47
+ AchievementsTable .file ,
48
+ AchievementsTable .name ,
49
+ AchievementsTable .desc ,
50
+ AchievementsTable .cond ,
31
51
)
32
52
33
53
@@ -39,41 +59,27 @@ class Achievement(TypedDict):
39
59
cond : Callable [[Score , int ], bool ]
40
60
41
61
42
- class AchievementUpdateFields (TypedDict , total = False ):
43
- file : str
44
- name : str
45
- desc : str
46
- cond : str
47
-
48
-
49
62
async def create (
50
63
file : str ,
51
64
name : str ,
52
65
desc : str ,
53
66
cond : str ,
54
67
) -> Achievement :
55
68
"""Create a new achievement."""
56
- query = """\
57
- INSERT INTO achievements (file, name, desc, cond)
58
- VALUES (:file, :name, :desc, :cond)
59
- """
60
- params : dict [str , Any ] = {
61
- "file" : file ,
62
- "name" : name ,
63
- "desc" : desc ,
64
- "cond" : cond ,
65
- }
66
- rec_id = await app .state .services .database .execute (query , params )
67
-
68
- query = f"""\
69
- SELECT { READ_PARAMS }
70
- FROM achievements
71
- WHERE id = :id
72
- """
73
- params = {
74
- "id" : rec_id ,
75
- }
76
- rec = await app .state .services .database .fetch_one (query , params )
69
+ insert_stmt = insert (AchievementsTable ).values (
70
+ file = file ,
71
+ name = name ,
72
+ desc = desc ,
73
+ cond = cond ,
74
+ )
75
+ compiled = insert_stmt .compile (dialect = DIALECT )
76
+
77
+ rec_id = await app .state .services .database .execute (str (compiled ), compiled .params )
78
+
79
+ select_stmt = select (READ_PARAMS ).where (AchievementsTable .id == rec_id )
80
+ compiled = select_stmt .compile (dialect = DIALECT )
81
+
82
+ rec = await app .state .services .database .fetch_one (str (compiled ), compiled .params )
77
83
assert rec is not None
78
84
79
85
achievement = dict (rec ._mapping )
@@ -89,17 +95,15 @@ async def fetch_one(
89
95
if id is None and name is None :
90
96
raise ValueError ("Must provide at least one parameter." )
91
97
92
- query = f"""\
93
- SELECT { READ_PARAMS }
94
- FROM achievements
95
- WHERE id = COALESCE(:id, id)
96
- OR name = COALESCE(:name, name)
97
- """
98
- params : dict [str , Any ] = {
99
- "id" : id ,
100
- "name" : name ,
101
- }
102
- rec = await app .state .services .database .fetch_one (query , params )
98
+ select_stmt = select (READ_PARAMS )
99
+
100
+ if id is not None :
101
+ select_stmt = select_stmt .where (AchievementsTable .id == id )
102
+ if name is not None :
103
+ select_stmt = select_stmt .where (AchievementsTable .name == name )
104
+
105
+ compiled = select_stmt .compile (dialect = DIALECT )
106
+ rec = await app .state .services .database .fetch_one (str (compiled ), compiled .params )
103
107
104
108
if rec is None :
105
109
return None
@@ -111,13 +115,10 @@ async def fetch_one(
111
115
112
116
async def fetch_count () -> int :
113
117
"""Fetch the number of achievements."""
114
- query = """\
115
- SELECT COUNT(*) AS count
116
- FROM achievements
117
- """
118
- params : dict [str , Any ] = {}
118
+ select_stmt = select (func .count ().label ("count" )).select_from (AchievementsTable )
119
+ compiled = select_stmt .compile (dialect = DIALECT )
119
120
120
- rec = await app .state .services .database .fetch_one (query , params )
121
+ rec = await app .state .services .database .fetch_one (str ( compiled ), compiled . params )
121
122
assert rec is not None
122
123
return cast (int , rec ._mapping ["count" ])
123
124
@@ -127,21 +128,16 @@ async def fetch_many(
127
128
page_size : int | None = None ,
128
129
) -> list [Achievement ]:
129
130
"""Fetch a list of achievements."""
130
- query = f"""\
131
- SELECT { READ_PARAMS }
132
- FROM achievements
133
- """
134
- params : dict [str , Any ] = {}
135
-
131
+ select_stmt = select (READ_PARAMS )
136
132
if page is not None and page_size is not None :
137
- query += """\
138
- LIMIT :limit
139
- OFFSET :offset
140
- """
141
- params ["page_size" ] = page_size
142
- params ["offset" ] = (page - 1 ) * page_size
133
+ select_stmt = select_stmt .limit (page_size ).offset ((page - 1 ) * page_size )
143
134
144
- records = await app .state .services .database .fetch_all (query , params )
135
+ compiled = select_stmt .compile (dialect = DIALECT )
136
+
137
+ records = await app .state .services .database .fetch_all (
138
+ str (compiled ),
139
+ compiled .params ,
140
+ )
145
141
146
142
achievements : list [dict [str , Any ]] = []
147
143
@@ -153,74 +149,50 @@ async def fetch_many(
153
149
return cast (list [Achievement ], achievements )
154
150
155
151
156
- async def update (
152
+ async def partial_update (
157
153
id : int ,
158
154
file : str | _UnsetSentinel = UNSET ,
159
155
name : str | _UnsetSentinel = UNSET ,
160
156
desc : str | _UnsetSentinel = UNSET ,
161
157
cond : str | _UnsetSentinel = UNSET ,
162
158
) -> Achievement | None :
163
159
"""Update an existing achievement."""
164
- update_fields : AchievementUpdateFields = {}
160
+ update_stmt = update ( AchievementsTable ). where ( AchievementsTable . id == id )
165
161
if not isinstance (file , _UnsetSentinel ):
166
- update_fields [ "file" ] = file
162
+ update_stmt = update_stmt . values ( file = file )
167
163
if not isinstance (name , _UnsetSentinel ):
168
- update_fields [ "name" ] = name
164
+ update_stmt = update_stmt . values ( name = name )
169
165
if not isinstance (desc , _UnsetSentinel ):
170
- update_fields [ "desc" ] = desc
166
+ update_stmt = update_stmt . values ( desc = desc )
171
167
if not isinstance (cond , _UnsetSentinel ):
172
- update_fields ["cond" ] = cond
173
-
174
- query = f"""\
175
- UPDATE achievements
176
- SET { "," .join (f"{ k } = COALESCE(:{ k } , { k } )" for k in update_fields )}
177
- WHERE id = :id
178
- """
179
- params : dict [str , Any ] = {
180
- "id" : id ,
181
- } | update_fields
182
- await app .state .services .database .execute (query , params )
183
-
184
- query = f"""\
185
- SELECT { READ_PARAMS }
186
- FROM achievements
187
- WHERE id = :id
188
- """
189
- params = {
190
- "id" : id ,
191
- }
192
- rec = await app .state .services .database .fetch_one (query , params )
168
+ update_stmt = update_stmt .values (cond = cond )
169
+
170
+ compiled = update_stmt .compile (dialect = DIALECT )
171
+ await app .state .services .database .execute (str (compiled ), compiled .params )
172
+
173
+ select_stmt = select (READ_PARAMS ).where (AchievementsTable .id == id )
174
+ compiled = select_stmt .compile (dialect = DIALECT )
175
+ rec = await app .state .services .database .fetch_one (str (compiled ), compiled .params )
193
176
assert rec is not None
194
177
195
178
achievement = dict (rec ._mapping )
196
179
achievement ["cond" ] = eval (f'lambda score, mode_vn: { rec ["cond" ]} ' )
197
180
return cast (Achievement , achievement )
198
181
199
182
200
- async def delete (
183
+ async def delete_one (
201
184
id : int ,
202
185
) -> Achievement | None :
203
186
"""Delete an existing achievement."""
204
- query = f"""\
205
- SELECT { READ_PARAMS }
206
- FROM achievements
207
- WHERE id = :id
208
- """
209
- params : dict [str , Any ] = {
210
- "id" : id ,
211
- }
212
- rec = await app .state .services .database .fetch_one (query , params )
187
+ select_stmt = select (READ_PARAMS ).where (AchievementsTable .id == id )
188
+ compiled = select_stmt .compile (dialect = DIALECT )
189
+ rec = await app .state .services .database .fetch_one (str (compiled ), compiled .params )
213
190
if rec is None :
214
191
return None
215
192
216
- query = """\
217
- DELETE FROM achievements
218
- WHERE id = :id
219
- """
220
- params = {
221
- "id" : id ,
222
- }
223
- await app .state .services .database .execute (query , params )
193
+ delete_stmt = delete (AchievementsTable ).where (AchievementsTable .id == id )
194
+ compiled = delete_stmt .compile (dialect = DIALECT )
195
+ await app .state .services .database .execute (str (compiled ), compiled .params )
224
196
225
197
achievement = dict (rec ._mapping )
226
198
achievement ["cond" ] = eval (f'lambda score, mode_vn: { rec ["cond" ]} ' )
0 commit comments