Skip to content

Commit 621971f

Browse files
JobDefinition env parameter to set pip index url and host
1 parent 745772e commit 621971f

File tree

3 files changed

+44
-2
lines changed

3 files changed

+44
-2
lines changed

Diff for: src/codeflare_sdk/job/jobs.py

+8
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from torchx.schedulers.ray_scheduler import RayScheduler
2323
from torchx.specs import AppHandle, parse_app_handle, AppDryRunInfo
2424

25+
from ..utils.generate_yaml import update_pip_requirements
26+
2527

2628
if TYPE_CHECKING:
2729
from ..cluster.cluster import Cluster
@@ -90,6 +92,12 @@ def __init__(
9092
)
9193
self.image = image
9294
self.workspace = workspace
95+
if 'PIP_INDEX_URL' in self.env or 'PIP_TRUSTED_HOST' in self.env:
96+
update_pip_requirements(self)
97+
else:
98+
self.env.setdefault('PIP_INDEX_URL', 'https://pypi.org/simple')
99+
self.env.setdefault('PIP_TRUSTED_HOST', 'pypi.org')
100+
update_pip_requirements(self)
93101

94102
def _dry_run(self, cluster: "Cluster"):
95103
j = f"{cluster.config.num_workers}x{max(cluster.config.num_gpus, 1)}" # # of proc. = # of gpus

Diff for: src/codeflare_sdk/utils/generate_yaml.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import sys
2222
import os
2323
import argparse
24+
from pathlib import Path
2425
import uuid
2526
from kubernetes import client, config
2627
from .kube_api_helpers import _kube_api_error_handling
@@ -607,7 +608,6 @@ def write_components(user_yaml: dict, output_file_name: str):
607608
)
608609
print(f"Written to: {output_file_name}")
609610

610-
611611
def generate_appwrapper(
612612
name: str,
613613
namespace: str,
@@ -689,3 +689,37 @@ def generate_appwrapper(
689689
else:
690690
write_user_appwrapper(user_yaml, outfile)
691691
return outfile
692+
693+
def update_pip_requirements(self):
694+
pip_index_url = self.env.get('PIP_INDEX_URL')
695+
pip_trusted_host = self.env.get('PIP_TRUSTED_HOST')
696+
requirements_path = Path('requirements.txt')
697+
698+
if requirements_path.exists():
699+
with requirements_path.open('r') as file:
700+
requirements = file.readlines()
701+
702+
# Check and replace or add --trusted-host and --index-url
703+
trusted_host = f"--trusted-host {pip_trusted_host}\n"
704+
index_url = f"--index-url {pip_index_url}\n"
705+
modified_requirements = []
706+
707+
for line in requirements:
708+
if line.startswith("--trusted-host"):
709+
modified_requirements.append(trusted_host)
710+
trusted_host = None
711+
elif line.startswith("--index-url"):
712+
modified_requirements.append(index_url)
713+
index_url = None
714+
else:
715+
modified_requirements.append(line)
716+
717+
# Append the lines if they were not replaced
718+
if index_url:
719+
modified_requirements.insert(0, index_url)
720+
if trusted_host:
721+
modified_requirements.insert(0, trusted_host)
722+
723+
# Write back the modified requirements
724+
with requirements_path.open('w') as file:
725+
file.writelines(modified_requirements)

Diff for: tests/unit_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2091,7 +2091,7 @@ def test_DDPJobDefinition_creation():
20912091
assert ddp.memMB == 1024
20922092
assert ddp.h == None
20932093
assert ddp.j == "2x1"
2094-
assert ddp.env == {"test": "test"}
2094+
assert ddp.env == {"PIP_INDEX_URL":"https://pypi.org/simple", "PIP_TRUSTED_HOST": "pypi.org", "test": "test"}
20952095
assert ddp.max_retries == 0
20962096
assert ddp.mounts == []
20972097
assert ddp.rdzv_port == 29500

0 commit comments

Comments
 (0)