Skip to content

Commit 2b4a9bd

Browse files
committed
feat: supports labelu json file as the pre-annotation file
1 parent 57eaae0 commit 2b4a9bd

File tree

3 files changed

+67
-35
lines changed

3 files changed

+67
-35
lines changed

labelu/internal/adapter/persistence/crud_pre_annotation.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Any, Dict, List, Union, Tuple
33

44
from sqlalchemy.orm import Session
5+
from sqlalchemy import or_
56
from fastapi.encoders import jsonable_encoder
67

78
from labelu.internal.domain.models.pre_annotation import TaskPreAnnotation
@@ -34,7 +35,7 @@ def list_by(
3435
query_filter.append(TaskPreAnnotation.task_id == task_id)
3536

3637
if sample_name:
37-
query_filter.append(TaskPreAnnotation.sample_name == sample_name)
38+
query_filter.append(or_(TaskPreAnnotation.sample_name == sample_name, TaskPreAnnotation.sample_name == sample_name[9:]))
3839

3940
query = db.query(TaskPreAnnotation).filter(*query_filter)
4041

@@ -73,6 +74,14 @@ def list_by_task_id_and_owner_id(db: Session, task_id: int, owner_id: int) -> Di
7374

7475
return pre_annotations
7576

77+
def list_by_task_id_and_file_id(db: Session, task_id: int, file_id: int, owner_id: int) -> List[TaskPreAnnotation]:
78+
return db.query(TaskPreAnnotation).filter(
79+
TaskPreAnnotation.task_id == task_id,
80+
TaskPreAnnotation.created_by == owner_id,
81+
TaskPreAnnotation.deleted_at == None,
82+
TaskPreAnnotation.file_id == file_id
83+
).all()
84+
7685
def list_by_task_id_and_owner_id_and_sample_name(db: Session, task_id: int, owner_id: int, sample_name: str) -> List[TaskPreAnnotation]:
7786
"""list pre annotations by task_id, owner_id and sample_name without pagination
7887

labelu/internal/application/service/pre_annotation.py

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,32 @@
2121
from labelu.internal.application.response.pre_annotation import PreAnnotationResponse
2222
from labelu.internal.application.response.attachment import AttachmentResponse
2323

24-
def read_jsonl_file(attachment: TaskAttachment) -> List[dict]:
24+
def read_pre_annotation_file(attachment: TaskAttachment) -> List[dict]:
2525
if attachment is None:
2626
return []
2727

2828
attachment_path = attachment.path
2929
file_full_path = settings.MEDIA_ROOT.joinpath(attachment_path.lstrip("/"))
3030

3131
# check if the file exists
32-
if not file_full_path.exists() or not attachment.filename.endswith('.jsonl'):
32+
if not file_full_path.exists() or (not attachment.filename.endswith('.jsonl') and not attachment.filename.endswith('.json')):
3333
return []
3434

3535
try:
36-
with open(file_full_path, "r", encoding="utf-8") as f:
37-
data = f.readlines()
36+
if attachment.filename.endswith('.jsonl'):
37+
with open(file_full_path, "r", encoding="utf-8") as f:
38+
data = f.readlines()
39+
return [json.loads(line) for line in data]
40+
else:
41+
with open(file_full_path, "r", encoding="utf-8") as f:
42+
# parse result
43+
parsed_data = json.load(f)
44+
45+
return [{**item, "result": json.loads(item["result"])} for item in parsed_data]
46+
3847
except FileNotFoundError:
3948
raise LabelUException(status_code=404, code=ErrorCode.CODE_51001_TASK_ATTACHMENT_NOT_FOUND)
4049

41-
parsed_data = [json.loads(line) for line in data]
42-
return parsed_data
43-
4450
async def create(
4551
db: Session, task_id: int, cmd: List[CreatePreAnnotationCommand], current_user: User
4652
) -> CreatePreAnnotationResponse:
@@ -65,15 +71,15 @@ async def create(
6571
status_code=status.HTTP_400_BAD_REQUEST,
6672
)
6773

68-
jsonl_contents = read_jsonl_file(jsonl_file)
74+
pre_annotation_contents = read_pre_annotation_file(jsonl_file)
6975

70-
for jsonl_content in jsonl_contents:
76+
for _item in pre_annotation_contents:
7177
pre_annotations.append(
7278
TaskPreAnnotation(
7379
task_id=task_id,
7480
file_id=pre_annotation.file_id,
75-
sample_name=jsonl_content.get("sample_name"),
76-
data=json.dumps(jsonl_content, ensure_ascii=False),
81+
sample_name= _item.get("sample_name") if jsonl_file.filename.endswith(".jsonl") else _item.get("fileName"),
82+
data=json.dumps(_item, ensure_ascii=False),
7783
created_by=current_user.id,
7884
updated_by=current_user.id,
7985
)
@@ -138,20 +144,37 @@ async def list_pre_annotation_files(
138144
sorting: Optional[str],
139145
current_user: User,
140146
) -> Tuple[List[TaskAttachment], int]:
141-
pre_annotations = crud_pre_annotation.list_by_task_id_and_owner_id(db=db, task_id=task_id, owner_id=current_user.id)
142-
file_ids = [pre_annotation.file_id for pre_annotation in pre_annotations]
143-
144-
attachments, total = crud_attachment.list_by(db=db, ids=file_ids, after=after, before=before, pageNo=pageNo, pageSize=pageSize, sorting=sorting)
145-
146-
return [
147-
PreAnnotationFileResponse(
148-
id=attachment.id,
149-
url=attachment.url,
150-
filename=attachment.filename,
151-
sample_names=[pre_annotation.sample_name for pre_annotation in pre_annotations if pre_annotation.file_id == attachment.id]
147+
try:
148+
pre_annotations = crud_pre_annotation.list_by_task_id_and_owner_id(db=db, task_id=task_id, owner_id=current_user.id)
149+
file_ids = [pre_annotation.file_id for pre_annotation in pre_annotations]
150+
151+
attachments, total = crud_attachment.list_by(db=db, ids=file_ids, after=after, before=before, pageNo=pageNo, pageSize=pageSize, sorting=sorting)
152+
153+
_attachment_ids = [attachment.id for attachment in attachments]
154+
def get_sample_names():
155+
_names = []
156+
for pre_annotation in pre_annotations:
157+
if pre_annotation.file_id in _attachment_ids and pre_annotation.sample_name is not None:
158+
_names.append(pre_annotation.sample_name)
159+
160+
return _names
161+
162+
return [
163+
PreAnnotationFileResponse(
164+
id=attachment.id,
165+
url=attachment.url,
166+
filename=attachment.filename,
167+
sample_names=get_sample_names(),
168+
)
169+
for attachment in attachments
170+
], total
171+
172+
except Exception as e:
173+
logger.error("list pre annotation files error: {}", e)
174+
raise LabelUException(
175+
code=ErrorCode.CODE_51001_TASK_ATTACHMENT_NOT_FOUND,
176+
status_code=status.HTTP_404_NOT_FOUND,
152177
)
153-
for attachment in attachments
154-
], total
155178

156179

157180
async def get(
@@ -189,7 +212,7 @@ async def delete_pre_annotation_file(
189212
db: Session, task_id: int, file_id: int, current_user: User
190213
) -> CommonDataResp:
191214
with db.begin():
192-
pre_annotations = crud_pre_annotation.list_by_task_id_and_owner_id_and_sample_name(db=db, task_id=task_id, owner_id=current_user.id, sample_name=crud_attachment.get(db, file_id).filename)
215+
pre_annotations = crud_pre_annotation.list_by_task_id_and_file_id(db=db, task_id=task_id, owner_id=current_user.id, file_id=file_id)
193216
pre_annotation_ids = [pre_annotation.id for pre_annotation in pre_annotations]
194217
crud_pre_annotation.delete(db=db, pre_annotation_ids=pre_annotation_ids)
195218
crud_attachment.delete(db=db, attachment_ids=[file_id])

labelu/internal/common/converter.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def convert_to_json(
149149
"id": sample.get("id"),
150150
"result": annotated_result_str,
151151
"url": file.get("url"),
152-
"fileName": file.get("filename", "")[9:],
152+
"fileName": file.get("filename", ""),
153153
}
154154
)
155155

@@ -222,7 +222,7 @@ def convert_to_coco(
222222
# coco image
223223
image = {
224224
"id": sample.get("id"),
225-
"fileName": file.get("filename", "")[9:],
225+
"fileName": file.get("filename", ""),
226226
"width": annotation_result.get("width", 0),
227227
"height": annotation_result.get("height", 0),
228228
"valid": False
@@ -319,7 +319,7 @@ def convert_to_mask(
319319
logger.info("data is: {}", sample)
320320
filename = file.get("filename")
321321
if filename and filename.split("/")[-1]:
322-
file_relative_path_base_name = filename.split("/")[-1].split(".")[0][9:]
322+
file_relative_path_base_name = filename.split("/")[-1].split(".")[0]
323323
else:
324324
file_relative_path_base_name = "result"
325325

@@ -444,7 +444,7 @@ def image_to_base64(file_path: str):
444444
if sample.get("state") == "SKIPPED":
445445
continue
446446

447-
labelme_item["imagePath"] = file.get("filename", "")[9:]
447+
labelme_item["imagePath"] = file.get("filename", "")
448448
labelme_item["imageData"] = image_to_base64(file.get("path"))
449449

450450
if annotated_result:
@@ -520,7 +520,7 @@ def image_to_base64(file_path: str):
520520
}
521521
labelme_item["shapes"].append(shape)
522522
result.append(labelme_item)
523-
file_basename = os.path.splitext(file.get("filename", "")[9:])[0]
523+
file_basename = os.path.splitext(file.get("filename", ""))[0]
524524
file_name = out_data_dir.joinpath(f"{file_basename}.json")
525525
with file_name.open("w") as outfile:
526526
# 格式化json,两个空格缩进
@@ -567,7 +567,7 @@ def convert_to_yolo(self, config: dict, input_data: List[dict], out_data_file_na
567567
continue
568568

569569
image_path = settings.MEDIA_ROOT.joinpath(file.get("path").lstrip("/"))
570-
file_basename = os.path.splitext(file.get("filename", "")[9:])[0]
570+
file_basename = os.path.splitext(file.get("filename", ""))[0]
571571
file_name = out_data_dir.joinpath(f"{file_basename}.txt")
572572
image_width = annotated_result.get("width", 0)
573573
image_height = annotated_result.get("height", 0)
@@ -730,7 +730,7 @@ def get_points(direction: dict):
730730
label_text = get_label(tool, label)
731731
rows.append([tool, label, label_text, direction, front, back, get_attributes(tool_result.get('attributes', {})), order])
732732

733-
file_basename = os.path.splitext(file.get("filename", "")[9:])[0]
733+
file_basename = os.path.splitext(file.get("filename", ""))[0]
734734
file_name = out_data_dir.joinpath(f"{file_basename}.csv")
735735
with file_name.open("w") as outfile:
736736
writer = csv.writer(outfile)
@@ -801,7 +801,7 @@ def convert_to_tf_record(self, config: dict, input_data: List[dict], out_data_fi
801801
for sample in input_data:
802802
file = sample.get("file", {})
803803
example = tf_record_examples.pop(0)
804-
file_basename = os.path.splitext(file.get("filename", "")[9:])[0]
804+
file_basename = os.path.splitext(file.get("filename", ""))[0]
805805
tf_record = f"{file_basename}.tfrecord"
806806
file_full_path = out_data_dir.joinpath(tf_record)
807807

@@ -834,7 +834,7 @@ def convert_to_pascal_voc(self, config: dict, input_data: List[dict], out_data_f
834834
continue
835835

836836
voc_xml = xml_converter.create_pascal_voc_xml(config, file, annotated_result)
837-
file_basename = os.path.splitext(file.get("filename", "")[9:])[0]
837+
file_basename = os.path.splitext(file.get("filename", ""))[0]
838838
file_name = out_data_dir.joinpath(f"{file_basename}.xml")
839839

840840
tree = ET.ElementTree(voc_xml)

0 commit comments

Comments
 (0)