@@ -19,33 +19,43 @@ class CRUDPlus(Generic[_Model]):
19
19
def __init__ (self , model : Type [_Model ]):
20
20
self .model = model
21
21
22
- async def create_model (self , session : AsyncSession , obj : _CreateSchema , ** kwargs ) -> None :
22
+ async def create_model (self , session : AsyncSession , obj : _CreateSchema , commit : bool = False , ** kwargs ) -> _Model :
23
23
"""
24
24
Create a new instance of a model
25
25
26
26
:param session:
27
27
:param obj:
28
+ :param commit:
28
29
:param kwargs:
29
30
:return:
30
31
"""
31
32
if kwargs :
32
- instance = self .model (** obj .model_dump (), ** kwargs )
33
+ ins = self .model (** obj .model_dump (), ** kwargs )
33
34
else :
34
- instance = self .model (** obj .model_dump ())
35
- session .add (instance )
35
+ ins = self .model (** obj .model_dump ())
36
+ session .add (ins )
37
+ if commit :
38
+ await session .commit ()
39
+ return ins
36
40
37
- async def create_models (self , session : AsyncSession , obj : Iterable [_CreateSchema ]) -> None :
41
+ async def create_models (
42
+ self , session : AsyncSession , obj : Iterable [_CreateSchema ], commit : bool = False
43
+ ) -> list [_Model ]:
38
44
"""
39
45
Create new instances of a model
40
46
41
47
:param session:
42
48
:param obj:
49
+ :param commit:
43
50
:return:
44
51
"""
45
- instance_list = []
52
+ ins_list = []
46
53
for i in obj :
47
- instance_list .append (self .model (** i .model_dump ()))
48
- session .add_all (instance_list )
54
+ ins_list .append (self .model (** i .model_dump ()))
55
+ session .add_all (ins_list )
56
+ if commit :
57
+ await session .commit ()
58
+ return ins_list
49
59
50
60
async def select_model_by_id (self , session : AsyncSession , pk : int ) -> _Model | None :
51
61
"""
@@ -55,7 +65,8 @@ async def select_model_by_id(self, session: AsyncSession, pk: int) -> _Model | N
55
65
:param pk:
56
66
:return:
57
67
"""
58
- query = await session .execute (select (self .model ).where (self .model .id == pk ))
68
+ stmt = select (self .model ).where (self .model .id == pk )
69
+ query = await session .execute (stmt )
59
70
return query .scalars ().first ()
60
71
61
72
async def select_model_by_column (self , session : AsyncSession , column : str , column_value : Any ) -> _Model | None :
@@ -69,10 +80,11 @@ async def select_model_by_column(self, session: AsyncSession, column: str, colum
69
80
"""
70
81
if hasattr (self .model , column ):
71
82
model_column = getattr (self .model , column )
72
- query = await session .execute (select (self .model ).where (model_column == column_value )) # type: ignore
83
+ stmt = select (self .model ).where (model_column == column_value ) # type: ignore
84
+ query = await session .execute (stmt )
73
85
return query .scalars ().first ()
74
86
else :
75
- raise ModelColumnError (f'Model column { column } is not found' )
87
+ raise ModelColumnError (f'Column { column } is not found in { self . model } ' )
76
88
77
89
async def select_model_by_columns (
78
90
self , session : AsyncSession , expression : Literal ['and' , 'or' ] = 'and' , ** conditions
@@ -91,31 +103,36 @@ async def select_model_by_columns(
91
103
model_column = getattr (self .model , column )
92
104
where_list .append (model_column == value )
93
105
else :
94
- raise ModelColumnError (f'Model column { column } is not found' )
106
+ raise ModelColumnError (f'Column { column } is not found in { self . model } ' )
95
107
match expression :
96
108
case 'and' :
97
- query = await session .execute (select (self .model ).where (and_ (* where_list )))
109
+ stmt = select (self .model ).where (and_ (* where_list ))
110
+ query = await session .execute (stmt )
98
111
case 'or' :
99
- query = await session .execute (select (self .model ).where (or_ (* where_list )))
112
+ stmt = select (self .model ).where (or_ (* where_list ))
113
+ query = await session .execute (stmt )
100
114
case _:
101
- raise SelectExpressionError (f'select expression { expression } is not supported' )
115
+ raise SelectExpressionError (
116
+ f'Select expression { expression } is not supported, only supports `and`, `or`'
117
+ )
102
118
return query .scalars ().first ()
103
119
104
- async def select_models (self , session : AsyncSession ) -> Sequence [Row | RowMapping | Any ] | None :
120
+ async def select_models (self , session : AsyncSession ) -> Sequence [Row [ Any ] | RowMapping | Any ]:
105
121
"""
106
122
Query all rows
107
123
108
124
:param session:
109
125
:return:
110
126
"""
111
- query = await session .execute (select (self .model ))
127
+ stmt = select (self .model )
128
+ query = await session .execute (stmt )
112
129
return query .scalars ().all ()
113
130
114
131
async def select_models_order (
115
132
self ,
116
133
session : AsyncSession ,
117
134
* columns ,
118
- model_sort : Literal ['default' , ' asc' , 'desc' ] = 'default ' ,
135
+ model_sort : Literal ['asc' , 'desc' ] = 'desc ' ,
119
136
) -> Sequence [Row | RowMapping | Any ] | None :
120
137
"""
121
138
Query all rows asc or desc
@@ -131,25 +148,28 @@ async def select_models_order(
131
148
model_column = getattr (self .model , column )
132
149
sort_list .append (model_column )
133
150
else :
134
- raise ModelColumnError (f'Model column { column } is not found' )
151
+ raise ModelColumnError (f'Column { column } is not found in { self . model } ' )
135
152
match model_sort :
136
- case 'default' :
137
- query = await session .execute (select (self .model ).order_by (* sort_list ))
138
153
case 'asc' :
139
154
query = await session .execute (select (self .model ).order_by (asc (* sort_list )))
140
155
case 'desc' :
141
156
query = await session .execute (select (self .model ).order_by (desc (* sort_list )))
142
157
case _:
143
- raise SelectExpressionError (f'select sort expression { model_sort } is not supported' )
158
+ raise SelectExpressionError (
159
+ f'Select sort expression { model_sort } is not supported, only supports `asc`, `desc`'
160
+ )
144
161
return query .scalars ().all ()
145
162
146
- async def update_model (self , session : AsyncSession , pk : int , obj : _UpdateSchema | dict [str , Any ], ** kwargs ) -> int :
163
+ async def update_model (
164
+ self , session : AsyncSession , pk : int , obj : _UpdateSchema | dict [str , Any ], commit : bool = False , ** kwargs
165
+ ) -> int :
147
166
"""
148
167
Update an instance of model's primary key
149
168
150
169
:param session:
151
170
:param pk:
152
171
:param obj:
172
+ :param commit:
153
173
:param kwargs:
154
174
:return:
155
175
"""
@@ -159,11 +179,20 @@ async def update_model(self, session: AsyncSession, pk: int, obj: _UpdateSchema
159
179
instance_data = obj .model_dump (exclude_unset = True )
160
180
if kwargs :
161
181
instance_data .update (kwargs )
162
- result = await session .execute (sa_update (self .model ).where (self .model .id == pk ).values (** instance_data ))
182
+ stmt = sa_update (self .model ).where (self .model .id == pk ).values (** instance_data )
183
+ result = await session .execute (stmt )
184
+ if commit :
185
+ await session .commit ()
163
186
return result .rowcount # type: ignore
164
187
165
188
async def update_model_by_column (
166
- self , session : AsyncSession , column : str , column_value : Any , obj : _UpdateSchema | dict [str , Any ], ** kwargs
189
+ self ,
190
+ session : AsyncSession ,
191
+ column : str ,
192
+ column_value : Any ,
193
+ obj : _UpdateSchema | dict [str , Any ],
194
+ commit : bool = False ,
195
+ ** kwargs ,
167
196
) -> int :
168
197
"""
169
198
Update an instance of model column
@@ -172,6 +201,7 @@ async def update_model_by_column(
172
201
:param column:
173
202
:param column_value:
174
203
:param obj:
204
+ :param commit:
175
205
:param kwargs:
176
206
:return:
177
207
"""
@@ -184,23 +214,29 @@ async def update_model_by_column(
184
214
if hasattr (self .model , column ):
185
215
model_column = getattr (self .model , column )
186
216
else :
187
- raise ModelColumnError (f'Model column { column } is not found' )
188
- result = await session .execute (
189
- sa_update (self .model ).where (model_column == column_value ).values (** instance_data )
190
- )
217
+ raise ModelColumnError (f'Column { column } is not found in { self .model } ' )
218
+ stmt = sa_update (self .model ).where (model_column == column_value ).values (** instance_data ) # type: ignore
219
+ result = await session .execute (stmt )
220
+ if commit :
221
+ await session .commit ()
191
222
return result .rowcount # type: ignore
192
223
193
- async def delete_model (self , session : AsyncSession , pk : int , ** kwargs ) -> int :
224
+ async def delete_model (self , session : AsyncSession , pk : int , commit : bool = False , ** kwargs ) -> int :
194
225
"""
195
226
Delete an instance of a model
196
227
197
228
:param session:
198
229
:param pk:
230
+ :param commit:
199
231
:param kwargs: for soft deletion only
200
232
:return:
201
233
"""
202
234
if not kwargs :
203
- result = await session .execute (sa_delete (self .model ).where (self .model .id == pk ))
235
+ stmt = sa_delete (self .model ).where (self .model .id == pk )
236
+ result = await session .execute (stmt )
204
237
else :
205
- result = await session .execute (sa_update (self .model ).where (self .model .id == pk ).values (** kwargs ))
238
+ stmt = sa_update (self .model ).where (self .model .id == pk ).values (** kwargs )
239
+ result = await session .execute (stmt )
240
+ if commit :
241
+ await session .commit ()
206
242
return result .rowcount # type: ignore
0 commit comments