|
10 | 10 | import os
|
11 | 11 | import random
|
12 | 12 | import re
|
| 13 | +from datetime import datetime, timedelta |
13 | 14 | from functools import wraps
|
14 | 15 | from pathlib import Path
|
15 | 16 | from string import Template
|
16 | 17 | from typing import List, Union
|
17 | 18 |
|
18 | 19 | import fsspec
|
19 |
| -import oci |
20 |
| -from oci.data_science.models import JobRun, Model |
| 20 | +import ocifs |
| 21 | +from cachetools import TTLCache, cached |
21 | 22 |
|
| 23 | +import oci |
22 | 24 | from ads.aqua.common.enums import (
|
23 | 25 | InferenceContainerParamType,
|
24 | 26 | InferenceContainerType,
|
|
52 | 54 | from ads.common.utils import copy_file, get_console_link, upload_to_os
|
53 | 55 | from ads.config import AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, TENANCY_OCID
|
54 | 56 | from ads.model import DataScienceModel, ModelVersionSet
|
| 57 | +from oci.data_science.models import JobRun, Model |
55 | 58 |
|
56 | 59 | logger = logging.getLogger("ads.aqua")
|
57 | 60 |
|
@@ -228,6 +231,29 @@ def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
|
228 | 231 | return config
|
229 | 232 |
|
230 | 233 |
|
| 234 | +def list_os_files_with_extension(oss_path: str, extension: str) -> [str]: |
| 235 | + """ |
| 236 | + List files in the specified directory with the given extension. |
| 237 | +
|
| 238 | + Parameters: |
| 239 | + - oss_path: The path to the directory where files are located. |
| 240 | + - extension: The file extension to filter by (e.g., 'txt' for text files). |
| 241 | +
|
| 242 | + Returns: |
| 243 | + - A list of file paths matching the specified extension. |
| 244 | + """ |
| 245 | + |
| 246 | + signer = default_signer() |
| 247 | + |
| 248 | + # Ensure the extension is prefixed with a dot if not already |
| 249 | + if not extension.startswith("."): |
| 250 | + extension = "." + extension |
| 251 | + fs = ocifs.OCIFileSystem(**signer) |
| 252 | + files: [str] = fs.ls(oss_path) |
| 253 | + |
| 254 | + return [file for file in files if file.endswith(extension)] |
| 255 | + |
| 256 | + |
231 | 257 | def is_valid_ocid(ocid: str) -> bool:
|
232 | 258 | """Checks if the given ocid is valid.
|
233 | 259 |
|
@@ -503,6 +529,7 @@ def container_config_path():
|
503 | 529 | return f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config"
|
504 | 530 |
|
505 | 531 |
|
| 532 | +@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now)) |
506 | 533 | def get_container_config():
|
507 | 534 | config = load_config(
|
508 | 535 | file_path=container_config_path(),
|
|
0 commit comments