Skip to content

Commit 92ef7d0

Browse files
committed
arguments for node type, host ipc, no symlinks
1 parent f40e72f commit 92ef7d0

File tree

1 file changed

+45
-11
lines changed

1 file changed

+45
-11
lines changed

csub.py

+45-11
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,25 @@
9292
type=int,
9393
help="specifies the number of retries before marking a workload as failed (default 0). only exists for train jobs",
9494
)
95+
parser.add_argument(
96+
"--node_type",
97+
type=str,
98+
default="",
99+
choices=["", "G9", "G10"],
100+
help="node type to run on (default is empty, which means any node). \
101+
only exists for IC cluster: G9 for V100, G10 for A100. \
102+
leave empty for RCP",
103+
)
104+
parser.add_argument(
105+
"--host_ipc",
106+
action="store_true",
107+
help="created workload will use the host's ipc namespace",
108+
)
109+
parser.add_argument(
110+
"--no_symlinks",
111+
action="store_true",
112+
help="do not create symlinks to the user's home directory",
113+
)
95114

96115
if __name__ == "__main__":
97116
args = parser.parse_args()
@@ -130,17 +149,22 @@
130149
backofflimit = ""
131150

132151
working_dir = user_cfg["working_dir"]
133-
symlink_targets, symlink_destinations = zip(*user_cfg["symlinks"].items())
134-
symlink_targets = ":".join(
135-
[os.path.join(working_dir, target) for target in symlink_targets]
136-
)
137-
symlink_paths = ":".join(
138-
[
139-
os.path.join(f"/home/{user_cfg['user']}", dest[1])
140-
for dest in symlink_destinations
141-
]
142-
)
143-
symlink_types = ":".join([dest[0] for dest in symlink_destinations])
152+
if not args.no_symlinks:
153+
symlink_targets, symlink_destinations = zip(*user_cfg["symlinks"].items())
154+
symlink_targets = ":".join(
155+
[os.path.join(working_dir, target) for target in symlink_targets]
156+
)
157+
symlink_paths = ":".join(
158+
[
159+
os.path.join(f"/home/{user_cfg['user']}", dest[1])
160+
for dest in symlink_destinations
161+
]
162+
)
163+
symlink_types = ":".join([dest[0] for dest in symlink_destinations])
164+
else:
165+
symlink_targets = ""
166+
symlink_paths = ""
167+
symlink_types = ""
144168
cfg = f"""
145169
apiVersion: run.ai/v2alpha1
146170
kind: {workload_kind}
@@ -210,6 +234,16 @@
210234
username:
211235
value: {user_cfg['user']}
212236
"""
237+
if args.node_type:
238+
cfg += f"""
239+
nodeType:
240+
value: {args.node_type} # G10 for A100, G9 for V100 (on IC cluster)
241+
"""
242+
if args.host_ipc:
243+
cfg += f"""
244+
hostIpc:
245+
value: true
246+
"""
213247

214248
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f:
215249
f.write(cfg)

0 commit comments

Comments
 (0)