|
92 | 92 | type=int,
|
93 | 93 | help="specifies the number of retries before marking a workload as failed (default 0). only exists for train jobs",
|
94 | 94 | )
|
| 95 | +parser.add_argument( |
| 96 | + "--node_type", |
| 97 | + type=str, |
| 98 | + default="", |
| 99 | + choices=["", "G9", "G10"], |
| 100 | + help="node type to run on (default is empty, which means any node). \ |
| 101 | + only exists for IC cluster: G9 for V100, G10 for A100. \ |
| 102 | + leave empty for RCP", |
| 103 | +) |
| 104 | +parser.add_argument( |
| 105 | + "--host_ipc", |
| 106 | + action="store_true", |
| 107 | + help="created workload will use the host's ipc namespace", |
| 108 | +) |
| 109 | +parser.add_argument( |
| 110 | + "--no_symlinks", |
| 111 | + action="store_true", |
| 112 | + help="do not create symlinks to the user's home directory", |
| 113 | +) |
95 | 114 |
|
96 | 115 | if __name__ == "__main__":
|
97 | 116 | args = parser.parse_args()
|
|
130 | 149 | backofflimit = ""
|
131 | 150 |
|
132 | 151 | working_dir = user_cfg["working_dir"]
|
133 |
| - symlink_targets, symlink_destinations = zip(*user_cfg["symlinks"].items()) |
134 |
| - symlink_targets = ":".join( |
135 |
| - [os.path.join(working_dir, target) for target in symlink_targets] |
136 |
| - ) |
137 |
| - symlink_paths = ":".join( |
138 |
| - [ |
139 |
| - os.path.join(f"/home/{user_cfg['user']}", dest[1]) |
140 |
| - for dest in symlink_destinations |
141 |
| - ] |
142 |
| - ) |
143 |
| - symlink_types = ":".join([dest[0] for dest in symlink_destinations]) |
| 152 | + if not args.no_symlinks: |
| 153 | + symlink_targets, symlink_destinations = zip(*user_cfg["symlinks"].items()) |
| 154 | + symlink_targets = ":".join( |
| 155 | + [os.path.join(working_dir, target) for target in symlink_targets] |
| 156 | + ) |
| 157 | + symlink_paths = ":".join( |
| 158 | + [ |
| 159 | + os.path.join(f"/home/{user_cfg['user']}", dest[1]) |
| 160 | + for dest in symlink_destinations |
| 161 | + ] |
| 162 | + ) |
| 163 | + symlink_types = ":".join([dest[0] for dest in symlink_destinations]) |
| 164 | + else: |
| 165 | + symlink_targets = "" |
| 166 | + symlink_paths = "" |
| 167 | + symlink_types = "" |
144 | 168 | cfg = f"""
|
145 | 169 | apiVersion: run.ai/v2alpha1
|
146 | 170 | kind: {workload_kind}
|
|
210 | 234 | username:
|
211 | 235 | value: {user_cfg['user']}
|
212 | 236 | """
|
| 237 | + if args.node_type: |
| 238 | + cfg += f""" |
| 239 | + nodeType: |
| 240 | + value: {args.node_type} # G10 for A100, G9 for V100 (on IC cluster) |
| 241 | +""" |
| 242 | + if args.host_ipc: |
| 243 | + cfg += f""" |
| 244 | + hostIpc: |
| 245 | + value: true |
| 246 | +""" |
213 | 247 |
|
214 | 248 | with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f:
|
215 | 249 | f.write(cfg)
|
|
0 commit comments