|
| 1 | +from typing import Optional, TypeVar, Generic |
| 2 | +from pydantic import BaseModel |
| 3 | + |
| 4 | +T = TypeVar('T', bound=BaseModel) |
| 5 | + |
| 6 | + |
| 7 | +class AbstractRepository(Generic[T]): |
| 8 | + def __init__(self, database): |
| 9 | + super().__init__() |
| 10 | + self.__database = database |
| 11 | + self.__document_class = self.__orig_bases__[0].__args__[0] |
| 12 | + self.__collection_name = self.Meta.collection_name |
| 13 | + self.__validate() |
| 14 | + |
| 15 | + def get_collection(self): |
| 16 | + return self.__database[self.__collection_name] |
| 17 | + |
| 18 | + def __validate(self): |
| 19 | + if not issubclass(self.__document_class, BaseModel): |
| 20 | + raise Exception('Document class should inherit BaseModel') |
| 21 | + if 'id' not in self.__document_class.__fields__: |
| 22 | + raise Exception('Document class should have id field') |
| 23 | + if not self.__collection_name: |
| 24 | + raise Exception('Meta should contain collection name') |
| 25 | + |
| 26 | + def to_document(self, model: T) -> dict: |
| 27 | + result = model.dict() |
| 28 | + result.pop('id') |
| 29 | + if model.id: |
| 30 | + result['_id'] = model.id |
| 31 | + return result |
| 32 | + |
| 33 | + def __to_query(self, data: dict): |
| 34 | + query = data.copy() |
| 35 | + if 'id' in data: |
| 36 | + query['_id'] = query.pop('id') |
| 37 | + return query |
| 38 | + |
| 39 | + def to_model(self, data: dict) -> T: |
| 40 | + data_copy = data.copy() |
| 41 | + if '_id' in data_copy: |
| 42 | + data_copy['id'] = data_copy.pop('_id') |
| 43 | + return self.__document_class.parse_obj(data_copy) |
| 44 | + |
| 45 | + def save(self, model: T): |
| 46 | + document = self.to_document(model) |
| 47 | + |
| 48 | + if model.id: |
| 49 | + mongo_id = document.pop('_id') |
| 50 | + self.get_collection().update_one({'_id': mongo_id}, {'$set': document}) |
| 51 | + return |
| 52 | + |
| 53 | + result = self.get_collection().insert_one(document) |
| 54 | + model.id = result.inserted_id |
| 55 | + return result |
| 56 | + |
| 57 | + def delete(self, model: T): |
| 58 | + return self.get_collection().delete_one({'_id': model.id}) |
| 59 | + |
| 60 | + def find_one_by_id(self, id: str) -> Optional[T]: |
| 61 | + return self.find_one_by({'id': id}) |
| 62 | + |
| 63 | + def find_one_by(self, query: dict) -> Optional[T]: |
| 64 | + result = self.get_collection().find_one(self.__to_query(query)) |
| 65 | + return self.to_model(result) if result else None |
| 66 | + |
| 67 | + def find_by( |
| 68 | + self, |
| 69 | + query: dict, |
| 70 | + skip: Optional[int] = None, |
| 71 | + limit: Optional[int] = None, |
| 72 | + sort=None |
| 73 | + ): |
| 74 | + cursor = self.get_collection().find(self.__to_query(query)) |
| 75 | + if limit: |
| 76 | + cursor.limit(limit) |
| 77 | + if skip: |
| 78 | + cursor.skip(skip) |
| 79 | + if sort: |
| 80 | + cursor.sort(sort) |
| 81 | + return map(self.to_model, cursor) |
0 commit comments