Skip to content

csub refactor #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
295 changes: 155 additions & 140 deletions csub.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
import subprocess
import tempfile
import yaml
import json
import os
import time
import sys


parser = argparse.ArgumentParser(description="Cluster Submit Utility")
parser.add_argument(
Expand Down Expand Up @@ -104,7 +108,7 @@
"--node_type",
type=str,
default="",
choices=["", "g9", "g10", "h100", "default"],
choices=["", "g9", "g10", "h100", "v100", "default"],
help="node type to run on (default is empty, which means any node). \
IC cluster: g9 for V100, g10 for A100. \
RCP-Prod cluster: h100 for H100, use 'default' to get A100 on interactive jobs",
Expand All @@ -124,10 +128,29 @@
action="store_true",
help="use large shared memory /dev/shm for the job",
)
parser.add_argument(
"--follow",
action="store_true",
help="follow logs of the launched job",
)
parser.add_argument(
"--github-secret-file",
default="~/.ssh/github",
help="private ssh key for github to be set as env",
)


def to_string_values(d):
if not isinstance(d, dict):
return str(d)
return {k: to_string_values(v) for k, v in d.items()}


if __name__ == "__main__":
args = parser.parse_args()

args.user = os.path.expanduser(args.user)

if not os.path.exists(args.user):
print(
f"User file {args.user} does not exist, use the template in `template/user.yaml` to create your user file."
Expand Down Expand Up @@ -195,155 +218,147 @@
symlink_types = ""

# this is the yaml file that will be submitted to the cluster
cfg = f"""
apiVersion: run.ai/v2alpha1
kind: {workload_kind}
metadata:
annotations:
runai-cli-version: {runai_cli_version}
labels:
PreviousJob: "true"
name: {args.name}
namespace: runai-mlo-{user_cfg['user']}
spec:
name:
value: {args.name}
arguments:
value: "/bin/zsh -c 'source ~/.zshrc && {args.command}'" # zshrc is just loaded to have some env variables ready
environment:
items:
HOME:
value: "/home/{user_cfg['user']}"
NB_USER:
value: {user_cfg['user']}
NB_UID:
value: "{user_cfg['uid']}"
NB_GROUP:
value: {user_cfg['group']}
NB_GID:
value: "{user_cfg['gid']}"
WORKING_DIR:
value: "{working_dir}"
SYMLINK_TARGETS:
value: "{symlink_targets}"
SYMLINK_PATHS:
value: "{symlink_paths}"
SYMLINK_TYPES:
value: "{symlink_types}"
WANDB_API_KEY:
value: {user_cfg['wandb_api_key']}
HF_HOME:
value: /mloscratch/hf_cache
HF_TOKEN:
value: {user_cfg['hf_token']}
EPFML_LDAP:
value: {user_cfg['user']}
gpu:
value: "{args.gpus}"
cpu:
value: "{args.cpus}"
memory:
value: "{args.memory}"
image:
value: {args.image}
imagePullPolicy:
value: Always
pvcs:
items:
pvc--0:
value:
claimName: {scratch_name}
existingPvc: true
path: /mloscratch
readOnly: false
## these two lines are necessary on RCP, not on the new IC
runAsGid:
value: {user_cfg['gid']}
runAsUid:
value: {user_cfg['uid']}
##
runAsUser:
value: true
serviceType:
value: ClusterIP
username:
value: {user_cfg['user']}
allowPrivilegeEscalation: # allow sudo
value: true
"""
cfg: dict = dict(
apiVersion="run.ai/v2alpha1",
kind=workload_kind,
metadata=dict(
annotations={"runai-cli-version": runai_cli_version},
labels={"PreviousJob": "true"},
name=args.name,
namespace=f"runai-mlo-{user_cfg['user']}",
),
spec=dict(
name={"value": args.name},
arguments={"value": f"/bin/zsh -c 'source ~/.zshrc && {args.command}'"},
environment=dict(
items=to_string_values(
{
"HOME": {"value": f"/home/{user_cfg['user']}"},
"NB_USER": {"value": user_cfg["user"]},
"NB_UID": {"value": user_cfg["uid"]},
"NB_GROUP": {"value": user_cfg["group"]},
"NB_GID": {"value": user_cfg["gid"]},
"WORKING_DIR": {"value": working_dir},
"SYMLINK_TARGETS": {"value": symlink_targets},
"SYMLINK_PATHS": {"value": symlink_paths},
"SYMLINK_TYPES": {"value": symlink_types},
"WANDB_API_KEY": {"value": user_cfg["wandb_api_key"]},
"HF_HOME": {"value": "/mloscratch/hf_cache"},
"HF_TOKEN": {"value": user_cfg["hf_token"]},
"EPFML_LDAP": {"value": user_cfg["user"]},
}
)
),
gpu={"value": str(args.gpus)},
cpu={"value": str(args.cpus)},
memory={"value": str(args.memory)},
image={"value": args.image},
imagePullPolicy={"value": "Always"},
pvcs={
"items": {
"pvc--0": {
"value": {
"claimName": scratch_name,
"existingPvc": True,
"path": "/mloscratch",
"readOnly": False,
}
}
}
},
runAsGid={"value": user_cfg["gid"]},
runAsUid={"value": user_cfg["uid"]},
runAsUser={"value": True},
serviceType={"value": "ClusterIP"},
username={"value": user_cfg["user"]},
allowPrivilegeEscalation={"value": True},
),
)

#### some additional flags that can be added at the end of the config
if args.node_type in ["g10", "g9", "h100", "default"]:
cfg += f"""
nodePools:
value: {args.node_type} # g10 for A100, g9 for V100 (only on IC cluster)
"""
if args.node_type in ["g10", "h100", "default"] and not args.train:
# for interactive jobs on A100s (g10 nodes), we need to set the jobs preemptible
# see table "Types of Workloads" https://inside.epfl.ch/ic-it-docs/ic-cluster/caas/submit-jobs/
cfg += f"""
preemptible:
value: true
"""
if args.host_ipc:
cfg += f"""
hostIpc:
value: true
"""

if args.node_type:
cfg["spec"]["nodePools"] = {"value": args.node_type}
if args.node_type in ["g10", "h100", "default"] and not args.train:
cfg["spec"]["preemptible"] = {"value": True}
cfg["spec"]["hostIpc"] = {"value": args.host_ipc}
if args.train:
cfg += f"""
backoffLimit:
value: {args.backofflimit}
"""
cfg["spec"]["backoffLimit"] = {"value": args.backofflimit}

if args.large_shm:
cfg += f"""
largeShm:
value: true
"""
cfg["spec"]["largeShm"] = {"value": True}

github_key = os.path.expanduser(args.github_secret_file)
if os.path.exists(github_key):
with open(github_key) as f:
cfg["spec"]["environment"]["items"]["GITHUB_KEY"] = {"value": f.read()}

with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f:
f.write(cfg)
with tempfile.NamedTemporaryFile(mode="w+", suffix=".yaml") as f:
yaml.dump(cfg, f)
f.flush()
if args.dry:
print(cfg)
else:
# Run the subprocess and capture stdout and stderr
f.seek(0)
print(f.read())
exit()
# Run the subprocess and capture stdout and stderr
result = subprocess.run(
["kubectl", "apply", "-f", f.name],
# check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)

# Check if there was an error
if result.returncode != 0:
print("Error encountered:")
# Prettify and print the stderr
pprint(result.stderr)
exit(1)

print("Output:")
# Prettify and print the stdout
print(result.stdout)

print("If the above says 'created', the job has been submitted.")

print(
f"If the above says 'job unchanged', the job with name {args.name} "
f"already exists (and you might need to delete it)."
)

print("\nThe following commands may come in handy:")
print(
f"runai exec {args.name} -it zsh # opens an interactive shell on the pod"
)
print(
f"runai delete job {args.name} # kills the job and removes it from the list of jobs"
)
print(
f"runai describe job {args.name} # shows information on the status/execution of the job"
)
print("runai list jobs # list all jobs and their status")
print(f"runai logs {args.name} # shows the output/logs for the job")

if args.follow:
print("Waiting for start...", end="", flush=True)
started = False
while not started:
time.sleep(1)
result = subprocess.run(
["kubectl", "apply", "-f", f.name],
# check=True,
["runai", "describe", "job", args.name, "--output", "json"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
latest_status = json.loads(result.stdout)["status"]
started = json.loads(result.stdout)["status"] not in {"Pending", "ContainerCreating"}
print(".", end="", flush=True)

# Check if there was an error
if result.returncode != 0:
print("Error encountered:")
# Prettify and print the stderr
pprint(result.stderr)
exit(1)
else:
print("Output:")
# Prettify and print the stdout
print(result.stdout)

print("If the above says 'created', the job has been submitted.")

print(
f"If the above says 'job unchanged', the job with name {args.name} "
f"already exists (and you might need to delete it)."
)

print("\nThe following commands may come in handy:")
print(
f"runai exec {args.name} -it zsh # opens an interactive shell on the pod"
)
print(
f"runai delete job {args.name} # kills the job and removes it from the list of jobs"
)
print(
f"runai describe job {args.name} # shows information on the status/execution of the job"
)
print("runai list jobs # list all jobs and their status")
print(f"runai logs {args.name} # shows the output/logs for the job")
print(latest_status, "\n=================================\n", flush=True)
following = subprocess.run(
["runai", "logs", args.name, "--follow"],
stdout=sys.stdout,
stderr=sys.stderr,
text=True,
)
exit(following.returncode)