diff --git a/src/neuron_proofreader/machine_learning/train.py b/src/neuron_proofreader/machine_learning/train.py index 7798bbb..c18114e 100644 --- a/src/neuron_proofreader/machine_learning/train.py +++ b/src/neuron_proofreader/machine_learning/train.py @@ -331,7 +331,7 @@ def check_model_performance(self, stats, epoch): True if the model achieved a new best F1 score and was saved. False otherwise. """ - if stats["f1"] > self.best_f1 and stats["recall"] > 0.82: + if stats["f1"] > self.best_f1 and stats["recall"] > 0.83: self.best_f1 = stats["f1"] self.save_model(epoch) return True diff --git a/src/neuron_proofreader/merge_proofreading/merge_inference.py b/src/neuron_proofreader/merge_proofreading/merge_inference.py index 92050bc..fe3ee79 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_inference.py +++ b/src/neuron_proofreader/merge_proofreading/merge_inference.py @@ -133,7 +133,7 @@ def filter_with_nms(self, merge_sites, likelihoods): ) if iou > 0.35: merge_sites_set.remove(i) - self.node_preds[i] = 1 + self.node_preds[i] = 1e-2 # Populate queue for j in self.dataset.graph.neighbors(i): @@ -147,6 +147,10 @@ def remove_merge_sites(self, detected_merge_sites): pass # --- Helpers --- + def get_detected_sites(self, threshold): + nodes = np.where(self.node_preds >= threshold)[0] + return [self.dataset.graph.node_xyz[i] for i in nodes] + def save_results(self, output_dir, output_prefix_s3=None): # Get predicted merge sites nodes = np.where(self.node_preds >= self.threshold)[0]