Skip to content

Commit

Permalink
DRY models
Browse files Browse the repository at this point in the history
  • Loading branch information
yedpodtrzitko committed Sep 8, 2024
1 parent deeaef7 commit 704237f
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 65 deletions.
94 changes: 42 additions & 52 deletions tagstudio/src/core/library/alchemy/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,69 @@

from dataclasses import dataclass
from enum import Enum
from typing import Union, Any, TYPE_CHECKING
from typing import Any, TYPE_CHECKING

from sqlalchemy import ForeignKey, ForeignKeyConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.orm import Mapped, mapped_column, relationship, declared_attr

from .db import Base
from .enums import FieldTypeEnum

if TYPE_CHECKING:
from .models import Entry, Tag, LibraryField

Field = Union["TextField", "TagBoxField", "DatetimeField"]

class BaseField(Base):
__abstract__ = True

class BooleanField(Base):
__tablename__ = "boolean_fields"
@declared_attr
def id(cls) -> Mapped[int]:
return mapped_column(primary_key=True, autoincrement=True)

id: Mapped[int] = mapped_column(primary_key=True)
type_key: Mapped[str] = mapped_column(ForeignKey("library_fields.key"))
type: Mapped[LibraryField] = relationship(foreign_keys=[type_key], lazy=False)
@declared_attr
def type_key(cls) -> Mapped[str]:
return mapped_column(ForeignKey("library_fields.key"))

entry_id: Mapped[int] = mapped_column(ForeignKey("entries.id"))
entry: Mapped[Entry] = relationship()
@declared_attr
def type(cls) -> Mapped[LibraryField]:
return relationship(foreign_keys=[cls.type_key], lazy=False) # type: ignore

value: Mapped[bool]
position: Mapped[int]
@declared_attr
def entry_id(cls) -> Mapped[int]:
return mapped_column(ForeignKey("entries.id"))

def __key(self):
return (self.type, self.value)
@declared_attr
def entry(cls) -> Mapped[Entry]:
return relationship(foreign_keys=[cls.entry_id]) # type: ignore

@declared_attr
def position(cls) -> Mapped[int]:
return mapped_column()

def __hash__(self):
return hash(self.__key())

def __key(self):
raise NotImplementedError

value: Any


class BooleanField(BaseField):
__tablename__ = "boolean_fields"

value: Mapped[bool]

def __key(self):
return (self.type, self.value)

def __eq__(self, value) -> bool:
if isinstance(value, BooleanField):
return self.__key() == value.__key()
raise NotImplementedError


class TextField(Base):
class TextField(BaseField):
__tablename__ = "text_fields"
# constrain for combination of: entry_id, type_key and position
__table_args__ = (
Expand All @@ -51,21 +74,10 @@ class TextField(Base):
),
)

id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
type_key: Mapped[str] = mapped_column(ForeignKey("library_fields.key"))
type: Mapped[LibraryField] = relationship(foreign_keys=[type_key], lazy=False)

entry_id: Mapped[int] = mapped_column(ForeignKey("entries.id"))
entry: Mapped[Entry] = relationship(foreign_keys=[entry_id])

value: Mapped[str | None]
position: Mapped[int]

def __key(self):
return (self.type, self.value)

def __hash__(self):
return hash(self.__key())
def __key(self) -> tuple:
return self.type, self.value

def __eq__(self, value) -> bool:
if isinstance(value, TextField):
Expand All @@ -75,18 +87,10 @@ def __eq__(self, value) -> bool:
raise NotImplementedError


class TagBoxField(Base):
class TagBoxField(BaseField):
__tablename__ = "tag_box_fields"

id: Mapped[int] = mapped_column(primary_key=True)
type_key: Mapped[str] = mapped_column(ForeignKey("library_fields.key"))
type: Mapped[LibraryField] = relationship(foreign_keys=[type_key], lazy=False)

entry_id: Mapped[int] = mapped_column(ForeignKey("entries.id"))
entry: Mapped[Entry] = relationship(foreign_keys=[entry_id])

tags: Mapped[set[Tag]] = relationship(secondary="tag_fields")
position: Mapped[int]

