diff --git a/rats-devtools/src/rats/aml/_app.py b/rats-devtools/src/rats/aml/_app.py index 88c35e4a..983842de 100644 --- a/rats-devtools/src/rats/aml/_app.py +++ b/rats-devtools/src/rats/aml/_app.py @@ -546,6 +546,127 @@ def _submit( time.sleep(2) + @cli.command() + @click.argument("args", nargs=-1) + @click.option("--wait", is_flag=True, default=False, help="wait for completion of aml job.") + def _submit_cmd( + self, + args: tuple[str, ...], + wait: bool, + ) -> None: + """Submit a single cli command to aml.""" + from azure.ai.ml import Input, Output, command # type: ignore[reportUnknownVariableType] + from azure.ai.ml.entities import Environment + + if self._request is not None: + print(self._request) + raise runtime.DuplicateRequestError() + + if len(args) == 0: + raise RuntimeError("No cli args were provided to the command") + + env: dict[str, str] = {} + for env_map in self._app.get_group(AppConfigs.CLI_ENVS): + env.update(env_map) + + cli_command = cli.Command( + cwd=self._app.get(AppConfigs.CLI_CWD), + argv=tuple([*args]), + env=env, + ) + + pre_cmds: list[str] = [] + for pre_cmd in self._app.get_group(AppConfigs.CLI_PRE_CMD): + pre_cmds.append(shlex.join(["cd", pre_cmd.cwd])) + pre_cmds.append( + shlex.join( + [ + *[f"{k}={shlex.quote(v)}" for k, v in pre_cmd.env.items()], + *pre_cmd.argv, + ] + ) + ) + + post_cmds: list[str] = [] + for post_cmd in self._app.get_group(AppConfigs.CLI_POST_CMD): + post_cmds.append(shlex.join(["cd", post_cmd.cwd])) + post_cmds.append( + shlex.join( + [ + *[f"{k}={shlex.quote(v)}" for k, v in post_cmd.env.items()], + *post_cmd.argv, + ] + ) + ) + + config = self._app.get(AppConfigs.JOB_DETAILS) + env_ops = self._app.get(AppServices.AML_ENVIRONMENT_OPS) + job_ops = self._app.get(AppServices.AML_JOB_OPS) + + input_keys = config.inputs.keys() + output_keys = config.outputs.keys() + + input_envs = ( + [ + # double escape double curly braces for aml to later replace with dataset values + f"export RATS_AML_PATH_{k.upper()}=${{{{inputs.{k}}}}}" + for k in input_keys + ] + if len(input_keys) > 0 + else [] + ) + + output_keys = ( + [ + # double escape double curly braces for aml to later replace with dataset values + f"export RATS_AML_PATH_{k.upper()}=${{{{outputs.{k}}}}}" + for k in output_keys + ] + if len(output_keys) > 0 + else [] + ) + + cmd = " && ".join( + [ + # make sure we know the original directory and any input/output paths + "export RATS_AML_ORIGINAL_PWD=${PWD}", + *input_envs, + *output_keys, + *pre_cmds, + shlex.join(["cd", cli_command.cwd]), + shlex.join(cli_command.argv), + *post_cmds, + ] + ) + + env_ops.create_or_update(Environment(**config.environment._asdict())) + extra_aml_command_args = self._app.get(AppConfigs.COMMAND_KWARGS) + + job = command( + command=cmd, + compute=config.compute, + environment=config.environment.full_name, + outputs={ + k: Output(type=v.type, path=v.path, mode=v.mode) for k, v in config.outputs.items() + }, + inputs={ + k: Input(type=v.type, path=v.path, mode=v.mode) for k, v in config.inputs.items() + }, + environment_variables=dict(cli_command.env), + **extra_aml_command_args, + ) + returned_job = job_ops.create_or_update(job) # type: ignore[reportUnknownMemberType] + logger.info(f"created job: {returned_job.name}") + + self._request = Request( + job_name=str(returned_job.name), + wait=wait, + ) + + if wait: + logger.info(f"waiting for completion of job: {returned_job.name}") + job_ops.stream(str(returned_job.name)) + @cache # noqa: B019 def _runtime_list(self) -> tuple[str, ...]: """