diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a29abf81..9044cb9f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,6 +44,6 @@ repos: - id: add-trailing-comma # Flake8 to check style is OK - repo: https://gitlab.com/pycqa/flake8 - rev: 3.8.4 + rev: 3.9.0 hooks: - id: flake8 diff --git a/app/database/crud/__init__.py b/app/database/crud/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/database/crud/crud.py b/app/database/crud/crud.py new file mode 100644 index 00000000..b6cc45d1 --- /dev/null +++ b/app/database/crud/crud.py @@ -0,0 +1,242 @@ +"""Low level CRUD functions, generalised for wider usage. + +Functions listed here should be accessed only from other CRUD modules, +and not directly from the app. +""" +from typing import Any, List, Optional, Type, Union + +from pydantic import BaseModel +from sqlalchemy.exc import ( + IntegrityError, + InvalidRequestError, + OperationalError, + SQLAlchemyError, + StatementError, +) +from sqlalchemy.orm import Session +from sqlalchemy.orm.attributes import InstrumentedAttribute +from sqlalchemy.orm.exc import UnmappedInstanceError + +from app.database.models_v2 import Base +from app.dependencies import logger + + +def insert(session: Session, instance: Base) -> bool: + """Inserts a new row into the database. + + Args: + session: The database connection. + instance: The object to save. + + Returns: + True if successful, otherwise returns False. + + Raises: + SQLAlchemyError: If the database tables were not created. + + """ + if issubclass(instance.__class__, Base): + try: + session.add(instance) + session.commit() + session.refresh(instance) + return True + except IntegrityError as e: + logger.exception(e) + return False + except OperationalError as e: + logger.exception(e) + raise SQLAlchemyError("Database tables were not created yet.") + return False + + +def delete(session: Session, instance: Base) -> bool: + """Deletes a row from the database using the database model. + + Args: + session: The database connection. + instance: The object to delete. + + Returns: + True if successful, otherwise returns False. + """ + return delete_multiple(session, [instance]) + + +def delete_multiple(session: Session, instances: List[Base]) -> bool: + """Deletes a multiple rows from the database using the database models. + + Args: + session: The database connection. + instances: A list of objects to delete. + + Returns: + True if successful, otherwise returns False. + """ + try: + for instance in instances: + session.delete(instance) + session.commit() + return True + except InvalidRequestError: + return False + except UnmappedInstanceError: + return False + + +def get_by_id( + session: Session, + entity_id: int, + orm_class: Type[Base], +) -> Optional[Union[BaseModel, Base]]: + """Returns a schema or database model by an ID. + + Args: + session: The database connection. + entity_id: The entity's ID. + orm_class: The database mapped model class. + + Returns: + A BaseModel or Base model, as requested, if successful, + otherwise returns None. + """ + keywords = {orm_class.id.key: entity_id} + return get_database_model_by_parameter(session, orm_class, **keywords) + + +def get_database_model_by_parameter( + session: Session, + orm_class: Type[Base], + **kwargs: Any, +) -> Optional[Union[BaseModel, Base]]: + """Returns a schema or database model by a parameter. + + Args: + session: The database connection. + orm_class: The database mapped model class. + **kwargs: The parameter to filter by. + Must be in the format of: key=value. + + Returns: + A BaseModel or Base model, as requested, if successful, + otherwise returns None. + """ + try: + return session.query(orm_class).filter_by(**kwargs).first() + except OperationalError as e: + logger.exception(e) + return None + + +def get_all_database_models( + session: Session, + orm_class: Type[Base], + skip: int = 0, + limit: int = 100, +) -> List[Base]: + """Returns all models from the database. + + Args: + session: The database connection. + orm_class: The database mapped model class. + skip: The starting index. + Defaults to 0. + limit: The amount of returned items. + Defaults to 100. + + Returns: + A list database models. + """ + try: + return session.query(orm_class).offset(skip).limit(limit).all() + except OperationalError as e: + logger.exception(e) + return [] + + +def get_property( + session: Session, + entity_id: int, + column: InstrumentedAttribute, +) -> Optional[Any]: + """Returns the value of an entity's property. + + Args: + session: The database connection. + entity_id: The entity's ID. + column: The database column from where to query the data. + + Returns: + The value of the entity's database column. + """ + orm_model = get_by_id(session, entity_id, column.class_) + if not orm_model: + return None + + return getattr(orm_model, column.key) + + +def set_property( + session: Session, + entity_id: int, + column: InstrumentedAttribute, + value: Any, +) -> bool: + """Sets a new value for an entity's property. + + Args: + session: The database connection. + entity_id: The entity's ID. + column: The database column to where the data is saved. + value: The new value to set. + + Returns: + True if successful, otherwise returns False. + """ + orm_model = get_by_id(session, entity_id, column.class_) + if not orm_model: + return False + + setattr(orm_model, column.key, value) + try: + session.commit() + except IntegrityError: + session.rollback() + return False + except StatementError: + session.rollback() + return False + return True + + +def update_database_by_schema_model( + session: Session, + entity_id: int, + schema_instance: BaseModel, + orm_class: Type[Base], +) -> bool: + """Updates the database model by extracting data from the schema object. + + ID is passed as a separate parameter for instances where an ID is named + something other than "id". + + Args: + session: The database connection. + entity_id: The entity's ID. + schema_instance: The schema model whose data is used for the update. + orm_class: The database mapped model class. + + Returns: + True if successful, otherwise returns False. + """ + id_filter = {orm_class.id.key: entity_id} + try: + ( + session.query(orm_class) + .filter_by(**id_filter) + .update(schema_instance.dict()) + ) + session.commit() + return True + except InvalidRequestError: + return False diff --git a/app/database/crud/event.py b/app/database/crud/event.py new file mode 100644 index 00000000..25546052 --- /dev/null +++ b/app/database/crud/event.py @@ -0,0 +1,452 @@ +"""CRUD functions for the Event model.""" +from typing import List, Optional, Union + +from pydantic import parse_obj_as +from sqlalchemy.orm import Session + +from app.database.crud import crud +from app.database.crud import user as user_crud +from app.database.models_v2 import Event as EventOrm +from app.database.models_v2 import User as UserOrm +from app.database.schemas_v2 import Event, EventAll, EventCreate, User +from app.internal.privacy import PrivacyKinds + + +def create(session: Session, event: EventCreate) -> Optional[Event]: + """Returns an Event after creating and saving a model in the database. + + Args: + session: The database connection. + event: The created event data. + + Returns: + The created Event if successful, otherwise returns None. + """ + event_orm = EventOrm(**event.dict()) + event_created = crud.insert(session, event_orm) + if not event_created: + return None + + event_orm.members.append(event_orm.owner) + session.commit() + return Event.from_orm(event_orm) + + +def delete(session: Session, event: Event) -> bool: + """Deletes an Event from the database. + + Args: + session: The database connection. + event: The Event to delete. + + Returns: + True if successful, otherwise returns False. + """ + event_orm = _get_by_id(session, event.id, False) + if event_orm: + return crud.delete(session, event_orm) + return False + + +def get_database_event_by_id( + session: Session, + event_id: int, +) -> Optional[EventOrm]: + """Returns an Event database model by an ID. + + Args: + session: The database connection. + event_id: The Event's ID. + + Returns: + An Event database model if successful, otherwise returns None. + """ + event = _get_by_id(session, event_id, False) + if isinstance(event, EventOrm): + return event + return None + + +def get_by_id(session: Session, event_id: int) -> Optional[Event]: + """Returns an Event by an ID. + + Args: + session: The database connection. + event_id: The Event's ID. + + Returns: + An Event if successful, otherwise returns None. + """ + return _get_by_id(session, event_id, True) + + +def get_all(session: Session, skip: int = 0, limit: int = 100) -> List[Event]: + """Returns all Events. + + TODO: should the return value be limited to public events? + + Args: + session: The database connection. + skip: The starting index. + Defaults to 0. + limit: The amount of returned items. + Defaults to 100. + + Returns: + A list of Events. + """ + events = crud.get_all_database_models(session, EventOrm, skip, limit) + return parse_obj_as(List[Event], events) + + +def update_event(session: Session, event: EventAll) -> bool: + """Updates an Event's data in a complete update. + + Args: + session: The database connection. + event: The Event to update. + + Returns: + True if successful, otherwise returns False. + """ + return crud.update_database_by_schema_model( + session, + event.id, + event, + EventOrm, + ) + + +def get_members(session: Session, event: Event) -> List[User]: + """Returns a list of the events User members. + + Args: + session: The database connection. + event: The Event whose data is retrieved. + + Returns: + A list of Users. + """ + event_orm = _get_by_id(session, event.id, False) + if isinstance(event_orm, EventOrm): + return parse_obj_as(List[User], event_orm.members) + return [] + + +def add_member(session: Session, event: Event, user: User) -> bool: + """Adds a User to an event. + + Args: + session: The database connection. + event: The Event to update. + user: The User being added to the Event. + + Returns: + True if successful, otherwise returns False. + """ + user_orm = user_crud.get_database_user_by_id(session, user.id) + event_orm = _get_by_id(session, event.id, False) + if isinstance(event_orm, EventOrm): + event_orm.members.append(user_orm) + session.commit() + return True + return False + + +def remove_member(session: Session, event: Event, user: User) -> bool: + """Removes a User from an event. + + Args: + session: The database connection. + event: The Event to update. + user: The User being added to the Event. + + Returns: + True if successful, otherwise returns False. + """ + user_orm = user_crud.get_database_user_by_id(session, user.id) + event_orm = _get_by_id(session, event.id, False) + if isinstance(event_orm, EventOrm): + event_orm.members.remove(user_orm) + session.commit() + return True + return False + + +# TODO: function +# def get_category(session: Session, event: Event) -> Optional[Category]: +# """Returns the Category the Event belongs to. +# +# Args: +# session: The database connection. +# event: The Event whose data is retrieved. +# +# Returns: +# The Category the Event belongs to if found. +# """ +# event_orm = _get_by_id(session, event.id, False) +# if event_orm: +# return Category.from_orm(event_orm.category) +# return None + + +def get_color(session: Session, event: Event) -> Optional[str]: + """Returns the Event's color. + + Args: + session: The database connection. + event: The Event whose data is retrieved. + + Returns: + The Event's color, if found. + """ + return crud.get_property(session, event.id, EventOrm.color) + + +def get_content(session: Session, event: Event) -> Optional[str]: + """Returns the Event's content. + + Args: + session: The database connection. + event: The Event whose data is retrieved. + + Returns: + The Event's content, if found. + """ + return crud.get_property(session, event.id, EventOrm.content) + + +def get_emotion(session: Session, event: Event) -> Optional[str]: + """Returns the Event's emotion. + + Args: + session: The database connection. + event: The Event whose data is retrieved. + + Returns: + The Event's emotion, if found. + """ + return crud.get_property(session, event.id, EventOrm.emotion) + + +def get_image(session: Session, event: Event) -> Optional[str]: + """Returns the Event's image. + + Args: + session: The database connection. + event: The Event whose data is retrieved. + + Returns: + The Event's image, if found. + """ + return crud.get_property(session, event.id, EventOrm.image) + + +def get_invited_emails(session: Session, event: Event) -> Optional[str]: + """Returns the a list of emails that have been sent invites to the event. + + TODO: current format returns a single list, should probably be refactored. + + Args: + session: The database connection. + event: The Event whose data is retrieved. + + Returns: + A list of emails in a single string, if found. + """ + return crud.get_property(session, event.id, EventOrm.invited_emails) + + +def is_all_day(session: Session, event: Event) -> Optional[bool]: + """Returns whether the Event is a full day event or not. + + Args: + session: The database connection. + event: The Event whose data is retrieved. + + Returns: + The Event's all day status, if found. + """ + return crud.get_property(session, event.id, EventOrm.is_all_day) + + +def is_available(session: Session, event: Event) -> Optional[bool]: + """TODO: This is not an event data but a user data that shows on the event. + + Args: + session: The database connection. + event: The Event whose data is retrieved. + + Returns: + The Event's availability status, if found. + """ + return crud.get_property(session, event.id, EventOrm.is_available) + + +def is_google_event(session: Session, event: Event) -> Optional[bool]: + """Returns whether the Event was imported from Google Calendar or not. + + Args: + session: The database connection. + event: The Event whose data is retrieved. + + Returns: + The Event's calendar status, if found. + """ + return crud.get_property(session, event.id, EventOrm.is_google_event) + + +def get_latitude(session: Session, event: Event) -> Optional[float]: + """Returns the Event's location's latitude. + + Args: + session: The database connection. + event: The Event whose data is retrieved. + + Returns: + The Event's location's latitude, if found. + """ + return crud.get_property(session, event.id, EventOrm.latitude) + + +def get_location(session: Session, event: Event) -> Optional[str]: + """Returns the Event's location's address. + + Args: + session: The database connection. + event: The Event whose data is retrieved. + + Returns: + The Event's location's address, if found. + """ + return crud.get_property(session, event.id, EventOrm.location) + + +def get_longitude(session: Session, event: Event) -> Optional[float]: + """Returns the Event's location's longitude. + + Args: + session: The database connection. + event: The Event whose data is retrieved. + + Returns: + The Event's location's longitude, if found. + """ + return crud.get_property(session, event.id, EventOrm.longitude) + + +def get_owner(session: Session, event: Event) -> Optional[User]: + """Returns the Event's owner. + + Args: + session: The database connection. + event: The Event whose data is retrieved. + + Returns: + The User who created the event, if found. + """ + event_orm = _get_by_id(session, event.id, False) + if isinstance(event_orm, EventOrm): + return User.from_orm(event_orm.owner) + return None + + +def change_owner(session: Session, event: Event, user_id: int) -> bool: + """Sets a new owner for an Event. + + Args: + session: The database connection. + event: The Event to update. + user_id: The new owner's user ID. + + Returns: + True if successful, otherwise returns False. + """ + event_orm = _get_by_id(session, event.id, False) + new_owner = user_crud.get_database_user_by_id(session, user_id) + if not isinstance(event_orm, EventOrm) or not isinstance( + new_owner, + UserOrm, + ): + return False + + event_orm.owner = new_owner + + # TODO: The flow of changing owners isn't clear. + # If the owner must be a member of a group than this might not be needed. + event_orm.members.append(new_owner) + session.commit() + return True + + +def get_privacy(session: Session, event: Event) -> Optional[PrivacyKinds]: + """Returns the Event's privacy setting. + + Args: + session: The database connection. + event: The Event whose data is retrieved. + + Returns: + The Event's privacy setting, if found. + """ + return crud.get_property(session, event.id, EventOrm.privacy) + + +def get_video_chat_link(session: Session, event: Event) -> Optional[str]: + """Returns the Event's video chat link. + + Args: + session: The database connection. + event: The Event whose data is retrieved. + + Returns: + The Event's video chat link, if found. + """ + return crud.get_property(session, event.id, EventOrm.video_chat_link) + + +def _get_by_id( + session: Session, + event_id: int, + to_schema: bool, +) -> Optional[Union[Event, EventOrm]]: + """Returns an Event schema or database model by an ID. + + Args: + session: The database connection. + event_id: The event's ID. + to_schema: Whether to convert to schema. + Defaults to True. + + Returns: + An Event schema or database model, as requested, if successful, + otherwise returns None. + """ + keywords = {EventOrm.id.key: event_id} + return _get_by_parameter(session, to_schema, **keywords) + + +def _get_by_parameter( + session: Session, to_schema: bool = True, **kwargs +) -> Optional[Union[Event, EventOrm]]: + """Returns an Event by a parameter. + + Args: + session: The database connection. + to_schema: Whether to convert to schema. + Defaults to True. + **kwargs: The parameter to filter by. + Must be in the format of: key=value. + + Returns: + An Event schema or database model, as requested, if successful, + otherwise returns None. + """ + event = crud.get_database_model_by_parameter(session, EventOrm, **kwargs) + if isinstance(event, EventOrm): + if to_schema: + return Event.from_orm(event) + else: + return event + else: + return None diff --git a/app/database/crud/language.py b/app/database/crud/language.py new file mode 100644 index 00000000..702ce408 --- /dev/null +++ b/app/database/crud/language.py @@ -0,0 +1,49 @@ +"""CRUD functions for the Language model.""" +from typing import List, Optional + +from pydantic import parse_obj_as +from sqlalchemy.orm import Session + +from app.database.crud import crud +from app.database.models_v2 import Language as LanguageOrm +from app.database.schemas_v2 import Language + + +def get_by_id(session: Session, language_id: int) -> Optional[Language]: + """Returns a Language by an ID. + + Args: + session: The database connection. + language_id: The language's ID. + + Returns: + A Language if successful, otherwise returns None. + """ + keywords = {LanguageOrm.id.key: language_id} + language = crud.get_database_model_by_parameter( + session, LanguageOrm, **keywords + ) + if isinstance(language, LanguageOrm): + return Language.from_orm(language) + return None + + +def get_all( + session: Session, + skip: int = 0, + limit: int = 100, +) -> List[Language]: + """Returns all Languages. + + Args: + session: The database connection. + skip: The starting index. + Defaults to 0. + limit: The amount of returned items. + Defaults to 100. + + Returns: + A list of Languages. + """ + languages = crud.get_all_database_models(session, LanguageOrm, skip, limit) + return parse_obj_as(List[Language], languages) diff --git a/app/database/crud/temp_utils.py b/app/database/crud/temp_utils.py new file mode 100644 index 00000000..7e4022f7 --- /dev/null +++ b/app/database/crud/temp_utils.py @@ -0,0 +1,17 @@ +"""Temp file to prevent file conflicts while WiP.""" +from pydantic import SecretStr + +from app.internal.security.ouath2 import pwd_context + + +def get_hashed_password_v2(password: SecretStr) -> str: + """Hashes the user's password. + + Args: + password: An unhashed password. + + Returns: + A hashed password. + """ + unhashed_password = password.get_secret_value().encode("utf-8") + return pwd_context.hash(unhashed_password) diff --git a/app/database/crud/user.py b/app/database/crud/user.py new file mode 100644 index 00000000..35be7a38 --- /dev/null +++ b/app/database/crud/user.py @@ -0,0 +1,583 @@ +"""CRUD functions for the User model.""" +from typing import Any, List, Optional, Union + +from pydantic import ( + EmailStr, + SecretStr, + ValidationError, + parse_obj_as, + validate_email, +) +from pydantic.errors import EmailError +from sqlalchemy.orm import Session + +from app.database.crud import crud +from app.database.crud.temp_utils import get_hashed_password_v2 +from app.database.models_v2 import Event as EventOrm +from app.database.models_v2 import User as UserOrm +from app.database.schemas_v2 import Event, Language, User, UserCreate +from app.dependencies import logger +from app.internal.privacy import PrivacyKinds + + +def create(session: Session, user: UserCreate) -> Optional[User]: + """Returns a User after creating and saving a model in the database. + + Args: + session: The database connection. + user: The created user's data. + + Returns: + The created User if successful, otherwise returns None. + """ + hashed_password = get_hashed_password_v2(user.password) + user_orm = UserOrm( + **user.dict(exclude={"confirm_password", "password"}), + password=hashed_password, + ) + if crud.insert(session, user_orm): + return User.from_orm(user_orm) + return None + + +def delete(session: Session, user: User) -> bool: + """Deletes a User from the database. + + Args: + session: The database connection. + user: The User to delete. + + Returns: + True if successful, otherwise returns False. + """ + user_orm = get_database_user_by_id(session, user.id) + if user_orm: + return crud.delete(session, user_orm) + return False + + +def get_database_user_by_id( + session: Session, + user_id: int, +) -> Optional[UserOrm]: + """Returns a User database model by an ID. + + Args: + session: The database connection. + user_id: The User's ID. + + Returns: + A User database model if successful, otherwise returns None. + """ + user = _get_by_id(session, user_id, False) + if isinstance(user, UserOrm): + return user + return None + + +def get_by_id(session: Session, user_id: int) -> Optional[User]: + """Returns a User by an ID. + + Args: + session: The database connection. + user_id: The User's ID. + + Returns: + A User if successful, otherwise returns None. + """ + return _get_by_id(session, user_id, True) + + +def get_by_username(session: Session, username: str) -> Optional[User]: + """Returns a User by a username. + + Args: + session: The database connection. + username: The User's username. + + Returns: + A User if successful, otherwise returns None. + """ + keywords = {UserOrm.username.key: username} + return _get_by_parameter(session, True, **keywords) + + +def get_by_email(session: Session, email: str) -> Optional[User]: + """Returns a User by an email address. + + Args: + session: The database connection. + email: The User's email address. + + Returns: + A User if successful, otherwise returns None. + """ + keywords = {UserOrm.email.key: email} + return _get_by_parameter(session, True, **keywords) + + +def get_all(session: Session, skip: int = 0, limit: int = 100) -> List[User]: + """Returns all registered Users. + + Args: + session: The database connection. + skip: The starting index. + Defaults to 0. + limit: The amount of returned items. + Defaults to 100. + + Returns: + A list of registered Users. + """ + users = crud.get_all_database_models(session, UserOrm, skip, limit) + return parse_obj_as(List[User], users) + + +def get_avatar(session: Session, user: User) -> Optional[str]: + """Returns the User's avatar image. + + Args: + session: The database connection. + user: The User whose data is retrieved. + + Returns: + The User's avatar, if found. + """ + return crud.get_property(session, user.id, UserOrm.avatar) + + +def set_avatar(session: Session, user: User, file_name: str) -> bool: + """Sets a new avatar image for the User. + + Args: + session: The database connection. + user: The User to update. + file_name: The new avatar image file. + + Returns: + True if successful, otherwise returns False. + """ + return crud.set_property(session, user.id, UserOrm.avatar, file_name) + + +def get_description(session: Session, user: User) -> Optional[str]: + """Returns the User's description. + + Args: + session: The database connection. + user: The User whose data is retrieved. + + Returns: + The User's description, if found. + """ + return crud.get_property(session, user.id, UserOrm.description) + + +def set_description(session: Session, user: User, description: str) -> bool: + """Sets a new description for the User. + + Args: + session: The database connection. + user: The User to update. + description: The new description. + + Returns: + True if successful, otherwise returns False. + """ + return crud.set_property( + session, + user.id, + UserOrm.description, + description, + ) + + +def get_email(session: Session, user: User) -> Optional[str]: + """Returns the User's email address. + + Args: + session: The database connection. + user: The User whose data is retrieved. + + Returns: + The User's email address, if found. + """ + return crud.get_property(session, user.id, UserOrm.email) + + +def set_email(session: Session, user: User, email: str) -> bool: + """Sets a new email address for the User. + + Args: + session: The database connection. + user: The User to update. + email: The new email address. + + Returns: + True if successful, otherwise returns False. + """ + try: + validate_email(email) + except (TypeError, EmailError): + return False + return crud.set_property(session, user.id, UserOrm.email, email) + + +def set_full_name(session: Session, user: User, full_name: str) -> bool: + """Sets a new full name for the User. + + Args: + session: The database connection. + user: The User to update. + full_name: The new full name. + + Returns: + True if successful, otherwise returns False. + """ + try: + user.full_name = full_name + except ValidationError: + return False + return crud.set_property(session, user.id, UserOrm.full_name, full_name) + + +def is_active(session: Session, user: User) -> Optional[bool]: + """Returns whether the User's account is active or not. + + Args: + session: The database connection. + user: The User whose data is retrieved. + + Returns: + The User's account status, if found. + """ + return crud.get_property(session, user.id, UserOrm.is_active) + + +def set_active(session: Session, user: User, active: bool) -> bool: + """Sets a new account status for the User. + + Args: + session: The database connection. + user: The User to update. + active: The new status. + + Returns: + True if successful, otherwise returns False. + """ + return crud.set_property(session, user.id, UserOrm.is_active, active) + + +def is_admin(session: Session, user: User) -> Optional[bool]: + """Returns whether the User is an admin or not. + + Args: + session: The database connection. + user: The User whose data is retrieved. + + Returns: + The User's admin status, if found. + """ + return crud.get_property(session, user.id, UserOrm.is_admin) + + +def set_admin(session: Session, user: User, is_user_admin: bool) -> bool: + """Sets a new admin status for the User. + + Args: + session: The database connection. + user: The User to update. + is_user_admin: The new admin status. + + Returns: + True if successful, otherwise returns False. + """ + return crud.set_property(session, user.id, UserOrm.is_admin, is_user_admin) + + +def get_language(session: Session, user: User) -> Optional[Language]: + """Returns the User's Language. + + Args: + session: The database connection. + user: The User whose data is retrieved. + + Returns: + The User's Language, if found. + """ + user_orm = get_database_user_by_id(session, user.id) + if user_orm: + return Language.from_orm(user_orm.language) + return None + + +def set_language(session: Session, user: User, language_id: int) -> bool: + """Sets a new language ID for the User. + + Args: + session: The database connection. + user: The User to update. + language_id: The new language ID. + + Returns: + True if successful, otherwise returns False. + """ + return crud.set_property( + session, + user.id, + UserOrm.language_id, + language_id, + ) + + +def set_password(session: Session, user: User, password: SecretStr) -> bool: + """Sets a new password for the User. + + Args: + session: The database connection. + user: The User to update. + password: An unhashed password. + + Returns: + True if successful, otherwise returns False. + """ + try: + UserCreate( + **user.dict(), + # This is needed as SecretStr fails here. + password=password.get_secret_value(), + confirm_password=password, + email=EmailStr("email@gmail.com"), + ) + except (AttributeError, ValidationError): + return False + + try: + hashed_password = get_hashed_password_v2(password) + except TypeError as e: + logger.exception(e) + return False + return crud.set_property( + session, + user.id, + UserOrm.password, + hashed_password, + ) + + +def get_privacy(session: Session, user: User) -> Optional[PrivacyKinds]: + """Returns the User's privacy setting. + + Args: + session: The database connection. + user: The User whose data is retrieved. + + Returns: + The User's privacy setting, if found. + """ + return crud.get_property(session, user.id, UserOrm.privacy) + + +def set_privacy(session: Session, user: User, privacy: PrivacyKinds) -> bool: + """Sets a new privacy setting for the User. + + Args: + session: The database connection. + user: The User to update. + privacy: The new privacy setting. + + Returns: + True if successful, otherwise returns False. + """ + return crud.set_property(session, user.id, UserOrm.privacy, privacy) + + +def get_target_weight(session: Session, user: User) -> Optional[float]: + """Returns the User's target weight. + + Args: + session: The database connection. + user: The User whose data is retrieved. + + Returns: + The User's target weight, if found. + """ + return crud.get_property(session, user.id, UserOrm.target_weight) + + +def set_target_weight(session: Session, user: User, weight: float) -> bool: + """Sets a new target weight for the User. + + Args: + session: The database connection. + user: The User to update. + weight: The target weight. + + Returns: + True if successful, otherwise returns False. + """ + return crud.set_property(session, user.id, UserOrm.target_weight, weight) + + +def get_telegram_id(session: Session, user: User) -> Optional[str]: + """Returns the User's Telegram ID. + + Args: + session: The database connection. + user: The User whose data is retrieved. + + Returns: + The User's Telegram ID, if found. + """ + return crud.get_property(session, user.id, UserOrm.telegram_id) + + +def set_telegram_id(session: Session, user: User, telegram_id: str) -> bool: + """Sets a new Telegram ID for the User. + + Args: + session: The database connection. + user: The User to update. + telegram_id: The User's Telegram ID. + + Returns: + True if successful, otherwise returns False. + """ + return crud.set_property( + session, + user.id, + UserOrm.telegram_id, + telegram_id, + ) + + +def set_username(session: Session, user: User, username: str) -> bool: + """Sets a new username for the User. + + Args: + session: The database connection. + user: The User to update. + username: The new username. + + Returns: + True if successful, otherwise returns False. + """ + original_username = user.username + try: + user.username = username + except ValidationError: + return False + result = crud.set_property(session, user.id, UserOrm.username, username) + if result is False: + user.username = original_username + return False + return True + + +def get_events(session: Session, user: User) -> List[Event]: + """Returns a list of Events the User belongs to. + + Args: + session: The database connection. + user: The User whose data is retrieved. + + Returns: + A list of Events. + """ + user_orm = get_database_user_by_id(session, user.id) + if user_orm: + return parse_obj_as(List[Event], user_orm.events) + return [] + + +def get_owned_events(session: Session, user: User) -> List[Event]: + """Returns a list of Events the User has created. + + Args: + session: The database connection. + user: The User whose data is retrieved. + + Returns: + A list of Events. + """ + user_orm = get_database_user_by_id(session, user.id) + if user_orm: + return parse_obj_as(List[Event], user_orm.owned_events) + return [] + + +def get_google_calender_events(session: Session, user: User) -> List[EventOrm]: + """Returns a list of Events imported from Google Calendar. + + Args: + session: The database connection. + user: The User whose data is retrieved. + + Returns: + A list of Events. + """ + user_orm = get_database_user_by_id(session, user.id) + if not user_orm: + return [] + return [event for event in user_orm.events if event.is_google_event] + + +def delete_all_google_calendar_events(session: Session, user: User) -> bool: + """Deletes all of a User's imported Google Calendar Events. + + Args: + session: The database connection. + user: The User to delete data from. + + Returns: + True if successful, otherwise returns False. + """ + events = get_google_calender_events(session, user) + return crud.delete_multiple(session, events) + + +def _get_by_id( + session: Session, + user_id: int, + to_schema: bool, +) -> Optional[Union[User, UserOrm]]: + """Returns a User schema or database model by an ID. + + Args: + session: The database connection. + user_id: The user's ID. + to_schema: Whether to convert to schema. + Defaults to True. + + Returns: + A User schema or database model, as requested, if successful, + otherwise returns None. + """ + keywords = {UserOrm.id.key: user_id} + return _get_by_parameter(session, to_schema, **keywords) + + +def _get_by_parameter( + session: Session, to_schema: bool = True, **kwargs: Any +) -> Optional[Union[User, UserOrm]]: + """Returns a User by a parameter. + + Args: + session: The database connection. + to_schema: Whether to convert to schema. + Defaults to True. + **kwargs: The parameter to filter by. + Must be in the format of: key=value. + + Returns: + A User schema or database model, as requested, if successful, + otherwise returns None. + """ + user = crud.get_database_model_by_parameter(session, UserOrm, **kwargs) + if isinstance(user, UserOrm): + if to_schema: + return User.from_orm(user) + else: + return user + else: + return None diff --git a/app/database/models_v2.py b/app/database/models_v2.py new file mode 100644 index 00000000..e55fb9ac --- /dev/null +++ b/app/database/models_v2.py @@ -0,0 +1,304 @@ +"""SQLAlchemy database models.""" +from sqlalchemy import ( + Boolean, + Column, + DateTime, + Enum, + Float, + ForeignKey, + Index, + Integer, + String, + Table, + event, +) +from sqlalchemy.dialects.postgresql import TSVECTOR +from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base +from sqlalchemy.orm import Session, relationship + +from app.config import PSQL_ENVIRONMENT +from app.internal.privacy import PrivacyKinds + +Base: DeclarativeMeta = declarative_base() + +PRIMARY_KEY_DOC = "An auto increment unique primary key." + + +class Event(Base): + """A database model of an Event entity.""" + + __tablename__ = "event" + + id = Column( + Integer, + primary_key=True, + index=True, + nullable=False, + doc=PRIMARY_KEY_DOC, + ) + color = Column( + String, + nullable=True, + doc="", # TODO: doc + ) + content = Column( + String, + doc="", # TODO: doc + ) + datetime_end = Column( + DateTime, + nullable=False, + doc="The event's end datetime.", + ) + datetime_start = Column( + DateTime, + nullable=False, + doc="The event's start datetime.", + ) + emotion = Column( + String, + nullable=True, + doc="", # TODO: doc + ) + image = Column( + String, + nullable=True, + doc="The event's image file name.", + ) + invited_emails = Column( + String, + doc="A list of emails separated by a comma.", # TODO: Refactor? + ) + is_all_day = Column( + Boolean, + default=False, + doc="Whether the event spans the whole day.", + ) + # TODO: This is not an event data but a user data that shows on the event. + is_available = Column( + Boolean, + default=True, + nullable=False, + doc="", # TODO: doc + ) + is_google_event = Column( + Boolean, + default=False, + doc="Whether the event was imported from a Google Calendar", + ) + latitude = Column( + Float, + nullable=True, + doc="The latitude of the event's location.", + ) + location = Column( + String, + nullable=True, + doc="The location of the event.", + ) + longitude = Column( + Float, + nullable=True, + doc="The longitude of the event's location.", + ) + owner_id = Column( + Integer, + ForeignKey("user.id"), + nullable=False, + doc="The ID of the user who created the event.", + ) + privacy = Column( + Enum(PrivacyKinds), + default=PrivacyKinds.Public, + nullable=False, + doc="The event's privacy setting. Defaults to PrivacyKinds.Public.", + ) + title = Column( + String, + nullable=False, + doc="The title of the event.", + ) + video_chat_link = Column( + String, + nullable=True, + doc="The event's video chat link.", + ) + members = relationship( + "User", + secondary="user_event", + back_populates="events", + doc="A list of users who are attending the event.", + ) + owner = relationship( + "User", + back_populates="owned_events", + doc="The user who created the event.", + ) + + # PostgreSQL + if PSQL_ENVIRONMENT: + events_tsv = Column(TSVECTOR) + __table_args__ = ( + Index("events_tsv_idx", "events_tsv", postgresql_using="gin"), + ) + + def __repr__(self): + return f"" + + +class Language(Base): + """A database model of a Language entity. + + Languages in the database are the ones which are supported by the website. + """ + + __tablename__ = "language" + + id = Column( + Integer, + primary_key=True, + index=True, + nullable=False, + doc=PRIMARY_KEY_DOC, + ) + name = Column( + String, + unique=True, + nullable=False, + doc="The name of the language, in the language's script.", + ) + code = Column( + String, + unique=True, + nullable=False, + doc="The ISO code of the language.", + ) + + +class User(Base): + """A database model of a User entity.""" + + __tablename__ = "user" + + id = Column( # TODO: Should id be changed to a UUID? + Integer, + primary_key=True, + index=True, + nullable=False, + doc=PRIMARY_KEY_DOC, + ) + avatar = Column( + String, + default="profile.png", + nullable=False, + doc="The user's avatar image file name. Defaults to 'profile.png'.", + ) + description = Column( + String, + nullable=True, + doc="A freeform description field.", + ) + email = Column( + String, + unique=True, + nullable=False, + doc="A unique email. Email must be valid.", + ) + full_name = Column( + String, + nullable=False, + doc="The user's full name. A full name must be at least 2 characters" + " long.", + ) + is_active = Column( + Boolean, + default=True, + nullable=False, + doc="Whether the user account is active or temporarily disabled." + " Defaults to True.", + ) + is_admin = Column( + Boolean, + default=False, + nullable=False, + doc="Whether the user is a site admin. Defaults to False.", + ) + language_id = Column( + Integer, + ForeignKey("language.id"), + default=1, + nullable=False, + doc="The user's UI language's ID. Defaults to 1.", + ) + password = Column( + String, + nullable=False, + doc="The user's password. A password must be a minimum length of 3" + " characters and a maximum length of 20 characters.", + ) + privacy = Column( + Enum(PrivacyKinds), + default=PrivacyKinds.Private, + nullable=False, + doc="The user's privacy setting. Defaults to PrivacyKinds.Private.", + ) + target_weight = Column( + Float, + nullable=True, + doc="The user's target weight goal.", + ) + telegram_id = Column( + String, + unique=True, + nullable=True, + doc="A unique Telegram ID.", + ) + username = Column( + String, + unique=True, + nullable=False, + doc="A unique username. A valid username must be a minimum length of 3" + " characters, a maximum length of 20 characters, and use only the" + " following characters: a-zA-Z0-9_.", + ) + events = relationship( + "Event", + secondary="user_event", + cascade="all, delete", + back_populates="members", + doc="A list of events the user is attending.", + ) + language = relationship( + "Language", + doc="The user's UI language.", + ) + owned_events = relationship( + "Event", + cascade="all, delete", + back_populates="owner", + doc="A list of events the user created.", + ) + + def __repr__(self): + return f"" + + +user_event = Table( + "user_event", + Base.metadata, + Column("user.id", Integer, ForeignKey("user.id"), primary_key=True), + Column("event.id", Integer, ForeignKey("event.id"), primary_key=True), +) + + +# TODO: move this into a json file and load it from the json loader. +def insert_data(target, session: Session, **kw): + """Inserts the supported languages into the Language table.""" + session.execute( + target.insert(), + {"id": 1, "name": "English", "code": "en"}, + {"id": 2, "name": "עברית", "code": "he"}, + ) + + +event.listen(Language.__table__, "after_create", insert_data) diff --git a/app/database/schemas_v2.py b/app/database/schemas_v2.py new file mode 100644 index 00000000..f44b611a --- /dev/null +++ b/app/database/schemas_v2.py @@ -0,0 +1,237 @@ +"""Pydantic schema models.""" +from datetime import datetime +from typing import Any, Dict + +from pydantic import BaseModel, EmailStr, Field, HttpUrl, SecretStr, validator + +from app.database import models_v2 +from app.internal.privacy import PrivacyKinds + + +class EventBase(BaseModel): + """Base Event schema model. Should not be used directly.""" + + title: str = Field( + description=models_v2.Event.title.__doc__, + example="A title of an event", + ) + + owner_id: int = Field( + description=models_v2.Event.owner_id.__doc__, + example=1, + ) + + datetime_start: datetime = Field( + description=models_v2.Event.datetime_start.__doc__, + example="2021-03-18 13:00:00", + ) + + datetime_end: datetime = Field( + description=models_v2.Event.datetime_end.__doc__, + example="2021-03-18 14:00:00", + ) + + is_all_day: bool = Field( + description=models_v2.Event.is_all_day.__doc__, + example=False, + ) + + class Config: # noqa + orm_mode = True + validate_assignment = True + + +class EventCreate(EventBase): + """Event schema used for entity creation. + + Extends :class:`EventBase` with event creation information. + """ + + color: str = Field( + description=models_v2.Event.color.__doc__, + example="Red", + ) + + content: str = Field( + description=models_v2.Event.content.__doc__, + example="The event's content", + ) + + emotion: str = Field( + description=models_v2.Event.emotion.__doc__, + example="", + ) + + image: str = Field( + description=models_v2.Event.image.__doc__, + example="event_image.png", + ) + + invited_emails: str = Field( + description=models_v2.Event.invited_emails.__doc__, + example="invite1@gmail.com, invite2@gmail.com", + ) + + is_available: bool = Field( + description=models_v2.Event.is_available.__doc__, + example=False, + ) + + is_google_event: bool = Field( + description=models_v2.Event.is_google_event.__doc__, + example=False, + ) + + latitude: float = Field( + description=models_v2.Event.latitude.__doc__, + example="32.0853", + ) + + location: str = Field( + description=models_v2.Event.location.__doc__, + example="Tel Aviv", + ) + + longitude: float = Field( + description=models_v2.Event.longitude.__doc__, + example="34.7818", + ) + + privacy: PrivacyKinds = Field( + description=models_v2.Event.privacy.__doc__, + example="PrivacyKinds.Public", + ) + + video_chat_link: HttpUrl = Field( + description=models_v2.Event.video_chat_link.__doc__, + example="https://www.link.com", + ) + + +class Event(EventBase): + """Event schema used for general use. + + Extends :class:`EventBase` with standard-use event information. + """ + + id: int = Field( + allow_mutation=False, + description=models_v2.Event.id.__doc__, + example=1, + ) + + +class EventAll(EventCreate, Event): + """Event schema used for full updates. + + Extends :class:`Event` and :class:`EventCreate. + """ + + pass + + +class Language(BaseModel): + """Language schema model.""" + + id: int = Field( + description=models_v2.Language.id.__doc__, + example=1, + ) + + name: str = Field( + description=models_v2.Language.name.__doc__, + example="English", + ) + + code: str = Field( + description=models_v2.Language.code.__doc__, + example="en", + ) + + class Config: # noqa + orm_mode = True + + +class UserBase(BaseModel): + """Base User schema model. Should not be used directly.""" + + username: str = Field( + min_length=3, + max_length=20, + regex="^[a-zA-Z0-9_]+$", + description=models_v2.User.username.__doc__, + example="user4", + ) + + full_name: str = Field( + min_length=2, + description=models_v2.User.full_name.__doc__, + example="John Locke", + ) + + class Config: # noqa + validate_assignment = True + + +class UserCreate(UserBase): + """User schema used for entity creation. + + Extends :class:`UserBase` with user registration information.""" + + password: SecretStr = Field( + min_length=3, + max_length=20, + description="The user's password.", + example="L108JL!", + ) + + confirm_password: SecretStr = Field( + description="Re-entry of the user's password for validation.", + example="L108JL!", + ) + + email: EmailStr = Field( + description=models_v2.User.email.__doc__, + example="user@gmail.com", + ) + + @validator("confirm_password") + def passwords_match( + cls, + confirm_password: SecretStr, + values: Dict[str, Any], + ) -> SecretStr: + """Validates the user correctly re-entered their password. + + Args: + confirm_password: The re-entered password. + values: All class field values. + + Returns: + The re-entered password. + + Raises: + ValueError if the passwords do not match. + """ + if "password" in values and confirm_password != values["password"]: + raise ValueError("Passwords do not match") + return confirm_password + + class Config: # noqa + orm_mode = True + + +class User(UserBase): + """User schema used for general use. + + Extends :class:`UserBase` with standard-use user information. + """ + + id: int = Field( + allow_mutation=False, + description=models_v2.User.id.__doc__, + example=1, + ) + + class Config: # noqa + orm_mode = True diff --git a/app/dependencies.py b/app/dependencies.py index 01cdcf56..0d4cd61f 100644 --- a/app/dependencies.py +++ b/app/dependencies.py @@ -14,6 +14,7 @@ STATIC_PATH = os.path.join(APP_PATH, "static") TEMPLATES_PATH = os.path.join(APP_PATH, "templates") SOUNDS_PATH = os.path.join(STATIC_PATH, "tracks") +LOCALES_PATH = os.path.join(APP_PATH, "locales") templates = Jinja2Templates(directory=TEMPLATES_PATH) templates.env.add_extension("jinja2.ext.i18n") diff --git a/app/internal/languages.py b/app/internal/languages.py index 420c2dab..e1abdf62 100644 --- a/app/internal/languages.py +++ b/app/internal/languages.py @@ -4,10 +4,8 @@ from typing import Iterator from app import config -from app.dependencies import templates +from app.dependencies import LOCALES_PATH, templates -LANGUAGE_DIR = "app/locales" -LANGUAGE_DIR_TEST = "../app/locales" TRANSLATION_FILE = "base" @@ -20,8 +18,8 @@ def set_ui_language(language: str = None) -> None: Args: language: Optional; A valid language code that follows RFC 1766. Defaults to None. - See also the Language Code Identifier (LCID) Reference for a list of - valid language codes. + See also the Language Code Identifier (LCID) Reference for a list + of valid language codes. .. _RFC 1766: https://tools.ietf.org/html/rfc1766.html @@ -34,14 +32,15 @@ def set_ui_language(language: str = None) -> None: # if not language: # language = _get_display_language(user_id: int) - language_dir = _get_language_directory() - - if language not in list(_get_supported_languages(language_dir)): + try: + if language not in set(_get_supported_languages()): + language = config.WEBSITE_LANGUAGE + except TypeError: language = config.WEBSITE_LANGUAGE translations = gettext.translation( TRANSLATION_FILE, - localedir=language_dir, + localedir=LOCALES_PATH, languages=[language], ) translations.install() @@ -59,27 +58,7 @@ def set_ui_language(language: str = None) -> None: # return config.WEBSITE_LANGUAGE -def _get_language_directory() -> str: - """Returns the language directory relative path.""" - language_dir = LANGUAGE_DIR - if Path(LANGUAGE_DIR_TEST).is_dir(): - # If running from test, change dir path. - language_dir = LANGUAGE_DIR_TEST - return language_dir - - -def _get_supported_languages( - language_dir: str = _get_language_directory() -) -> Iterator[str]: - """Returns a generator of supported translation languages codes. - - Args: - language_dir: Optional; The path of the language directory. - Defaults to the return value of _get_language_directory(). - - Returns: - A generator expression of the supported translation languages codes. - """ - - paths = [Path(f.path) for f in os.scandir(language_dir) if f.is_dir()] +def _get_supported_languages() -> Iterator[str]: + """Returns a generator of supported translation languages codes.""" + paths = (Path(f.path) for f in os.scandir(LOCALES_PATH) if f.is_dir()) return (language.name for language in paths) diff --git a/requirements.txt b/requirements.txt index 9612912e..836c2072 100644 --- a/requirements.txt +++ b/requirements.txt @@ -44,7 +44,7 @@ fastapi==0.63.0 fastapi-login==1.5.2 fastapi-mail==0.3.3.1 filelock==3.0.12 -flake8==3.8.4 +flake8==3.9.0 frozendict==1.2 geographiclib==1.50 geopy==2.1.0 @@ -94,7 +94,7 @@ passlib==1.7.4 pathspec==0.8.1 Pillow==8.1.0 pluggy==0.13.1 -pre-commit==2.10.0 +pre-commit==2.11.1 priority==1.3.0 protobuf==3.14.0 psycopg2==2.8.6 @@ -102,10 +102,10 @@ psycopg2-binary==2.8.6 py==1.10.0 pyasn1==0.4.8 pyasn1-modules==0.2.8 -pycodestyle==2.6.0 +pycodestyle==2.7.0 pycparser==2.20 -pydantic==1.7.3 -pyflakes==2.2.0 +pydantic==1.8.1 +pyflakes==2.3.0 PyJWT==2.0.0 pyparsing==2.4.7 pytest==6.2.1 diff --git a/tests/crud/__init__.py b/tests/crud/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/crud/conftest.py b/tests/crud/conftest.py new file mode 100644 index 00000000..ac755858 --- /dev/null +++ b/tests/crud/conftest.py @@ -0,0 +1,226 @@ +"""Fixtures for schemas and models v.2""" +from datetime import datetime +from typing import Callable, Generator, Optional + +import pytest +from pydantic import EmailStr, SecretStr +from sqlalchemy.orm import Session + +from app.database.crud import event as event_crud +from app.database.crud import user as user_crud +from app.database.models_v2 import Base +from app.database.schemas_v2 import ( + Event, + EventCreate, + Language, + User, + UserCreate, +) +from app.internal.privacy import PrivacyKinds +from tests.conftest import get_test_db, test_engine + + +@pytest.fixture +def event_create(user_v2: User) -> EventCreate: + """Returns an EventCreate entity fixture. + + Args: + user_v2: A User entity. + + Returns: + An EventCreate object. + """ + return EventCreate( + datetime_start=datetime.today(), + datetime_end=datetime.today(), + color="red", + content="content", + emotion="emotion", + image="event_image.png", + invited_emails="invite1@gmail.com, invite2@gmail.com", + is_available=False, + is_all_day=True, + is_google_event=False, + latitude=32.0853, + location="Tel Aviv", + longitude=34.7818, + owner_id=user_v2.id, + privacy=PrivacyKinds.Public, + title="title", + video_chat_link="https://www.link.com", + ) + + +@pytest.fixture +def event_v2( + event_create: EventCreate, + create_events: Callable[ + [int, datetime, datetime], + Optional[Event], + ], +) -> Generator[Optional[Event], None, None]: + """Yields a single Event entity fixture. + + This is a convenience wrapper when there is only one Event entity needed + for tests. + + Args: + event_create: The EventCreate fixture. + create_events: A generator which creates database Event entities. + + Yields: + A User entity. + """ + event = create_events( + event_create.owner_id, + event_create.datetime_start, + event_create.datetime_end, + ) + yield event + + +@pytest.fixture +def create_events( + session_v2: Session, + event_create: EventCreate, +) -> Generator[ + Callable[[int, datetime, datetime], Optional[Event]], + None, + None, +]: + """Yields a function fixture which creates Event entities. + + Args: + session_v2: The database connection fixture. + event_create: The EventCreate fixture. + + Returns: + An Event creation function fixture. + """ + created_events = [] + + def _create_events( + owner_id: int, + start_datetime: datetime, + end_datetime: datetime, + ) -> Optional[Event]: + """Returns an Event which was saved in the database. + + Args: + owner_id: The ID of the User who created the Event. + start_datetime: The start date of the Event. + end_datetime: The end date of the Event. + + Returns: + An Event entity. + """ + event_create.owner_id = owner_id + event_create.datetime_start = start_datetime + event_create.datetime_end = end_datetime + + event = event_crud.create(session_v2, event_create) + created_events.append(event) + return event + + yield _create_events + + for created_event in created_events: + if created_event: + event_crud.delete(session_v2, created_event) + + +@pytest.fixture +def language(): + """Returns a Language entity fixture.""" + return Language(id=1, name="English", code="en") + + +@pytest.fixture +def session_v2() -> Generator[Session, None, None]: + """Yields a Session entity fixture to connect to the database. + + Additionally, the database tables are created before usage and destroyed + when tests are completed. + + Yields: + A Session object. + """ + Base.metadata.create_all(bind=test_engine) + session = get_test_db() + yield session + session.rollback() + session.close() + Base.metadata.drop_all(bind=test_engine) + + +@pytest.fixture +def user_create() -> UserCreate: + """Returns a UserCreate entity fixture.""" + return UserCreate( + username="username", + full_name="full_name", + password="12345", + confirm_password=SecretStr("12345"), + email=EmailStr("email@gmail.com"), + ) + + +@pytest.fixture +def user_v2( + user_create, + create_users: Callable[[str, str], Optional[User]], +) -> Generator[Optional[User], None, None]: + """Yields a single User entity fixture. + + This is a convenience wrapper when there is only one User entity needed + for tests. + + Args: + user_create: The UserCreate fixture. + create_users: A generator which creates database User entities. + + Yields: + A User entity. + """ + user = create_users(user_create.username, user_create.email) + yield user + + +@pytest.fixture +def create_users( + session_v2: Session, + user_create: UserCreate, +) -> Generator[Callable[[str, str], Optional[User]], None, None]: + """Yields a function fixture which creates User entities. + + Args: + session_v2: The database connection fixture. + user_create: The UserCreate fixture. + + Returns: + A User creation function fixture. + """ + created_users = [] + + def _create_users(username: str, email: str) -> Optional[User]: + """Returns a User which was saved in the database. + + Args: + username: The User's unique username. + email: The User's unique email address. + + Returns: + A User entity. + """ + user_create.username = username + user_create.email = EmailStr(email) + + user = user_crud.create(session_v2, user_create) + created_users.append(user) + return user + + yield _create_users + + for created_user in created_users: + if created_user: + user_crud.delete(session_v2, created_user) diff --git a/tests/crud/test_crud.py b/tests/crud/test_crud.py new file mode 100644 index 00000000..685c1d22 --- /dev/null +++ b/tests/crud/test_crud.py @@ -0,0 +1,20 @@ +"""Tests for CRUD general low-level functions.""" +import pytest +from sqlalchemy.exc import SQLAlchemyError + +from app.database.crud import crud +from app.database.models_v2 import User +from app.database.models_v2 import User as UserOrm +from tests.conftest import get_test_db + + +def test_create_fail_no_tables_error(): + user = User() + session = get_test_db() + with pytest.raises(SQLAlchemyError): + crud.insert(session, user) + + +def test_get_all_database_models_no_tables_error(): + session = get_test_db() + assert crud.get_all_database_models(session, UserOrm) == [] diff --git a/tests/crud/test_event.py b/tests/crud/test_event.py new file mode 100644 index 00000000..b6950997 --- /dev/null +++ b/tests/crud/test_event.py @@ -0,0 +1,234 @@ +"""Tests for CRUD functions of the Event model.""" +from datetime import datetime +from typing import Callable, Optional + +import pytest +from sqlalchemy.orm import Session + +from app.database.crud import event as crud +from app.database.crud import user as user_crud +from app.database.schemas_v2 import Event, EventAll, EventCreate, User +from tests.crud.test_util import ( + get_attribute_value, + get_boolean_getter_function, + get_getter_function, +) + + +def test_create_event(session_v2: Session, event_create: EventCreate): + assert crud.create(session_v2, event_create) + + +def test_delete(session_v2: Session, event_v2: Event): + # Deletion of existing event - successful. + assert crud.delete(session_v2, event_v2) + + # Deletion of non-existing event - unsuccessful. + assert crud.delete(session_v2, event_v2) is False + + +def test_get_by_id(session_v2: Session, event_v2: Event): + database_event = crud.get_by_id(session_v2, event_v2.id) + assert database_event + assert database_event.owner_id == event_v2.owner_id + assert database_event.datetime_end == event_v2.datetime_end + assert crud.get_by_id(session_v2, 2) is None + + +@pytest.mark.parametrize("number_of_events", [0, 1, 2]) +def test_get_all_events( + session_v2: Session, + create_events: Callable[[int, datetime, datetime], Optional[Event]], + number_of_events: int, +): + for i in range(number_of_events): + create_events(1, datetime.today(), datetime.today()) + assert len(crud.get_all(session_v2)) == number_of_events + + +def test_update_event( + session_v2: Session, + event_create: EventCreate, + event_v2: Event, +): + event_all = EventAll(**event_create.dict(), id=event_v2.id) + event_all.title = "updated title" + event_all.location = "updated location" + + event_from_database = crud.get_by_id(session_v2, event_v2.id) + assert event_from_database + current_title = event_from_database.title + current_location = crud.get_location(session_v2, event_v2) + assert crud.update_event(session_v2, event_all) + + event_from_database = crud.get_by_id(session_v2, event_v2.id) + assert event_from_database + updated_title = event_from_database.title + updated_location = crud.get_location(session_v2, event_v2) + + assert current_title != updated_title and updated_title == event_all.title + + assert ( + current_location != updated_location + and updated_location == event_all.location + ) + + +COLUMNS_BASE_TESTS = [ + "title", + "owner_id", + "datetime_start", + "datetime_end", + "is_all_day", +] + + +@pytest.mark.parametrize("column_name", COLUMNS_BASE_TESTS) +def test_get_value_base_fields( + session_v2: Session, + event_create: EventCreate, + event_v2: Event, + column_name: str, +): + value_from_create = get_attribute_value(event_create, column_name) + value_from_database = get_attribute_value(event_v2, column_name) + assert value_from_create == value_from_database + + +COLUMNS_STANDARD_TESTS = [ + "color", + "content", + "emotion", + "image", + "invited_emails", + "latitude", + "location", + "longitude", + "privacy", + "video_chat_link", +] + + +@pytest.mark.parametrize("column_name", COLUMNS_STANDARD_TESTS) +def test_get_value_standard_fields( + session_v2: Session, + event_create: EventCreate, + event_v2: Event, + column_name: str, +): + getter_function = get_getter_function(crud, column_name) + value = getattr(event_create, column_name) + assert getter_function + assert getter_function(session_v2, event_v2) == value + + +COLUMNS_BOOLEAN_TESTS = [ + "available", + "google_event", +] + + +@pytest.mark.parametrize("column_name", COLUMNS_BOOLEAN_TESTS) +def test_get_value_boolean_fields( + session_v2: Session, + event_create: EventCreate, + event_v2: Event, + column_name: str, +): + getter_function = get_boolean_getter_function(crud, column_name) + value = getattr(event_create, f"is_{column_name}") + assert getter_function + assert getter_function(session_v2, event_v2) == value + + +def test_get_change_owner( + session_v2: Session, + event_v2: Event, + create_users: Callable[[str, str], Optional[User]], +): + original_owner = crud.get_owner(session_v2, event_v2) + assert original_owner + original_owner_events = user_crud.get_events( + session_v2, + original_owner, + ) + + original_owner_owned_events = user_crud.get_owned_events( + session_v2, + original_owner, + ) + + # Verify original values are 1. + assert len(original_owner_events) == 1 + assert len(original_owner_owned_events) == 1 + + user2 = create_users("user2", "email2@gmail.com") + assert user2 + + new_user_events = user_crud.get_events(session_v2, user2) + new_user_owned_events = user_crud.get_owned_events(session_v2, user2) + + # Verify default values are 0. + assert len(new_user_events) == 0 + assert len(new_user_owned_events) == 0 + + assert crud.change_owner(session_v2, event_v2, user2.id) + new_owner = crud.get_owner(session_v2, event_v2) + new_owner_events = user_crud.get_events(session_v2, user2) + new_owner_owned_events = user_crud.get_owned_events(session_v2, user2) + + # Verify new owner changed and values are updated to 1. + assert new_owner == user2 + assert len(new_owner_events) == 1 + assert len(new_owner_owned_events) == 1 + + original_owner_events = user_crud.get_events( + session_v2, + original_owner, + ) + original_owner_owned_events = user_crud.get_owned_events( + session_v2, + original_owner, + ) + + # Verify original_owner values are updated. + assert len(original_owner_events) == 1 + assert len(original_owner_owned_events) == 0 + + +def test_get_add_delete_members( + session_v2: Session, + create_users: Callable[[str, str], Optional[User]], + create_events: Callable[[int, datetime, datetime], Optional[Event]], +): + user1 = create_users("username1", "email1@gmail.com") + assert user1 + + event = create_events(user1.id, datetime.today(), datetime.today()) + assert event + + # Owner should be the only member. + members = crud.get_members(session_v2, event) + assert len(members) == 1 + + user2 = create_users("username2", "email2@gmail.com") + assert user2 + + # Add a new user. + assert crud.add_member(session_v2, event, user2) + members = crud.get_members(session_v2, event) + assert len(members) == 2 + + # Try and re-add the same users. It should fail. + crud.add_member(session_v2, event, user1) + members = crud.get_members(session_v2, event) + assert len(members) == 2 + + crud.add_member(session_v2, event, user2) + members = crud.get_members(session_v2, event) + assert len(members) == 2 + + # Remove user. Only owner left. + assert crud.remove_member(session_v2, event, user2) + members = crud.get_members(session_v2, event) + assert len(members) == 1 diff --git a/tests/crud/test_language.py b/tests/crud/test_language.py new file mode 100644 index 00000000..65119162 --- /dev/null +++ b/tests/crud/test_language.py @@ -0,0 +1,25 @@ +"""Tests for CRUD functions of the Language model.""" +import pytest +from sqlalchemy.orm import Session + +from app.database.crud import language as crud + +LANGUAGE_ID_TESTS = [ + (1, True), + (2, True), + (50, False), +] + + +@pytest.mark.parametrize("language_id, is_valid", LANGUAGE_ID_TESTS) +def test_get_by_id(session_v2: Session, language_id: int, is_valid: bool): + language = crud.get_by_id(session_v2, language_id) + if is_valid: + assert language + else: + assert language is None + + +def test_get_all_users(session_v2: Session): + languages = crud.get_all(session_v2) + assert len(languages) > 0 diff --git a/tests/crud/test_schemas.py b/tests/crud/test_schemas.py new file mode 100644 index 00000000..899b53fa --- /dev/null +++ b/tests/crud/test_schemas.py @@ -0,0 +1,68 @@ +"""Tests for Pydantic schema models.""" +import pytest +from pydantic import SecretStr, ValidationError + +from app.database.schemas_v2 import EventAll, EventCreate, Language, UserCreate + + +class TestUser: + @staticmethod + def test_user_create(user_create: UserCreate): + assert user_create + + USERNAME_ERRORS = [ + "u", + "usernameusernameusername", + "username!#$", + ] + + @staticmethod + @pytest.mark.parametrize("username", USERNAME_ERRORS) + def test_user_create_username_errors( + user_create: UserCreate, + username: str, + ): + with pytest.raises(ValidationError): + user_create.username = username + UserCreate(**user_create.dict()) + + @staticmethod + def test_user_create_full_name_errors(user_create: UserCreate): + with pytest.raises(ValidationError): + user_create.full_name = "a" + UserCreate(**user_create.dict()) + + PASSWORD_ERRORS = [ + ("p", "p"), + ("password is very very long", "password is very very long"), + ("password!#$", "password"), + ] + + @staticmethod + @pytest.mark.parametrize("password, confirm", PASSWORD_ERRORS) + def test_user_create_password_errors( + user_create: UserCreate, + password: str, + confirm: str, + ): + with pytest.raises(ValidationError): + user_create.password = password + user_create.confirm_password = SecretStr(confirm) + UserCreate(**user_create.dict()) + + +class TestEvent: + @staticmethod + def test_event_create(event_create: EventCreate): + assert event_create + + @staticmethod + def test_event_all(event_create: EventCreate): + event_all = EventAll(**event_create.dict(), id=1) + assert event_all + + +class TestLanguage: + @staticmethod + def test_language(language: Language): + assert language diff --git a/tests/crud/test_user.py b/tests/crud/test_user.py new file mode 100644 index 00000000..0df4561d --- /dev/null +++ b/tests/crud/test_user.py @@ -0,0 +1,499 @@ +"""Tests for CRUD functions of the User model.""" +from datetime import datetime +from typing import Any, Callable, Optional + +import pytest +from pydantic import SecretStr +from sqlalchemy.orm import Session + +from app.database.crud import event as event_crud +from app.database.crud import user as crud +from app.database.models_v2 import User as UserOrm +from app.database.schemas_v2 import Event, Language, User, UserCreate +from app.internal.privacy import PrivacyKinds +from tests.crud.test_util import ( + get_boolean_getter_function, + get_getter_function, + get_setter_function, +) + + +def test_create(session_v2: Session, user_create: UserCreate): + # Unique creation - successful. + assert crud.create(session_v2, user_create) + + # Creation with duplicate unique fields - unsuccessful. + assert crud.create(session_v2, user_create) is None + + +def test_delete(session_v2: Session, user_v2: User, event_v2: Event): + events = crud.get_events(session_v2, user_v2) + assert len(events) == 1 + + # Deletion of existing user - successful. + assert crud.delete(session_v2, user_v2) + + # User event's should be deleted automatically. + events = crud.get_events(session_v2, user_v2) + assert len(events) == 0 + + # Verify event is not in the database. + event_from_database = event_crud.get_by_id(session_v2, event_v2.id) + assert event_from_database is None + + # Deletion of non-existing user - unsuccessful. + assert crud.delete(session_v2, user_v2) is False + + +def test_get_database_user_by_id(session_v2: Session, user_v2: User): + database_user = crud.get_database_user_by_id(session_v2, user_v2.id) + assert isinstance(database_user, UserOrm) + assert database_user.id == user_v2.id + assert database_user.username == user_v2.username + assert crud.get_by_id(session_v2, 2) is None + + +def test_get_by_id(session_v2: Session, user_v2: User): + database_user = crud.get_by_id(session_v2, user_v2.id) + assert isinstance(database_user, User) + assert database_user.id == user_v2.id + assert database_user.username == user_v2.username + assert crud.get_by_id(session_v2, 2) is None + + +def test_get_by_username(session_v2: Session, user_v2: User): + schema_model = crud.get_by_username(session_v2, user_v2.username) + assert isinstance(schema_model, User) + assert schema_model.id == user_v2.id + assert schema_model.username == user_v2.username + assert crud.get_by_username(session_v2, "bad username") is None + + +def test_get_by_email(session_v2: Session, user_v2: User): + schema_model = crud.get_by_email(session_v2, "email@gmail.com") + assert isinstance(schema_model, User) + assert schema_model.id == user_v2.id + assert schema_model.username == user_v2.username + assert crud.get_by_email(session_v2, "bad email") is None + + +@pytest.mark.parametrize("number_of_users", [0, 1, 2]) +def test_get_all_users( + session_v2: Session, + create_users: Callable[[str, str], Optional[User]], + number_of_users: int, +): + for i in range(number_of_users): + create_users(f"username{i}", f"email{i}@gmail.com") + assert len(crud.get_all(session_v2)) == number_of_users + + +COLUMNS_STANDARD_TESTS = [ + ("avatar", "file_path"), + ("description", "description"), + ("target_weight", 55.5), +] + + +@pytest.mark.parametrize("column_name, value", COLUMNS_STANDARD_TESTS) +def test_get_set_value_standard_fields( + session_v2: Session, + user_v2: User, + column_name: str, + value: Any, +): + setter_function = get_setter_function(crud, column_name) + assert setter_function(session_v2, user_v2, value) + getter_function = get_getter_function(crud, column_name) + assert getter_function + assert getter_function(session_v2, user_v2) == value + + +COLUMNS_BOOLEAN_TESTS = [ + "active", + "admin", +] + + +@pytest.mark.parametrize("column_name", COLUMNS_BOOLEAN_TESTS) +def test_get_set_value_boolean_fields( + session_v2: Session, + user_v2: User, + column_name: str, +): + for state in [True, False]: + setter_function = get_setter_function(crud, column_name) + assert setter_function(session_v2, user_v2, state) + getter_function = get_boolean_getter_function(crud, column_name) + assert getter_function(session_v2, user_v2) == state + + +COLUMN_UNIQUE_TESTS = [ + ("telegram_id", "TEST1234567TEST"), + ("email", "test_email@gmail.com"), + ("username", "test_username"), +] + + +@pytest.mark.parametrize("column_name, value", COLUMN_UNIQUE_TESTS) +def test_get_set_unique_columns( + session_v2: Session, + create_users: Callable[[str, str], Optional[User]], + column_name: str, + value: str, +): + users = [] + for i in range(2): + users.append(create_users(f"username{i}", f"email{i}@gmail.com")) + + user1 = users[0] + user2 = users[1] + assert user1 + assert user2 + + getter_function = get_getter_function(crud, column_name) + setter_function = get_setter_function(crud, column_name) + + # Set a unique value for user1. + change_one = True + assert setter_function(session_v2, user1, value) is change_one + + database_result_user1 = _get_getter_result( + session_v2, + user1, + getter_function, + column_name, + ) + + _validate_getter_result(database_result_user1, None, value, change_one) + + original_result_user2 = _get_getter_result( + session_v2, + user2, + getter_function, + column_name, + ) + + # User2 cannot have the same unique value. + change_two = False + assert setter_function(session_v2, user2, value) is change_two + + database_result_user2 = _get_getter_result( + session_v2, + user2, + getter_function, + column_name, + ) + + _validate_getter_result( + database_result_user2, + original_result_user2, + value, + change_two, + ) + + # Change the unique value for user1. + change_three = True + new_value = f"new_{value}" + assert setter_function(session_v2, user1, new_value) is change_three + + database_result_user1 = _get_getter_result( + session_v2, + user1, + getter_function, + column_name, + ) + + _validate_getter_result( + database_result_user1, + None, + new_value, + change_three, + ) + + # User2 can now have the previously unavailable unique value. + change_four = True + assert setter_function(session_v2, user2, value) is change_four + database_result_user2 = _get_getter_result( + session_v2, + user2, + getter_function, + column_name, + ) + _validate_getter_result(database_result_user2, None, value, change_four) + + +EMAIL_TESTS = [ + (None, False), + ("", False), + ("b", False), + ("b@", False), + ("b@com", False), + ("b@.com", False), + ("@", False), + ("@.com", False), + ("b.com", False), + ("b@c.com", True), +] + + +@pytest.mark.parametrize("email, is_valid", EMAIL_TESTS) +def test_get_set_email( + session_v2: Session, + user_v2: User, + email: str, + is_valid: bool, +): + original_email = "email@gmail.com" + assert crud.set_email(session_v2, user_v2, email) == is_valid + database_email = crud.get_email(session_v2, user_v2) + _validate_getter_result(database_email, original_email, email, is_valid) + + +FULL_NAME_TESTS = [ + (None, False), + ("", False), + ("b", False), + ("ba", True), + ("b" * 21, True), +] + + +@pytest.mark.parametrize("name, is_valid", FULL_NAME_TESTS) +def test_get_set_full_name( + session_v2: Session, + user_v2: User, + name: str, + is_valid: bool, +): + original_full_name = user_v2.full_name + assert crud.set_full_name(session_v2, user_v2, name) is is_valid + _validate_getter_result( + user_v2.full_name, + original_full_name, + name, + is_valid, + ) + + +def test_get_set_language(session_v2: Session, user_v2: User): + language = crud.get_language(session_v2, user_v2) + assert language == Language(id=1, name="English", code="en") + assert crud.set_language(session_v2, user_v2, 2) + language = crud.get_language(session_v2, user_v2) + assert language == Language(id=2, name="עברית", code="he") + + +PASSWORD_TESTS = [ + (None, False), + ("", False), + ("a", False), + ("ab", False), + ("a" * 21, False), + (1, False), + ("abc", True), + ("abc123!@#$_", True), +] + + +@pytest.mark.parametrize("password, is_valid", PASSWORD_TESTS) +def test_set_password( + session_v2: Session, + user_v2: User, + password: str, + is_valid: bool, +): + secret_password: Any + try: + secret_password = SecretStr(password) + except TypeError: + secret_password = password + + result = crud.set_password(session_v2, user_v2, secret_password) + assert result is is_valid + + +PRIVACY_TESTS = [ + (None, False), + ("", False), + ("bad_key", False), + (1, False), + (PrivacyKinds.Public, True), + (PrivacyKinds.Private, True), + (PrivacyKinds.Hidden, True), +] + + +@pytest.mark.parametrize("privacy, is_valid", PRIVACY_TESTS) +def test_get_set_privacy( + session_v2: Session, + user_v2: User, + privacy: Any, + is_valid: bool, +): + original_privacy = crud.get_privacy(session_v2, user_v2) + assert crud.set_privacy(session_v2, user_v2, privacy) is is_valid + database_privacy = crud.get_privacy(session_v2, user_v2) + _validate_getter_result( + database_privacy, + original_privacy, + privacy, + is_valid, + ) + + +USERNAME_TESTS = [ + (None, False), + ("", False), + ("a", False), + ("ab", False), + ("a" * 21, False), + ("abc%@", False), + ("abc", True), + ("abc12309_", True), +] + + +@pytest.mark.parametrize("username, is_valid", USERNAME_TESTS) +def test_set_username( + session_v2: Session, + user_v2: User, + username: str, + is_valid: bool, +): + original_username = user_v2.username + assert crud.set_username(session_v2, user_v2, username) is is_valid + _validate_getter_result( + user_v2.username, + original_username, + username, + is_valid, + ) + + +@pytest.mark.parametrize("number_of_events", [0, 1, 2]) +def test_get_owned_events( + session_v2: Session, + create_users: Callable[[str, str], Optional[User]], + create_events: Callable[[int, datetime, datetime], Optional[Event]], + number_of_events: int, +): + user1 = create_users("username1", "email1@gmail.com") + assert user1 + for i in range(number_of_events): + create_events(user1.id, datetime.today(), datetime.today()) + events = crud.get_owned_events(session_v2, user1) + assert len(events) == number_of_events + + user2 = create_users("username2", "email2@gmail.com") + assert user2 + create_events(user2.id, datetime.today(), datetime.today()) + events2 = crud.get_owned_events(session_v2, user2) + assert len(events2) == 1 + + events = crud.get_owned_events(session_v2, user1) + assert len(events) == number_of_events + + +def test_get_events( + session_v2: Session, + create_users: Callable[[str, str], Optional[User]], + create_events: Callable[[int, datetime, datetime], Optional[Event]], +): + user1 = create_users("username1", "email1@gmail.com") + user2 = create_users("username2", "email2@gmail.com") + user3 = create_users("username3", "email3@gmail.com") + assert user1 + assert user2 + assert user3 + + create_events(user1.id, datetime.today(), datetime.today()) + + event2 = create_events(user1.id, datetime.today(), datetime.today()) + assert event2 + event_crud.add_member(session_v2, event2, user2) + + event3 = create_events(user2.id, datetime.today(), datetime.today()) + assert event3 + event_crud.add_member(session_v2, event3, user1) + + user1_events = crud.get_events(session_v2, user1) + assert len(user1_events) == 3 + + user2_events = crud.get_events(session_v2, user2) + assert len(user2_events) == 2 + + user3_events = crud.get_events(session_v2, user3) + assert len(user3_events) == 0 + + +def test_get_delete_google_calendar_events( + session_v2: Session, + user_v2: User, + create_events: Callable[[int, datetime, datetime], Optional[Event]], +): + number_of_events = 4 + number_of_google_events = 2 + events = [] + + # Create number_of_events events. + for i in range(number_of_events): + event = create_events(user_v2.id, datetime.today(), datetime.today()) + assert event + events.append(event) + + # Mark number_of_google_events events as Google Calendar events. + for event in events[::2]: + event_orm = event_crud.get_database_event_by_id(session_v2, event.id) + assert event_orm + event_orm.is_google_event = True + session_v2.commit() + + # Verify number of events is number_of_events. + all_events = crud.get_events(session_v2, user_v2) + assert len(all_events) == number_of_events + + # Verify number of Google events is number_of_google_events. + google_events = crud.get_google_calender_events(session_v2, user_v2) + assert len(google_events) == number_of_google_events + + assert crud.delete_all_google_calendar_events(session_v2, user_v2) + + # Verify there are no more Google events. + google_events = crud.get_google_calender_events(session_v2, user_v2) + assert len(google_events) == 0 + + # Verify current number of events is + # number_of_events - number_of_google_events. + remaining_events = number_of_events - number_of_google_events + all_events = crud.get_events(session_v2, user_v2) + owned_events = crud.get_owned_events(session_v2, user_v2) + assert ( + len(all_events) == remaining_events + and len(owned_events) == remaining_events + ) + + +def _get_getter_result( + session_v2: Session, + user: User, + getter: Optional[Callable], + column_name: str, +) -> Any: + if column_name == "username": + return user.username + else: + assert getter + return getter(session_v2, user) + + +def _validate_getter_result( + current_value: Any, + original_value: Any, + new_value: str, + is_valid: bool, +): + if is_valid: + assert current_value == new_value + else: + assert current_value == original_value diff --git a/tests/crud/test_util.py b/tests/crud/test_util.py new file mode 100644 index 00000000..1e9a10cd --- /dev/null +++ b/tests/crud/test_util.py @@ -0,0 +1,70 @@ +"""Shared code for CRUD tests.""" +from collections import Callable +from types import ModuleType +from typing import Any, Optional + +from pydantic import BaseModel + + +def get_attribute_value(model: BaseModel, attribute: str) -> Any: + """Returns the value for a model's attribute. + + Args: + model: The model to get the value from. + attribute: The attribute to get the value for. + + Returns: + The attribute value. + """ + try: + return getattr(model, attribute) + except AttributeError: + return None + + +def get_getter_function( + crud: ModuleType, + column_name: str, +) -> Optional[Callable]: + """Returns a get function from a CRUD module. + + Args: + crud: A CRUD module. + column_name: The column to get the function for. + + Returns: + A get function for an object's column. + """ + try: + return getattr(crud, f"get_{column_name}") + except AttributeError: + return None + + +def get_boolean_getter_function( + crud: ModuleType, + column_name: str, +) -> Callable: + """Returns a get function from a CRUD module for a boolean field. + + Args: + crud: A CRUD module. + column_name: The column to get the function for. + + Returns: + A get function for an object's boolean column. + """ + return getattr(crud, f"is_{column_name}") + + +def get_setter_function(crud: ModuleType, column_name: str) -> Callable: + """Returns a set function from a CRUD module. + + Args: + crud: A CRUD module. + column_name: The column to get the function for. + + Returns: + A set function for an object's column. + """ + return getattr(crud, f"set_{column_name}") diff --git a/tests/test_language.py b/tests/test_language.py index 4d0746ad..b84b5833 100644 --- a/tests/test_language.py +++ b/tests/test_language.py @@ -1,5 +1,3 @@ -from pathlib import Path - import pytest from app.dependencies import templates @@ -11,50 +9,52 @@ class TestLanguage: # (currently 'en' and 'he') are set to the default language setting # at config.WEBSITE_LANGUAGE, which is currently set to 'en' (English). LANGUAGE_TESTS = [ - ('en', 'test python translation', True), - ('he', 'בדיקת תרגום בפייתון', True), - (None, 'test python translation', False), - ('', 'test python translation', False), - ('de', 'test python translation', False), - (["en"], 'test python translation', False), - (3, 'test python translation', False), + ("en", "test python translation", True), + ("he", "בדיקת תרגום בפייתון", True), + (None, "test python translation", False), + ("", "test python translation", False), + ("de", "test python translation", False), + (["en"], "test python translation", False), + (3, "test python translation", False), ] NUMBER_OF_LANGUAGES = 2 @staticmethod @pytest.mark.parametrize( - "language_code, translation, is_valid", LANGUAGE_TESTS) + "language_code, translation, is_valid", + LANGUAGE_TESTS, + ) def test_gettext_python(language_code, translation, is_valid): languages.set_ui_language(language_code) # i18n: String used in testing. Do not change. gettext_translation = _("test python translation") - assert ((is_valid and gettext_translation == translation) - or (not is_valid and gettext_translation == translation)) + assert (is_valid and gettext_translation == translation) or ( + not is_valid and gettext_translation == translation + ) @staticmethod @pytest.mark.parametrize( - "language_code, translation, is_valid", LANGUAGE_TESTS) + "language_code, translation, is_valid", + LANGUAGE_TESTS, + ) def test_gettext_html(language_code, translation, is_valid): languages.set_ui_language(language_code) template = templates.env.from_string( - '{{ gettext("test python translation") }}') + "{{ gettext('test python translation') }}", + ) text = template.render() - assert ((is_valid and translation in text) - or (not is_valid and translation in text)) + assert (is_valid and translation in text) or ( + not is_valid and translation in text + ) @staticmethod def test_get_supported_languages(): number_of_languages = len(list(languages._get_supported_languages())) assert number_of_languages == TestLanguage.NUMBER_OF_LANGUAGES - @staticmethod - def test_get_language_directory(): - pytest.MonkeyPatch().setattr(Path, 'is_dir', lambda x: True) - assert languages._get_language_directory() - @staticmethod def test_get_display_language(): # TODO: Waiting for user registration.