def __key(self):
return (
Expand All @@ -99,34 +103,20 @@ def value(self) -> None:
"""For interface compatibility with other field types."""
return None

def __hash__(self):
return hash(self.__key())

def __eq__(self, value) -> bool:
if isinstance(value, TagBoxField):
return self.__key() == value.__key()
raise NotImplementedError


class DatetimeField(Base):
class DatetimeField(BaseField):
__tablename__ = "datetime_fields"

id: Mapped[int] = mapped_column(primary_key=True)
type_key: Mapped[str] = mapped_column(ForeignKey("library_fields.key"))
type: Mapped[LibraryField] = relationship(foreign_keys=[type_key], lazy=False)

entry_id: Mapped[int] = mapped_column(ForeignKey("entries.id"))
entry: Mapped[Entry] = relationship(foreign_keys=[entry_id])

value: Mapped[str | None]
position: Mapped[int]

def __key(self):
return (self.type, self.value)

def __hash__(self):
return hash(self.__key())

def __eq__(self, value) -> bool:
if isinstance(value, DatetimeField):
return self.__key() == value.__key()
Expand Down
10 changes: 5 additions & 5 deletions tagstudio/src/core/library/alchemy/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
TagBoxField,
TextField,
_FieldID,
Field,
BaseField,
)
from .joins import TagSubtag, TagField
from .models import Entry, Preferences, Tag, TagAlias, LibraryField, Folder
Expand Down Expand Up @@ -488,7 +488,7 @@ def remove_tag_from_field(self, tag: Tag, field: TagBoxField) -> None:

def update_field_position(
self,
field_class: Type[Field],
field_class: Type[BaseField],
field_type: str,
entry_ids: list[int] | int,
):
Expand All @@ -512,15 +512,15 @@ def update_field_position(

# Reassign `order` starting from 0
for index, row in enumerate(rows):
row.position = index # type: ignore
row.position = index
session.add(row)
session.flush()
if rows:
session.commit()

def remove_entry_field(
self,
field: Field,
field: BaseField,
entry_ids: list[int],
) -> None:
FieldClass = type(field)
Expand Down Expand Up @@ -554,7 +554,7 @@ def remove_entry_field(
def update_entry_field(
self,
entry_ids: list[int] | int,
field: Field,
field: BaseField,
content: str | datetime | set[Tag],
):
if isinstance(entry_ids, int):
Expand Down
8 changes: 4 additions & 4 deletions tagstudio/src/core/library/alchemy/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from .enums import TagColor
from .fields import (
DatetimeField,
Field,
TagBoxField,
TextField,
FieldTypeEnum,
_FieldID,
BaseField,
)
from .joins import TagSubtag
from ...constants import TAG_FAVORITE, TAG_ARCHIVED
Expand Down Expand Up @@ -134,8 +134,8 @@ class Entry(Base):
)

@property
def fields(self) -> list[Field]:
fields: list[Field] = []
def fields(self) -> list[BaseField]:
fields: list[BaseField] = []
fields.extend(self.tag_box_fields)
fields.extend(self.text_fields)
fields.extend(self.datetime_fields)
Expand Down Expand Up @@ -171,7 +171,7 @@ def __init__(
self,
path: Path,
folder: Folder,
fields: list[Field] | None = None,
fields: list[BaseField] | None = None,
) -> None:
self.path = path
self.folder = folder
Expand Down
8 changes: 4 additions & 4 deletions tagstudio/src/qt/widgets/preview_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@
TagBoxField,
DatetimeField,
FieldTypeEnum,
Field,
_FieldID,
TextField,
BaseField,
)
from src.qt.helpers.file_opener import FileOpenerLabel, FileOpenerHelper, open_file
from src.qt.modals.add_field import AddFieldModal
Expand Down Expand Up @@ -719,7 +719,7 @@ def set_tags_updated_slot(self, slot: object):
self.tags_updated.connect(slot)
self.is_connected = True

def write_container(self, index: int, field: Field, is_mixed: bool = False):
def write_container(self, index: int, field: BaseField, is_mixed: bool = False):
"""Update/Create data for a FieldContainer.
:param is_mixed: Relevant when multiple items are selected. If True, field is not present in all selected items
Expand Down Expand Up @@ -930,7 +930,7 @@ def write_container(self, index: int, field: Field, is_mixed: bool = False):
container.setHidden(False)
self.place_add_field_button()

def remove_field(self, field: Field):
def remove_field(self, field: BaseField):
"""Remove a field from all selected Entries."""
logger.info("removing field", field=field, selected=self.selected)
entry_ids = []
Expand All @@ -945,7 +945,7 @@ def remove_field(self, field: Field):
if field.type_key == _FieldID.TAGS_META.value:
self.driver.update_badges(self.selected)

def update_field(self, field: Field, content: str) -> None:
def update_field(self, field: BaseField, content: str) -> None:
"""Remove a field from all selected Entries, given a field object."""
assert isinstance(
field, (TextField, DatetimeField, TagBoxField)
Expand Down

0 comments on commit 704237f

Please sign in to comment.