Skip to content

Commit 3148ff0

Browse files
Using Pathology bundles for nuclick and classification models (#1172)
* support bundles for nuclick and classify models Signed-off-by: Sachidanand Alle <[email protected]> * sync up changes Signed-off-by: Sachidanand Alle <[email protected]> * remove nuclick transform copy and use monai app instead Signed-off-by: Sachidanand Alle <[email protected]> Signed-off-by: Sachidanand Alle <[email protected]>
1 parent 4e72ec6 commit 3148ff0

File tree

14 files changed

+158
-1241
lines changed

14 files changed

+158
-1241
lines changed

monailabel/interfaces/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(self):
3636
self.path = None
3737
self.labels = None
3838
self.label_colors = None
39+
self.bundle_path = None
3940

4041
def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **kwargs):
4142
self.name = name

monailabel/tasks/train/basic_train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,7 @@ def _create_trainer(self, context: Context):
631631
amp=self._amp,
632632
postprocessing=self._validate_transforms(self.train_post_transforms(context), "Training", "post"),
633633
key_train_metric=self.train_key_metric(context),
634+
additional_metrics=self.train_additional_metrics(context),
634635
train_handlers=train_handlers,
635636
iteration_update=self.train_iteration_update(context),
636637
event_names=self.event_names(context),

monailabel/tasks/train/bundle.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def config(self):
104104
"gpus": "all", # COMMA SEPARATE DEVICE INDEX
105105
}
106106

107-
def _fetch_datalist(self, datastore: Datastore):
107+
def _fetch_datalist(self, request, datastore: Datastore):
108108
return datastore.datalist()
109109

110110
def _partition_datalist(self, datalist, request, shuffle=False):
@@ -144,16 +144,18 @@ def _load_checkpoint(self, output_dir, pretrained, train_handlers):
144144
train_handlers.insert(0, loader)
145145

146146
def __call__(self, request, datastore: Datastore):
147-
ds = self._fetch_datalist(datastore)
147+
logger.info(f"Train Request: {request}")
148+
ds = self._fetch_datalist(request, datastore)
148149
train_ds, val_ds = self._partition_datalist(ds, request)
149150

150151
max_epochs = request.get("max_epochs", 50)
151152
pretrained = request.get("pretrained", True)
152-
multi_gpu = request.get("multi_gpu", False)
153+
multi_gpu = request.get("multi_gpu", True)
153154
multi_gpu = multi_gpu if torch.cuda.device_count() > 1 else False
154155

155156
gpus = request.get("gpus", "all")
156157
gpus = list(range(torch.cuda.device_count())) if gpus == "all" else [int(g) for g in gpus.split(",")]
158+
multi_gpu = True if multi_gpu and len(gpus) > 1 else False
157159
logger.info(f"Using Multi GPU: {multi_gpu}; GPUS: {gpus}")
158160
logger.info(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
159161

sample-apps/endoscopy/lib/trainers/inbody.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424

2525

2626
class InBody(BundleTrainTask):
27-
def _fetch_datalist(self, datastore: Datastore):
28-
ds = super()._fetch_datalist(datastore)
27+
def _fetch_datalist(self, request, datastore: Datastore):
28+
ds = super()._fetch_datalist(request, datastore)
2929

3030
out_body = datastore.label_map.get("OutBody", 3) if isinstance(datastore, CVATDatastore) else 1
3131
load = LoadImage(dtype=np.uint8, image_only=True)

sample-apps/pathology/lib/configs/classification_nuclei.py

Lines changed: 8 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,17 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
import json
1312
import logging
1413
import os
1514
from typing import Any, Dict, Optional, Union
1615

1716
import lib.infers
1817
import lib.trainers
19-
from monai.networks.nets import DenseNet121
18+
from monai.bundle import download
2019

2120
from monailabel.interfaces.config import TaskConfig
2221
from monailabel.interfaces.tasks.infer_v2 import InferTask
2322
from monailabel.interfaces.tasks.train import TrainTask
24-
from monailabel.utils.others.generic import download_file, strtobool
2523

2624
logger = logging.getLogger(__name__)
2725

@@ -30,81 +28,16 @@ class ClassificationNuclei(TaskConfig):
3028
def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **kwargs):
3129
super().init(name, model_dir, conf, planner, **kwargs)
3230

