Skip to content

Commit 570b19e

Browse files
Update python_bundle_workflow (#1656)
Fixes # . ### Description - add description for `self._set_prop` in python `python_bundle_workflow`. - remove generate data outside of the Workflow class. ### Checks <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [ ] Avoid including large-size files in the PR. - [ ] Clean up long text outputs from code cells in the notebook. - [ ] For security purposes, please check the contents and remove any sensitive info such as user names and private key. - [ ] Ensure (1) hyperlinks and markdown anchors are working (2) use relative paths for tutorial repo files (3) put figure and graphs in the `./figure` folder - [ ] Notebook runs automatically `./runner.sh -t <path to .ipynb file>` --------- Signed-off-by: YunLiu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8186739 commit 570b19e

File tree

2 files changed

+29
-26
lines changed

2 files changed

+29
-26
lines changed

bundle/python_bundle_workflow/scripts/inference.py

+9-13
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
ScaleIntensityd,
3939
)
4040
from monai.utils import BundleProperty
41+
from scripts.train import prepare_data
4142

4243

4344
class InferenceWorkflow(BundleWorkflow):
@@ -46,28 +47,22 @@ class InferenceWorkflow(BundleWorkflow):
4647
4748
"""
4849

49-
def __init__(self, dataset_dir: str = "."):
50+
def __init__(self, dataset_dir: str = "./infer"):
5051
super().__init__(workflow="inference")
5152
print_config()
5253
# set root log level to INFO and init a evaluation logger, will be used in `StatsHandler`
5354
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
5455
get_logger("eval_log")
5556

5657
# create a temporary directory and 40 random image, mask pairs
57-
print(f"generating synthetic data to {dataset_dir} (this may take a while)")
58-
for i in range(5):
59-
im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
60-
n = nib.Nifti1Image(im, np.eye(4))
61-
nib.save(n, os.path.join(dataset_dir, f"img{i:d}.nii.gz"))
62-
n = nib.Nifti1Image(seg, np.eye(4))
63-
nib.save(n, os.path.join(dataset_dir, f"seg{i:d}.nii.gz"))
58+
prepare_data(dataset_dir=dataset_dir)
6459

6560
self._props = {}
6661
self._set_props = {}
6762
self.dataset_dir = dataset_dir
6863

6964
def initialize(self):
70-
self.props = {}
65+
self._props = {}
7166

7267
def run(self):
7368
self.evaluator.run()
@@ -76,6 +71,7 @@ def finalize(self):
7671
pass
7772

7873
def _set_property(self, name, property, value):
74+
# stores user-reset initialized objects that should not be re-initialized.
7975
self._set_props[name] = value
8076

8177
def _get_property(self, name, property):
@@ -88,11 +84,11 @@ def _get_property(self, name, property):
8884
8985
"""
9086
value = None
91-
if name in self._props:
92-
value = self._props[name]
93-
elif name in self._set_props:
87+
if name in self._set_props:
9488
value = self._set_props[name]
9589
self._props[name] = value
90+
elif name in self._props:
91+
value = self._props[name]
9692
else:
9793
try:
9894
value = getattr(self, f"get_{name}")()
@@ -112,7 +108,7 @@ def get_device(self):
112108
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
113109

114110
def get_dataset_dir(self):
115-
return "."
111+
return self.dataset_dir
116112

117113
def get_network_def(self):
118114
return UNet(

bundle/python_bundle_workflow/scripts/train.py

+20-13
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import logging
1313
import os
1414
import sys
15+
from pathlib import Path
1516
from glob import glob
1617

1718
import nibabel as nib
@@ -48,27 +49,32 @@
4849
from monai.utils import BundleProperty, set_determinism
4950

5051

52+
def prepare_data(dataset_dir):
53+
Path(dataset_dir).mkdir(exist_ok=True)
54+
print(f"generating synthetic data to {dataset_dir} (this may take a while)")
55+
for i in range(40):
56+
im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
57+
n = nib.Nifti1Image(im, np.eye(4))
58+
nib.save(n, os.path.join(dataset_dir, f"img{i:d}.nii.gz"))
59+
n = nib.Nifti1Image(seg, np.eye(4))
60+
nib.save(n, os.path.join(dataset_dir, f"seg{i:d}.nii.gz"))
61+
62+
5163
class TrainWorkflow(BundleWorkflow):
5264
"""
5365
Test class simulates the bundle training workflow defined by Python script directly.
5466
5567
"""
5668

57-
def __init__(self, dataset_dir: str = "."):
69+
def __init__(self, dataset_dir: str = "./train"):
5870
super().__init__(workflow="train")
5971
print_config()
6072
# set root log level to INFO and init a train logger, will be used in `StatsHandler`
6173
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
6274
get_logger("train_log")
6375

6476
# create a temporary directory and 40 random image, mask pairs
65-
print(f"generating synthetic data to {dataset_dir} (this may take a while)")
66-
for i in range(40):
67-
im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
68-
n = nib.Nifti1Image(im, np.eye(4))
69-
nib.save(n, os.path.join(dataset_dir, f"img{i:d}.nii.gz"))
70-
n = nib.Nifti1Image(seg, np.eye(4))
71-
nib.save(n, os.path.join(dataset_dir, f"seg{i:d}.nii.gz"))
77+
prepare_data(dataset_dir=dataset_dir)
7278

7379
# define buckets to store the generated properties and set properties
7480
self._props = {}
@@ -82,7 +88,7 @@ def __init__(self, dataset_dir: str = "."):
8288

8389
def initialize(self):
8490
set_determinism(0)
85-
self.props = {}
91+
self._props = {}
8692

8793
def run(self):
8894
self.trainer.run()
@@ -91,6 +97,7 @@ def finalize(self):
9197
set_determinism(None)
9298

9399
def _set_property(self, name, property, value):
100+
# stores user-reset initialized objects that should not be re-initialized.
94101
self._set_props[name] = value
95102

96103
def _get_property(self, name, property):
@@ -103,11 +110,11 @@ def _get_property(self, name, property):
103110
104111
"""
105112
value = None
106-
if name in self._props:
107-
value = self._props[name]
108-
elif name in self._set_props:
113+
if name in self._set_props:
109114
value = self._set_props[name]
110115
self._props[name] = value
116+
elif name in self._props:
117+
value = self._props[name]
111118
else:
112119
try:
113120
value = getattr(self, f"get_{name}")()
@@ -127,7 +134,7 @@ def get_device(self):
127134
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
128135

129136
def get_dataset_dir(self):
130-
return "."
137+
return self.dataset_dir
131138

132139
def get_network(self):
133140
return UNet(

0 commit comments

Comments
 (0)