Skip to content

Commit 76fa41d

Browse files
committed
feat: add improvement based on review
1 parent 4f42271 commit 76fa41d

File tree

2 files changed

+15
-18
lines changed

2 files changed

+15
-18
lines changed

app/trainers/medcat_trainer.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,13 @@ def run(
174174
fn_accumulated += fns.get(cui, 0)
175175
tp_accumulated += tps.get(cui, 0)
176176
cc_accumulated += cc.get(cui, 0)
177+
cui_info = model.cdb.cui2info.get(cui)
177178
aggregated_metrics.append({
178179
"per_concept_fp": fps.get(cui, 0),
179180
"per_concept_fn": fns.get(cui, 0),
180181
"per_concept_tp": tps.get(cui, 0),
181182
"per_concept_counts": cc.get(cui, 0),
182-
"per_concept_count_train": cast(Dict[str, Any], model.cdb.cui2info.get(cui, {})).get("count_train", 0),
183+
"per_concept_count_train": cui_info.get("count_train", 0) if cui_info is not None else 0,
183184
"per_concept_acc_fp": fp_accumulated,
184185
"per_concept_acc_fn": fn_accumulated,
185186
"per_concept_acc_tp": tp_accumulated,
@@ -309,8 +310,9 @@ def _save_trained_concepts(
309310
annotation_ignorance_count = []
310311
concepts = list(training_concepts.keys())
311312
for c in concepts:
312-
train_count.append(model.cdb.cui2info.get(c, {}).get("count_train", 0)) # type: ignore
313-
concept_names.append(model.cdb.cui2info.get(c, {}).get("preferred_name", "")) # type: ignore
313+
cui_info = model.cdb.cui2info.get(c)
314+
train_count.append(cui_info.get("count_train", 0) if cui_info is not None else 0)
315+
concept_names.append(model.cdb.get_name(c))
314316
annotation_count.append(training_concepts[c])
315317
annotation_unique_count.append(training_unique_concepts[c])
316318
annotation_ignorance_count.append(training_ignorance_counts[c])
@@ -421,7 +423,7 @@ def run(
421423
logger.info("Performing unsupervised training...")
422424
step = 0
423425
self._tracker_client.send_model_stats(dict(model.cdb.get_basic_info()), step)
424-
before_cui2count_train = {c: info["count_train"] for c, info in model.cdb.cui2info.items()}
426+
before_cui2count_train = model.cdb.get_cui2count_train()
425427
num_of_docs = 0
426428
train_unsupervised_params = get_func_params_as_dict(model.trainer.train_unsupervised)
427429
train_unsupervised_params = {p_key: training_params[p_key] if p_key in training_params else p_val for p_key, p_val in train_unsupervised_params.items()}
@@ -437,13 +439,14 @@ def run(
437439

438440
self._tracker_client.log_document_size(num_of_docs)
439441
after_cui2count_train = {
440-
c: info["count_train"]
441-
for c, info in sorted(
442-
model.cdb.cui2info.items(),
443-
key=lambda item: item[1]["count_train"],
442+
c: ct
443+
for c, ct in sorted(
444+
model.cdb.get_cui2count_train().items(),
445+
key=lambda item: item[1],
444446
reverse=True,
445447
)
446448
}
449+
447450
aggregated_metrics = []
448451
cui_step = 0
449452
for cui, train_count in after_cui2count_train.items():

app/trainers/metacat_trainer.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,13 @@ def run(
8686
model = self._model_service.load_model(copied_model_pack_path)
8787
is_retrained = False
8888
model.config.meta.description = description or model.config.meta.description
89-
meta_cat_addons = [
90-
addon for addon in model.get_addons()
91-
if addon.addon_type == MetaCATAddon.addon_type
92-
]
89+
meta_cat_addons = model.get_addons_of_type(MetaCATAddon)
9390
for meta_cat_addon in meta_cat_addons:
9491
if self._cancel_event.is_set():
9592
self._cancel_event.clear()
9693
raise TrainingCancelledException("Training was cancelled by the user")
9794

98-
meta_cat = cast(MetaCATAddon, meta_cat_addon).mc
95+
meta_cat = meta_cat_addon.mc
9996
category_name = meta_cat.config.general.category_name
10097
assert category_name is not None, "Category name should not be None"
10198
if meta_cat.config.general.alternative_class_names == [[]]:
@@ -203,12 +200,9 @@ def run(
203200
logger.info("Evaluating the running model...")
204201
metrics: List[Dict] = []
205202
assert self._model_service.model is not None, "Model should not be None"
206-
meta_cat_addons = [
207-
addon for addon in self._model_service.model.get_addons()
208-
if addon.addon_type == MetaCATAddon.addon_type
209-
]
203+
meta_cat_addons = self._model_service.model.get_addons_of_type(MetaCATAddon)
210204
for meta_cat_addon in meta_cat_addons:
211-
meta_cat = cast(MetaCATAddon, meta_cat_addon).mc
205+
meta_cat = meta_cat_addon.mc
212206
category_name = meta_cat.config.general.category_name
213207
self._tracker_client.log_model_config(self.get_flattened_metacat_config(meta_cat, category_name))
214208
self._tracker_client.log_trainer_version(TrainerBackend.MEDCAT, medcat_version)

0 commit comments

Comments
 (0)