33-
# Labels
34-
self.labels = {
35-
"Neoplastic cells": 1,
36-
"Inflammatory": 2,
37-
"Connective/Soft tissue cells": 3,
38-
"Dead Cells": 4,
39-
"Epithelial": 5,
40-
}
41-
self.label_colors = {
42-
"Neoplastic cells": (255, 0, 0),
43-
"Inflammatory": (255, 255, 0),
44-
"Connective/Soft tissue cells": (0, 255, 0),
45-
"Dead Cells": (0, 0, 0),
46-
"Epithelial": (0, 0, 255),
47-
}
48-
49-
consep = strtobool(self.conf.get("consep", "false"))
50-
if consep:
51-
self.labels = {
52-
"Other": 1,
53-
"Inflammatory": 2,
54-
"Epithelial": 3,
55-
"Spindle-Shaped": 4,
56-
}
57-
self.label_colors = {
58-
"Other": (255, 0, 0),
59-
"Inflammatory": (255, 255, 0),
60-
"Epithelial": (0, 0, 255),
61-
"Spindle-Shaped": (0, 255, 0),
62-
}
63-
64-
# Model Files
65-
self.path = [
66-
os.path.join(self.model_dir, f"pretrained_{name}{'_consep' if consep else ''}.pt"), # pretrained
67-
os.path.join(self.model_dir, f"{name}{'_consep' if consep else ''}.pt"), # published
68-
]
69-
70-
# Download PreTrained Model
71-
if strtobool(self.conf.get("use_pretrained_model", "true")):
72-
url = f"{self.conf.get('pretrained_path', self.PRE_TRAINED_PATH)}"
73-
url = f"{url}/pathology_classification_densenet121_nuclei{'_consep' if consep else ''}.pt"
74-
download_file(url, self.path[0])
75-
76-
# Network
77-
self.network = DenseNet121(spatial_dims=2, in_channels=4, out_channels=len(self.labels))
31+
bundle_name = conf.get("bundle_name", "pathology_nuclei_classification")
32+
bundle_version = conf.get("bundle_version", "0.0.1")
33+
self.bundle_path = os.path.join(self.model_dir, bundle_name)
34+
if not os.path.exists(self.bundle_path):
35+
download(name=bundle_name, version=bundle_version, bundle_dir=self.model_dir)
7836

7937
def infer(self) -> Union[InferTask, Dict[str, InferTask]]:
80-
task: InferTask = lib.infers.ClassificationNuclei(
81-
path=self.path,
82-
network=self.network,
83-
labels=self.labels,
84-
preload=strtobool(self.conf.get("preload", "false")),
85-
roi_size=json.loads(self.conf.get("roi_size", "[128, 128]")),
86-
config={
87-
"label_colors": self.label_colors,
88-
},
89-
)
38+
task: InferTask = lib.infers.ClassificationNuclei(self.bundle_path, self.conf)
9039
return task
9140

9241
def trainer(self) -> Optional[TrainTask]:
93-
output_dir = os.path.join(self.model_dir, self.name)
94-
load_path = self.path[0] if os.path.exists(self.path[0]) else self.path[1]
95-
96-
task: TrainTask = lib.trainers.ClassificationNuclei(
97-
model_dir=output_dir,
98-
network=self.network,
99-
load_path=load_path,
100-
publish_path=self.path[1],
101-
labels=self.labels,
102-
description="Train Nuclei Classification Model",
103-
train_save_interval=1,
104-
config={
105-
"max_epochs": 10,
106-
"train_batch_size": 16,
107-
"val_batch_size": 16,
108-
},
109-
)
42+
task: TrainTask = lib.trainers.ClassificationNuclei(self.bundle_path, self.conf)
11043
return task

