Skip to content

Commit

Permalink
Merge pull request #140 from opendatalab/feat/json-pre-annotation
Browse files Browse the repository at this point in the history
Feat/json pre annotation
  • Loading branch information
gary-Shen authored Nov 14, 2024
2 parents 7dc220a + 2b4a9bd commit 9e5b3ab
Show file tree
Hide file tree
Showing 9 changed files with 480 additions and 131 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""add_sample_name_and_data_to_task_pre_annotation
1. Add sample_name column to task_pre_annotation table.
2. Read jsonl file content and store it in data column of task_pre_annotation table.
Revision ID: eb9c5b98168b
Revises: bc8fcb35b66b
Create Date: 2024-11-13 14:08:09.374271
"""
import json
from typing import List
from alembic import op, context
import sqlalchemy as sa
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.ext.automap import automap_base
from sqlalchemy import select
from labelu.internal.common.config import settings

from labelu.internal.domain.models.pre_annotation import TaskPreAnnotation
from labelu.internal.adapter.persistence import crud_attachment

Base = automap_base()


# revision identifiers, used by Alembic.
revision = 'eb9c5b98168b'
down_revision = 'bc8fcb35b66b'
branch_labels = None
depends_on = None

def index_exists(bind, table_name, index_name):
inspector = sa.inspect(bind)
indexes = inspector.get_indexes(table_name)
for index in indexes:
if index['name'] == index_name:
return True
return False

def read_jsonl_file(db: Session, file_id: int) -> List[dict]:
attachment = crud_attachment.get(db, file_id)
if attachment is None:
return []

attachment_path = attachment.path
file_full_path = settings.MEDIA_ROOT.joinpath(attachment_path.lstrip("/"))

# check if the file exists
if not file_full_path.exists() or not attachment.filename.endswith('.jsonl'):
return []


with open(file_full_path, "r", encoding="utf-8") as f:
data = f.readlines()

parsed_data = [json.loads(line) for line in data]
return parsed_data

def upgrade() -> None:
bind = op.get_bind()
Base.prepare(autoload_with=bind, reflect=True)
# make a session
SessionMade = sessionmaker(bind=bind)
session = SessionMade()

try:
with context.begin_transaction():
# add column sample_name, check if the column exists
if not hasattr(Base.classes.task_pre_annotation, "sample_name"):
op.add_column(
"task_pre_annotation",
sa.Column("sample_name", sa.String(255), index=True, comment="One of the sample names of the task"),
)

# create index, check if the index exists
if not index_exists(bind, "task_pre_annotation", "idx_pre_annotation_sample_name"):
op.create_index("idx_pre_annotation_sample_name", "task_pre_annotation", ["sample_name"])

# create task_pre_annotation
exist_task_pre_annotations = session.execute(
select([Base.classes.task_pre_annotation])
).scalars().all()

for task_pre_annotation in exist_task_pre_annotations:
file_id = task_pre_annotation.file_id
jsonl_contents = read_jsonl_file(session, file_id)

# create new task_pre_annotation
for jsonl_content in jsonl_contents:
sample_name = jsonl_content.get("sample_name")
query = select(Base.classes.task_attachment).where(
Base.classes.task_attachment.task_id == task_pre_annotation.task_id,
# filename include sample_name, full name is xxxxxxxxx-sample_name.png, shot name is sample_name.png
Base.classes.task_attachment.filename.contains(sample_name),
)
sample_file = session.execute(query).scalars().first()
new_task_pre_annotation = TaskPreAnnotation(
task_id=task_pre_annotation.task_id,
# full file name
sample_name=sample_file.filename,
file_id=file_id,
created_by=task_pre_annotation.created_by,
updated_by=task_pre_annotation.updated_by,
data=json.dumps(jsonl_content),
)
session.add(new_task_pre_annotation)

# remove existing task_pre_annotation
for task_pre_annotation in exist_task_pre_annotations:
session.delete(task_pre_annotation)

# commit
session.commit()

except Exception as e:
session.rollback()
raise e

finally:
session.close()


def downgrade() -> None:
bind = op.get_bind()
Base.prepare(autoload_with=bind, reflect=True)
SessionMade = sessionmaker(bind=bind)
session = SessionMade()

try:
with context.begin_transaction():
# remove new task_pre_annotation
new_task_pre_annotations = session.query(Base.classes.task_pre_annotation).filter(
Base.classes.task_pre_annotation.sample_name.isnot(None)
).all()

for task_pre_annotation in new_task_pre_annotations:
session.delete(task_pre_annotation)

# restore old task_pre_annotation
old_task_pre_annotations = session.query(TaskPreAnnotation).filter(
TaskPreAnnotation.sample_name.isnot(None)
).all()

for task_pre_annotation in old_task_pre_annotations:
restored_task_pre_annotation = Base.classes.task_pre_annotation(
task_id=task_pre_annotation.task_id,
file_id=task_pre_annotation.file_id,
created_by=task_pre_annotation.created_by,
updated_by=task_pre_annotation.updated_by,
data=task_pre_annotation.data,
)
session.add(restored_task_pre_annotation)

# drop index
if index_exists(bind, "task_pre_annotation", "idx_pre_annotation_sample_name"):
op.drop_index("idx_pre_annotation_sample_name", table_name="task_pre_annotation")

# drop column
if hasattr(Base.classes.task_pre_annotation, "sample_name"):
op.drop_column("task_pre_annotation", "sample_name")

# commit
session.commit()

except Exception as e:
session.rollback()
raise e

finally:
session.close()
52 changes: 51 additions & 1 deletion labelu/internal/adapter/persistence/crud_attachment.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,60 @@
from datetime import datetime
from typing import List
from typing import List, Optional, Tuple

from sqlalchemy.orm import Session

from labelu.internal.domain.models.attachment import TaskAttachment

def list_by(
db: Session,
pageSize: int,
ids: List[int] | None = [],
task_id: Optional[int] = None,
owner_id: Optional[int] = None,
after: Optional[int] = None,
before: Optional[int] = None,
pageNo: Optional[int] = None,
sorting: Optional[str] = None,
) -> Tuple[List[TaskAttachment], int]:
# query filter
query_filter = [TaskAttachment.deleted_at == None, TaskAttachment.id.in_(ids)]
if owner_id:
query_filter.append(TaskAttachment.created_by == owner_id)

if before:
query_filter.append(TaskAttachment.id < before)
if after:
query_filter.append(TaskAttachment.id > after)
if task_id:
query_filter.append(TaskAttachment.task_id == task_id)

query = db.query(TaskAttachment).filter(*query_filter)

# default order by id, before need select last items
if before:
query = query.order_by(TaskAttachment.id.desc())
else:
query = query.order_by(TaskAttachment.id.asc())

count = query.count()

results = (
query.offset(offset=pageNo * pageSize if pageNo else 0)
.limit(limit=pageSize)
.all()
)

if sorting:
field, order = sorting.split(":")
if order == "desc":
results = sorted(results, key=lambda x: getattr(x, field), reverse=True)
else:
results = sorted(results, key=lambda x: getattr(x, field))

if before:
results.reverse()

return results, count

def create(db: Session, attachment: TaskAttachment) -> TaskAttachment:
db.add(attachment)
Expand Down
59 changes: 50 additions & 9 deletions labelu/internal/adapter/persistence/crud_pre_annotation.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from datetime import datetime
import json
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Union, Tuple


from sqlalchemy import case, text
from sqlalchemy.orm import Session
from sqlalchemy import or_
from fastapi.encoders import jsonable_encoder

from labelu.internal.domain.models.pre_annotation import TaskPreAnnotation
from labelu.internal.adapter.persistence import crud_attachment


def batch(db: Session, pre_annotations: List[TaskPreAnnotation]) -> List[TaskPreAnnotation]:
Expand All @@ -20,12 +17,13 @@ def list_by(
db: Session,
task_id: Union[int, None],
owner_id: int,
sample_name: str | None,
after: Union[int, None],
before: Union[int, None],
pageNo: Union[int, None],
pageSize: int,
sorting: Union[str, None],
) -> List[TaskPreAnnotation]:
) -> Tuple[List[TaskPreAnnotation], int]:

# query filter
query_filter = [TaskPreAnnotation.created_by == owner_id, TaskPreAnnotation.deleted_at == None]
Expand All @@ -35,25 +33,37 @@ def list_by(
query_filter.append(TaskPreAnnotation.id > after)
if task_id:
query_filter.append(TaskPreAnnotation.task_id == task_id)

if sample_name:
query_filter.append(or_(TaskPreAnnotation.sample_name == sample_name, TaskPreAnnotation.sample_name == sample_name[9:]))

query = db.query(TaskPreAnnotation).filter(*query_filter)

# default order by id, before need select last items
if before:
query = query.order_by(TaskPreAnnotation.id.desc())
else:
query = query.order_by(TaskPreAnnotation.id.asc())

count = query.count()

results = (
query.offset(offset=pageNo * pageSize if pageNo else 0)
.limit(limit=pageSize)
.all()
)

# No sorting
if sorting:
field, order = sorting.split(":")
if order == "desc":
results = sorted(results, key=lambda x: getattr(x, field), reverse=True)
else:
results = sorted(results, key=lambda x: getattr(x, field))

if before:
results.reverse()

return results
return results, count

def list_by_task_id_and_owner_id(db: Session, task_id: int, owner_id: int) -> Dict[str, List[TaskPreAnnotation]]:
pre_annotations = db.query(TaskPreAnnotation).filter(
Expand All @@ -64,6 +74,33 @@ def list_by_task_id_and_owner_id(db: Session, task_id: int, owner_id: int) -> Di

return pre_annotations

def list_by_task_id_and_file_id(db: Session, task_id: int, file_id: int, owner_id: int) -> List[TaskPreAnnotation]:
return db.query(TaskPreAnnotation).filter(
TaskPreAnnotation.task_id == task_id,
TaskPreAnnotation.created_by == owner_id,
TaskPreAnnotation.deleted_at == None,
TaskPreAnnotation.file_id == file_id
).all()

def list_by_task_id_and_owner_id_and_sample_name(db: Session, task_id: int, owner_id: int, sample_name: str) -> List[TaskPreAnnotation]:
"""list pre annotations by task_id, owner_id and sample_name without pagination
Args:
db (Session): _description_
task_id (int): _description_
owner_id (int): _description_
sample_name (str): _description_
Returns:
List[TaskPreAnnotation]: _description_
"""
return db.query(TaskPreAnnotation).filter(
TaskPreAnnotation.task_id == task_id,
TaskPreAnnotation.deleted_at == None,
TaskPreAnnotation.created_by == owner_id,
TaskPreAnnotation.sample_name == sample_name
).all()

def get(db: Session, pre_annotation_id: int) -> TaskPreAnnotation:
return (
db.query(TaskPreAnnotation)
Expand Down Expand Up @@ -97,8 +134,12 @@ def delete(db: Session, pre_annotation_ids: List[int]) -> None:
)


def count(db: Session, task_id: int, owner_id: int) -> int:
def count(db: Session, task_id: int, owner_id: int, sample_name: str | None) -> int:
query_filter = [TaskPreAnnotation.created_by == owner_id, TaskPreAnnotation.deleted_at == None]
if task_id:
query_filter.append(TaskPreAnnotation.task_id == task_id)

if sample_name:
query_filter.append(TaskPreAnnotation.sample_name == sample_name)

return db.query(TaskPreAnnotation).filter(*query_filter).count()
Loading

0 comments on commit 9e5b3ab

Please sign in to comment.