Skip to content

Commit

Permalink
feat: supports labelu json file as the pre-annotation file
Browse files Browse the repository at this point in the history
  • Loading branch information
gary-Shen committed Nov 14, 2024
1 parent 57eaae0 commit 2b4a9bd
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 35 deletions.
11 changes: 10 additions & 1 deletion labelu/internal/adapter/persistence/crud_pre_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, List, Union, Tuple

from sqlalchemy.orm import Session
from sqlalchemy import or_
from fastapi.encoders import jsonable_encoder

from labelu.internal.domain.models.pre_annotation import TaskPreAnnotation
Expand Down Expand Up @@ -34,7 +35,7 @@ def list_by(
query_filter.append(TaskPreAnnotation.task_id == task_id)

if sample_name:
query_filter.append(TaskPreAnnotation.sample_name == sample_name)
query_filter.append(or_(TaskPreAnnotation.sample_name == sample_name, TaskPreAnnotation.sample_name == sample_name[9:]))

query = db.query(TaskPreAnnotation).filter(*query_filter)

Expand Down Expand Up @@ -73,6 +74,14 @@ def list_by_task_id_and_owner_id(db: Session, task_id: int, owner_id: int) -> Di

return pre_annotations

def list_by_task_id_and_file_id(db: Session, task_id: int, file_id: int, owner_id: int) -> List[TaskPreAnnotation]:
return db.query(TaskPreAnnotation).filter(
TaskPreAnnotation.task_id == task_id,
TaskPreAnnotation.created_by == owner_id,
TaskPreAnnotation.deleted_at == None,
TaskPreAnnotation.file_id == file_id
).all()

def list_by_task_id_and_owner_id_and_sample_name(db: Session, task_id: int, owner_id: int, sample_name: str) -> List[TaskPreAnnotation]:
"""list pre annotations by task_id, owner_id and sample_name without pagination
Expand Down
73 changes: 48 additions & 25 deletions labelu/internal/application/service/pre_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,32 @@
from labelu.internal.application.response.pre_annotation import PreAnnotationResponse
from labelu.internal.application.response.attachment import AttachmentResponse

def read_jsonl_file(attachment: TaskAttachment) -> List[dict]:
def read_pre_annotation_file(attachment: TaskAttachment) -> List[dict]:
if attachment is None:
return []

attachment_path = attachment.path
file_full_path = settings.MEDIA_ROOT.joinpath(attachment_path.lstrip("/"))

# check if the file exists
if not file_full_path.exists() or not attachment.filename.endswith('.jsonl'):
if not file_full_path.exists() or (not attachment.filename.endswith('.jsonl') and not attachment.filename.endswith('.json')):
return []

try:
with open(file_full_path, "r", encoding="utf-8") as f:
data = f.readlines()
if attachment.filename.endswith('.jsonl'):
with open(file_full_path, "r", encoding="utf-8") as f:
data = f.readlines()
return [json.loads(line) for line in data]
else:
with open(file_full_path, "r", encoding="utf-8") as f:
# parse result
parsed_data = json.load(f)

return [{**item, "result": json.loads(item["result"])} for item in parsed_data]

except FileNotFoundError:
raise LabelUException(status_code=404, code=ErrorCode.CODE_51001_TASK_ATTACHMENT_NOT_FOUND)

parsed_data = [json.loads(line) for line in data]
return parsed_data

async def create(
db: Session, task_id: int, cmd: List[CreatePreAnnotationCommand], current_user: User
) -> CreatePreAnnotationResponse:
Expand All @@ -65,15 +71,15 @@ async def create(
status_code=status.HTTP_400_BAD_REQUEST,
)

jsonl_contents = read_jsonl_file(jsonl_file)
pre_annotation_contents = read_pre_annotation_file(jsonl_file)

for jsonl_content in jsonl_contents:
for _item in pre_annotation_contents:
pre_annotations.append(
TaskPreAnnotation(
task_id=task_id,
file_id=pre_annotation.file_id,
sample_name=jsonl_content.get("sample_name"),
data=json.dumps(jsonl_content, ensure_ascii=False),
sample_name= _item.get("sample_name") if jsonl_file.filename.endswith(".jsonl") else _item.get("fileName"),
data=json.dumps(_item, ensure_ascii=False),
created_by=current_user.id,
updated_by=current_user.id,
)
Expand Down Expand Up @@ -138,20 +144,37 @@ async def list_pre_annotation_files(
sorting: Optional[str],
current_user: User,
) -> Tuple[List[TaskAttachment], int]:
pre_annotations = crud_pre_annotation.list_by_task_id_and_owner_id(db=db, task_id=task_id, owner_id=current_user.id)
file_ids = [pre_annotation.file_id for pre_annotation in pre_annotations]

attachments, total = crud_attachment.list_by(db=db, ids=file_ids, after=after, before=before, pageNo=pageNo, pageSize=pageSize, sorting=sorting)

return [
PreAnnotationFileResponse(
id=attachment.id,
url=attachment.url,
filename=attachment.filename,
sample_names=[pre_annotation.sample_name for pre_annotation in pre_annotations if pre_annotation.file_id == attachment.id]
try:
pre_annotations = crud_pre_annotation.list_by_task_id_and_owner_id(db=db, task_id=task_id, owner_id=current_user.id)
file_ids = [pre_annotation.file_id for pre_annotation in pre_annotations]

attachments, total = crud_attachment.list_by(db=db, ids=file_ids, after=after, before=before, pageNo=pageNo, pageSize=pageSize, sorting=sorting)

_attachment_ids = [attachment.id for attachment in attachments]
def get_sample_names():
_names = []
for pre_annotation in pre_annotations:
if pre_annotation.file_id in _attachment_ids and pre_annotation.sample_name is not None:
_names.append(pre_annotation.sample_name)

return _names

return [
PreAnnotationFileResponse(
id=attachment.id,
url=attachment.url,
filename=attachment.filename,
sample_names=get_sample_names(),
)
for attachment in attachments
], total

except Exception as e:
logger.error("list pre annotation files error: {}", e)
raise LabelUException(
code=ErrorCode.CODE_51001_TASK_ATTACHMENT_NOT_FOUND,
status_code=status.HTTP_404_NOT_FOUND,
)
for attachment in attachments
], total


async def get(
Expand Down Expand Up @@ -189,7 +212,7 @@ async def delete_pre_annotation_file(
db: Session, task_id: int, file_id: int, current_user: User
) -> CommonDataResp:
with db.begin():
pre_annotations = crud_pre_annotation.list_by_task_id_and_owner_id_and_sample_name(db=db, task_id=task_id, owner_id=current_user.id, sample_name=crud_attachment.get(db, file_id).filename)
pre_annotations = crud_pre_annotation.list_by_task_id_and_file_id(db=db, task_id=task_id, owner_id=current_user.id, file_id=file_id)
pre_annotation_ids = [pre_annotation.id for pre_annotation in pre_annotations]
crud_pre_annotation.delete(db=db, pre_annotation_ids=pre_annotation_ids)
crud_attachment.delete(db=db, attachment_ids=[file_id])
Expand Down
18 changes: 9 additions & 9 deletions labelu/internal/common/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def convert_to_json(
"id": sample.get("id"),
"result": annotated_result_str,
"url": file.get("url"),
"fileName": file.get("filename", "")[9:],
"fileName": file.get("filename", ""),
}
)

Expand Down Expand Up @@ -222,7 +222,7 @@ def convert_to_coco(
# coco image
image = {
"id": sample.get("id"),
"fileName": file.get("filename", "")[9:],
"fileName": file.get("filename", ""),
"width": annotation_result.get("width", 0),
"height": annotation_result.get("height", 0),
"valid": False
Expand Down Expand Up @@ -319,7 +319,7 @@ def convert_to_mask(
logger.info("data is: {}", sample)
filename = file.get("filename")
if filename and filename.split("/")[-1]:
file_relative_path_base_name = filename.split("/")[-1].split(".")[0][9:]
file_relative_path_base_name = filename.split("/")[-1].split(".")[0]
else:
file_relative_path_base_name = "result"

Expand Down Expand Up @@ -444,7 +444,7 @@ def image_to_base64(file_path: str):
if sample.get("state") == "SKIPPED":
continue

labelme_item["imagePath"] = file.get("filename", "")[9:]
labelme_item["imagePath"] = file.get("filename", "")
labelme_item["imageData"] = image_to_base64(file.get("path"))

if annotated_result:
Expand Down Expand Up @@ -520,7 +520,7 @@ def image_to_base64(file_path: str):
}
labelme_item["shapes"].append(shape)
result.append(labelme_item)
file_basename = os.path.splitext(file.get("filename", "")[9:])[0]
file_basename = os.path.splitext(file.get("filename", ""))[0]
file_name = out_data_dir.joinpath(f"{file_basename}.json")
with file_name.open("w") as outfile:
# 格式化json,两个空格缩进
Expand Down Expand Up @@ -567,7 +567,7 @@ def convert_to_yolo(self, config: dict, input_data: List[dict], out_data_file_na
continue

image_path = settings.MEDIA_ROOT.joinpath(file.get("path").lstrip("/"))
file_basename = os.path.splitext(file.get("filename", "")[9:])[0]
file_basename = os.path.splitext(file.get("filename", ""))[0]
file_name = out_data_dir.joinpath(f"{file_basename}.txt")
image_width = annotated_result.get("width", 0)
image_height = annotated_result.get("height", 0)
Expand Down Expand Up @@ -730,7 +730,7 @@ def get_points(direction: dict):
label_text = get_label(tool, label)
rows.append([tool, label, label_text, direction, front, back, get_attributes(tool_result.get('attributes', {})), order])

file_basename = os.path.splitext(file.get("filename", "")[9:])[0]
file_basename = os.path.splitext(file.get("filename", ""))[0]
file_name = out_data_dir.joinpath(f"{file_basename}.csv")
with file_name.open("w") as outfile:
writer = csv.writer(outfile)
Expand Down Expand Up @@ -801,7 +801,7 @@ def convert_to_tf_record(self, config: dict, input_data: List[dict], out_data_fi
for sample in input_data:
file = sample.get("file", {})
example = tf_record_examples.pop(0)
file_basename = os.path.splitext(file.get("filename", "")[9:])[0]
file_basename = os.path.splitext(file.get("filename", ""))[0]
tf_record = f"{file_basename}.tfrecord"
file_full_path = out_data_dir.joinpath(tf_record)

Expand Down Expand Up @@ -834,7 +834,7 @@ def convert_to_pascal_voc(self, config: dict, input_data: List[dict], out_data_f
continue

voc_xml = xml_converter.create_pascal_voc_xml(config, file, annotated_result)
file_basename = os.path.splitext(file.get("filename", "")[9:])[0]
file_basename = os.path.splitext(file.get("filename", ""))[0]
file_name = out_data_dir.joinpath(f"{file_basename}.xml")

tree = ET.ElementTree(voc_xml)
Expand Down

0 comments on commit 2b4a9bd

Please sign in to comment.