sample-apps/pathology/lib/configs/nuclick.py

Lines changed: 8 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,17 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
import json
1312
import logging
1413
import os
1514
from typing import Any, Dict, Optional, Union
1615

1716
import lib.infers
1817
import lib.trainers
19-
from monai.networks.nets import BasicUNet
18+
from monai.bundle import download
2019

2120
from monailabel.interfaces.config import TaskConfig
2221
from monailabel.interfaces.tasks.infer_v2 import InferTask
2322
from monailabel.interfaces.tasks.train import TrainTask
24-
from monailabel.utils.others.generic import download_file, strtobool
2523

2624
logger = logging.getLogger(__name__)
2725

@@ -30,59 +28,16 @@ class NuClick(TaskConfig):
3028
def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **kwargs):
3129
super().init(name, model_dir, conf, planner, **kwargs)
3230

33-
# Labels
34-
self.labels = {"Nuclei": 1}
35-
self.label_colors = {"Nuclei": (0, 255, 255)}
36-
37-
consep = strtobool(self.conf.get("consep", "false"))
38-
39-
# Model Files
40-
self.path = [
41-
os.path.join(self.model_dir, f"pretrained_{name}{'_consep' if consep else ''}.pt"), # pretrained
42-
os.path.join(self.model_dir, f"{name}{'_consep' if consep else ''}.pt"), # published
43-
]
44-
45-
# Download PreTrained Model
46-
if strtobool(self.conf.get("use_pretrained_model", "true")):
47-
url = f"{self.conf.get('pretrained_path', self.PRE_TRAINED_PATH)}"
48-
url = f"{url}/pathology_nuclick_bunet_nuclei{'_consep' if consep else ''}.pt"
49-
download_file(url, self.path[0])
50-
51-
# Network
52-
self.network = BasicUNet(
53-
spatial_dims=2,
54-
in_channels=5,
55-
out_channels=1,
56-
features=(32, 64, 128, 256, 512, 32),
57-
)
31+
bundle_name = conf.get("bundle_name", "pathology_nuclick_annotation")
32+
bundle_version = conf.get("bundle_version", "0.0.1")
33+
self.bundle_path = os.path.join(self.model_dir, bundle_name)
34+
if not os.path.exists(self.bundle_path):
35+
download(name=bundle_name, version=bundle_version, bundle_dir=self.model_dir)
5836

5937
def infer(self) -> Union[InferTask, Dict[str, InferTask]]:
60-
task: InferTask = lib.infers.NuClick(
61-
path=self.path,
62-
network=self.network,
63-
labels=self.labels,
64-
preload=strtobool(self.conf.get("preload", "false")),
65-
roi_size=json.loads(self.conf.get("roi_size", "[512, 512]")),
66-
config={"label_colors": self.label_colors, "ignore_non_click_patches": True},
67-
)
38+
task: InferTask = lib.infers.NuClick(self.bundle_path, self.conf)
6839
return task
6940

7041
def trainer(self) -> Optional[TrainTask]:
71-
output_dir = os.path.join(self.model_dir, self.name)
72-
load_path = self.path[0] if os.path.exists(self.path[0]) else self.path[1]
73-
74-
task: TrainTask = lib.trainers.NuClick(
75-
model_dir=output_dir,
76-
network=self.network,
77-
load_path=load_path,
78-
publish_path=self.path[1],
79-
labels=self.labels,
80-
description="Train Nuclei DeepEdit Model",
81-
train_save_interval=1,
82-
config={
83-
"max_epochs": 10,
84-
"train_batch_size": 16,
85-
"val_batch_size": 16,
86-
},
87-
)
42+
task: TrainTask = lib.trainers.NuClick(self.bundle_path, self.conf)
8843
return task

0 commit comments

Comments
 (0)