Skip to content

Commit 9e6b98d

Browse files
authored
Add commit option to CRUD operations (#8)
1 parent ddd24c4 commit 9e6b98d

File tree

3 files changed

+72
-42
lines changed

3 files changed

+72
-42
lines changed

README.md

+4-9
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,17 @@
11
# sqlalchemy-crud-plus
22

3-
基于 SQLAlChemy2 模型的异步 CRUD 操作
3+
Asynchronous CRUD operations based on SQLAlChemy 2.0
44

5-
## 下载
5+
## Download
66

77
```shell
88
pip install sqlalchemy-crud-plus
99
```
1010

11-
## TODO
12-
13-
- [ ] ...
14-
1511
## Use
1612

17-
以下仅为简易示例
18-
1913
```python
14+
# example:
2015
from sqlalchemy.orm import declarative_base
2116
from sqlalchemy_crud_plus import CRUDPlus
2217

@@ -34,7 +29,7 @@ class CRUDIns(CRUDPlus[ModelIns]):
3429

3530

3631
# singleton
37-
ins_dao = CRUDIns(ModelIns)
32+
ins_dao: CRUDIns = CRUDIns(ModelIns)
3833
```
3934

4035
## 互动

pyproject.toml

-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ select = [
4949
"I"
5050
]
5151
preview = true
52-
ignore-init-module-imports = true
5352

5453
[tool.ruff.lint.isort]
5554
lines-between-types = 1

sqlalchemy_crud_plus/crud.py

+68-32
Original file line numberDiff line numberDiff line change
@@ -19,33 +19,43 @@ class CRUDPlus(Generic[_Model]):
1919
def __init__(self, model: Type[_Model]):
2020
self.model = model
2121

22-
async def create_model(self, session: AsyncSession, obj: _CreateSchema, **kwargs) -> None:
22+
async def create_model(self, session: AsyncSession, obj: _CreateSchema, commit: bool = False, **kwargs) -> _Model:
2323
"""
2424
Create a new instance of a model
2525
2626
:param session:
2727
:param obj:
28+
:param commit:
2829
:param kwargs:
2930
:return:
3031
"""
3132
if kwargs:
32-
instance = self.model(**obj.model_dump(), **kwargs)
33+
ins = self.model(**obj.model_dump(), **kwargs)
3334
else:
34-
instance = self.model(**obj.model_dump())
35-
session.add(instance)
35+
ins = self.model(**obj.model_dump())
36+
session.add(ins)
37+
if commit:
38+
await session.commit()
39+
return ins
3640

37-
async def create_models(self, session: AsyncSession, obj: Iterable[_CreateSchema]) -> None:
41+
async def create_models(
42+
self, session: AsyncSession, obj: Iterable[_CreateSchema], commit: bool = False
43+
) -> list[_Model]:
3844
"""
3945
Create new instances of a model
4046
4147
:param session:
4248
:param obj:
49+
:param commit:
4350
:return:
4451
"""
45-
instance_list = []
52+
ins_list = []
4653
for i in obj:
47-
instance_list.append(self.model(**i.model_dump()))
48-
session.add_all(instance_list)
54+
ins_list.append(self.model(**i.model_dump()))
55+
session.add_all(ins_list)
56+
if commit:
57+
await session.commit()
58+
return ins_list
4959

5060
async def select_model_by_id(self, session: AsyncSession, pk: int) -> _Model | None:
5161
"""
@@ -55,7 +65,8 @@ async def select_model_by_id(self, session: AsyncSession, pk: int) -> _Model | N
5565
:param pk:
5666
:return:
5767
"""
58-
query = await session.execute(select(self.model).where(self.model.id == pk))
68+
stmt = select(self.model).where(self.model.id == pk)
69+
query = await session.execute(stmt)
5970
return query.scalars().first()
6071

6172
async def select_model_by_column(self, session: AsyncSession, column: str, column_value: Any) -> _Model | None:
@@ -69,10 +80,11 @@ async def select_model_by_column(self, session: AsyncSession, column: str, colum
6980
"""
7081
if hasattr(self.model, column):
7182
model_column = getattr(self.model, column)
72-
query = await session.execute(select(self.model).where(model_column == column_value)) # type: ignore
83+
stmt = select(self.model).where(model_column == column_value) # type: ignore
84+
query = await session.execute(stmt)
7385
return query.scalars().first()
7486
else:
75-
raise ModelColumnError(f'Model column {column} is not found')
87+
raise ModelColumnError(f'Column {column} is not found in {self.model}')
7688

7789
async def select_model_by_columns(
7890
self, session: AsyncSession, expression: Literal['and', 'or'] = 'and', **conditions
@@ -91,31 +103,36 @@ async def select_model_by_columns(
91103
model_column = getattr(self.model, column)
92104
where_list.append(model_column == value)
93105
else:
94-
raise ModelColumnError(f'Model column {column} is not found')
106+
raise ModelColumnError(f'Column {column} is not found in {self.model}')
95107
match expression:
96108
case 'and':
97-
query = await session.execute(select(self.model).where(and_(*where_list)))
109+
stmt = select(self.model).where(and_(*where_list))
110+
query = await session.execute(stmt)
98111
case 'or':
99-
query = await session.execute(select(self.model).where(or_(*where_list)))
112+
stmt = select(self.model).where(or_(*where_list))
113+
query = await session.execute(stmt)
100114
case _:
101-
raise SelectExpressionError(f'select expression {expression} is not supported')
115+
raise SelectExpressionError(
116+
f'Select expression {expression} is not supported, only supports `and`, `or`'
117+
)
102118
return query.scalars().first()
103119

104-
async def select_models(self, session: AsyncSession) -> Sequence[Row | RowMapping | Any] | None:
120+
async def select_models(self, session: AsyncSession) -> Sequence[Row[Any] | RowMapping | Any]:
105121
"""
106122
Query all rows
107123
108124
:param session:
109125
:return:
110126
"""
111-
query = await session.execute(select(self.model))
127+
stmt = select(self.model)
128+
query = await session.execute(stmt)
112129
return query.scalars().all()
113130

114131
async def select_models_order(
115132
self,
116133
session: AsyncSession,
117134
*columns,
118-
model_sort: Literal['default', 'asc', 'desc'] = 'default',
135+
model_sort: Literal['asc', 'desc'] = 'desc',
119136
) -> Sequence[Row | RowMapping | Any] | None:
120137
"""
121138
Query all rows asc or desc
@@ -131,25 +148,28 @@ async def select_models_order(
131148
model_column = getattr(self.model, column)
132149
sort_list.append(model_column)
133150
else:
134-
raise ModelColumnError(f'Model column {column} is not found')
151+
raise ModelColumnError(f'Column {column} is not found in {self.model}')
135152
match model_sort:
136-
case 'default':
137-
query = await session.execute(select(self.model).order_by(*sort_list))
138153
case 'asc':
139154
query = await session.execute(select(self.model).order_by(asc(*sort_list)))
140155
case 'desc':
141156
query = await session.execute(select(self.model).order_by(desc(*sort_list)))
142157
case _:
143-
raise SelectExpressionError(f'select sort expression {model_sort} is not supported')
158+
raise SelectExpressionError(
159+
f'Select sort expression {model_sort} is not supported, only supports `asc`, `desc`'
160+
)
144161
return query.scalars().all()
145162

146-
async def update_model(self, session: AsyncSession, pk: int, obj: _UpdateSchema | dict[str, Any], **kwargs) -> int:
163+
async def update_model(
164+
self, session: AsyncSession, pk: int, obj: _UpdateSchema | dict[str, Any], commit: bool = False, **kwargs
165+
) -> int:
147166
"""
148167
Update an instance of model's primary key
149168
150169
:param session:
151170
:param pk:
152171
:param obj:
172+
:param commit:
153173
:param kwargs:
154174
:return:
155175
"""
@@ -159,11 +179,20 @@ async def update_model(self, session: AsyncSession, pk: int, obj: _UpdateSchema
159179
instance_data = obj.model_dump(exclude_unset=True)
160180
if kwargs:
161181
instance_data.update(kwargs)
162-
result = await session.execute(sa_update(self.model).where(self.model.id == pk).values(**instance_data))
182+
stmt = sa_update(self.model).where(self.model.id == pk).values(**instance_data)
183+
result = await session.execute(stmt)
184+
if commit:
185+
await session.commit()
163186
return result.rowcount # type: ignore
164187

165188
async def update_model_by_column(
166-
self, session: AsyncSession, column: str, column_value: Any, obj: _UpdateSchema | dict[str, Any], **kwargs
189+
self,
190+
session: AsyncSession,
191+
column: str,
192+
column_value: Any,
193+
obj: _UpdateSchema | dict[str, Any],
194+
commit: bool = False,
195+
**kwargs,
167196
) -> int:
168197
"""
169198
Update an instance of model column
@@ -172,6 +201,7 @@ async def update_model_by_column(
172201
:param column:
173202
:param column_value:
174203
:param obj:
204+
:param commit:
175205
:param kwargs:
176206
:return:
177207
"""
@@ -184,23 +214,29 @@ async def update_model_by_column(
184214
if hasattr(self.model, column):
185215
model_column = getattr(self.model, column)
186216
else:
187-
raise ModelColumnError(f'Model column {column} is not found')
188-
result = await session.execute(
189-
sa_update(self.model).where(model_column == column_value).values(**instance_data)
190-
)
217+
raise ModelColumnError(f'Column {column} is not found in {self.model}')
218+
stmt = sa_update(self.model).where(model_column == column_value).values(**instance_data) # type: ignore
219+
result = await session.execute(stmt)
220+
if commit:
221+
await session.commit()
191222
return result.rowcount # type: ignore
192223

193-
async def delete_model(self, session: AsyncSession, pk: int, **kwargs) -> int:
224+
async def delete_model(self, session: AsyncSession, pk: int, commit: bool = False, **kwargs) -> int:
194225
"""
195226
Delete an instance of a model
196227
197228
:param session:
198229
:param pk:
230+
:param commit:
199231
:param kwargs: for soft deletion only
200232
:return:
201233
"""
202234
if not kwargs:
203-
result = await session.execute(sa_delete(self.model).where(self.model.id == pk))
235+
stmt = sa_delete(self.model).where(self.model.id == pk)
236+
result = await session.execute(stmt)
204237
else:
205-
result = await session.execute(sa_update(self.model).where(self.model.id == pk).values(**kwargs))
238+
stmt = sa_update(self.model).where(self.model.id == pk).values(**kwargs)
239+
result = await session.execute(stmt)
240+
if commit:
241+
await session.commit()
206242
return result.rowcount # type: ignore

0 commit comments

Comments
 (0)