Skip to content

Commit 16cbcec

Browse files
committed
Change func params from list to tuple, clean up docs
1 parent d872ee5 commit 16cbcec

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

label_maker/package.py

+18-14
Original file line numberDiff line numberDiff line change
@@ -9,43 +9,46 @@
99
from label_maker.utils import is_tif
1010

1111

12-
def package_directory(dest_folder, classes, imagery, ml_type, seed=False, split_names=['train', 'test'],
13-
split_vals=[0.8, .2], **kwargs):
12+
def package_directory(dest_folder, classes, imagery, ml_type, seed=False,
13+
split_names=('train', 'test'), split_vals=(0.8, .2),
14+
**kwargs):
1415
"""Generate an .npz file containing arrays for training machine learning algorithms
1516
1617
Parameters
1718
------------
1819
dest_folder: str
1920
Folder to save labels, tiles, and final numpy arrays into
2021
classes: list
21-
A list of classes for machine learning training. Each class is defined as a dict
22-
with two required properties:
22+
A list of classes for machine learning training. Each class is defined
23+
as a dict with two required properties:
2324
- name: class name
2425
- filter: A Mapbox GL Filter.
2526
See the README for more details
2627
imagery: str
2728
Imagery template to download satellite images from.
2829
Ex: http://a.tiles.mapbox.com/v4/mapbox.satellite/{z}/{x}/{y}.jpg?access_token=ACCESS_TOKEN
2930
ml_type: str
30-
Defines the type of machine learning. One of "classification", "object-detection", or "segmentation"
31+
Defines the type of machine learning. One of "classification",
32+
"object-detection", or "segmentation"
3133
seed: int
3234
Random generator seed. Optional, use to make results reproducible.
33-
split_vals: list
34-
Default: [0.8, 0.2]
35-
Percentage of data to put in each catagory listed in split_names.
36-
Must be floats and must sum to one.
37-
split_names: list
38-
Default: ['train', 'test']
35+
split_vals: tuple
36+
Percentage of data to put in each catagory listed in split_names. Must
37+
be floats and must sum to one. Default: (0.8, 0.2)
38+
split_names: tupel
39+
Default: ('train', 'test')
3940
List of names for each subset of the data.
4041
**kwargs: dict
41-
Other properties from CLI config passed as keywords to other utility functions
42+
Other properties from CLI config passed as keywords to other utility
43+
functions.
4244
"""
4345
# if a seed is given, use it
4446
if seed:
4547
np.random.seed(seed)
4648

4749
if len(split_names) != len(split_vals):
48-
raise ValueError('`split_names` and `split_vals` must be the same length. Please update your config.')
50+
raise ValueError('`split_names` and `split_vals` must be the same '
51+
'length. Please update your config.')
4952
if not np.isclose(sum(split_vals), 1):
5053
raise ValueError('`split_vals` must sum to one. Please update your config.')
5154

@@ -105,7 +108,8 @@ def package_directory(dest_folder, classes, imagery, ml_type, seed=False, split_
105108
split_n_samps = [len(x_vals) * val for val in split_vals]
106109

107110
if np.any(split_n_samps == 0):
108-
raise ValueError('split must not generate zero samples per partition, change ratio of values in config file.')
111+
raise ValueError('Split must not generate zero samples per partition. '
112+
'Change ratio of values in config file.')
109113

110114
# Convert into a cumulative sum to get indices
111115
split_inds = np.cumsum(split_n_samps).astype(np.integer)

0 commit comments

Comments
 (0)