Skip to content

Commit 9df166e

Browse files
committed
Adding gguf validation and files endpoint
1 parent c267386 commit 9df166e

File tree

9 files changed

+376
-138
lines changed

9 files changed

+376
-138
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ repos:
4545
rev: v8.18.4
4646
hooks:
4747
- id: gitleaks
48-
exclude: .github/workflows/reusable-actions/set-dummy-conf.yml
48+
exclude: .github/workflows/reusable-actions/set-dummy-conf.yml|./tests/operators/common/test_load_data.py
4949
# Oracle copyright checker
5050
- repo: https://github.com/oracle-samples/oci-data-science-ai-samples/
5151
rev: 1bc5270a443b791c62f634233c0f4966dfcc0dd6

ads/aqua/common/enums.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ class Tags(str, metaclass=ExtendedEnumMeta):
2828
TASK = "task"
2929
LICENSE = "license"
3030
ORGANIZATION = "organization"
31-
PLATFORM = "platform"
3231
AQUA_TAG = "OCI_AQUA"
3332
AQUA_SERVICE_MODEL_TAG = "aqua_service_model"
3433
AQUA_FINE_TUNED_MODEL_TAG = "aqua_fine_tuned_model"
@@ -39,6 +38,7 @@ class Tags(str, metaclass=ExtendedEnumMeta):
3938
READY_TO_IMPORT = "ready_to_import"
4039
BASE_MODEL_CUSTOM = "aqua_custom_base_model"
4140
AQUA_EVALUATION_MODEL_ID = "evaluation_model_id"
41+
MODEL_FORMAT = "model_format"
4242

4343

4444
class InferenceContainerType(str, metaclass=ExtendedEnumMeta):

ads/aqua/common/utils.py

+29-2
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,17 @@
1010
import os
1111
import random
1212
import re
13+
from datetime import datetime, timedelta
1314
from functools import wraps
1415
from pathlib import Path
1516
from string import Template
1617
from typing import List, Union
1718

1819
import fsspec
19-
import oci
20-
from oci.data_science.models import JobRun, Model
20+
import ocifs
21+
from cachetools import TTLCache, cached
2122

23+
import oci
2224
from ads.aqua.common.enums import (
2325
InferenceContainerParamType,
2426
InferenceContainerType,
@@ -52,6 +54,7 @@
5254
from ads.common.utils import copy_file, get_console_link, upload_to_os
5355
from ads.config import AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, TENANCY_OCID
5456
from ads.model import DataScienceModel, ModelVersionSet
57+
from oci.data_science.models import JobRun, Model
5558

5659
logger = logging.getLogger("ads.aqua")
5760

@@ -228,6 +231,29 @@ def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
228231
return config
229232

230233

234+
def list_os_files_with_extension(oss_path: str, extension: str) -> [str]:
235+
"""
236+
List files in the specified directory with the given extension.
237+
238+
Parameters:
239+
- oss_path: The path to the directory where files are located.
240+
- extension: The file extension to filter by (e.g., 'txt' for text files).
241+
242+
Returns:
243+
- A list of file paths matching the specified extension.
244+
"""
245+
246+
signer = default_signer()
247+
248+
# Ensure the extension is prefixed with a dot if not already
249+
if not extension.startswith("."):
250+
extension = "." + extension
251+
fs = ocifs.OCIFileSystem(**signer)
252+
files: [str] = fs.ls(oss_path)
253+
254+
return [file for file in files if file.endswith(extension)]
255+
256+
231257
def is_valid_ocid(ocid: str) -> bool:
232258
"""Checks if the given ocid is valid.
233259
@@ -503,6 +529,7 @@ def container_config_path():
503529
return f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config"
504530

505531

532+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
506533
def get_container_config():
507534
config = load_config(
508535
file_path=container_config_path(),

ads/aqua/constants.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
DEFAULT_FT_REPLICA = 1
2222
DEFAULT_FT_BATCH_SIZE = 1
2323
DEFAULT_FT_VALIDATION_SET_SIZE = 0.1
24-
ARM_CPU="arm_cpu"
25-
NVIDIA_GPU="nvidia_gpu"
2624
MAXIMUM_ALLOWED_DATASET_IN_BYTE = 52428800 # 1024 x 1024 x 50 = 50MB
2725
JOB_INFRASTRUCTURE_TYPE_DEFAULT_NETWORKING = "ME_STANDALONE"
2826
NB_SESSION_IDENTIFIER = "NB_SESSION_OCID"
@@ -35,6 +33,7 @@
3533
AQUA_MODEL_ARTIFACT_CONFIG = "config.json"
3634
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME = "_name_or_path"
3735
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE = "model_type"
36+
AQUA_MODEL_ARTIFACT_FILE = "model_file"
3837

3938
TRAINING_METRICS_FINAL = "training_metrics_final"
4039
VALIDATION_METRICS_FINAL = "validation_metrics_final"

ads/aqua/extension/errors.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

@@ -8,3 +7,4 @@ class Errors(str):
87
INVALID_INPUT_DATA_FORMAT = "Invalid format of input data."
98
NO_INPUT_DATA = "No input data provided."
109
MISSING_REQUIRED_PARAMETER = "Missing required parameter: '{}'"
10+
MISSING_ONEOF_REQUIRED_PARAMETER = "Either '{}' or '{}' is required."

ads/aqua/extension/model_handler.py

+31-6
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,50 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

6-
import re
7-
from typing import Optional
85
from urllib.parse import urlparse
96

107
from tornado.web import HTTPError
11-
from ads.aqua.extension.errors import Errors
8+
129
from ads.aqua.common.decorator import handle_exceptions
10+
from ads.aqua.common.errors import AquaValueError
1311
from ads.aqua.extension.base_handler import AquaAPIhandler
12+
from ads.aqua.extension.errors import Errors
1413
from ads.aqua.model import AquaModelApp
14+
from ads.aqua.ui import ModelFormat
1515

1616

1717
class AquaModelHandler(AquaAPIhandler):
1818
"""Handler for Aqua Model REST APIs."""
1919

2020
@handle_exceptions
21-
def get(self, model_id=""):
21+
def get(
22+
self,
23+
model_id="",
24+
):
2225
"""Handle GET request."""
23-
if not model_id:
26+
url_parse = urlparse(self.request.path)
27+
paths = url_parse.path.strip("/")
28+
if paths.startswith("aqua/model/files"):
29+
os_path = self.get_argument("os_path")
30+
if not os_path:
31+
raise HTTPError(
32+
400, Errors.MISSING_REQUIRED_PARAMETER.format("os_path")
33+
)
34+
model_format = self.get_argument("model_format")
35+
if not model_format:
36+
raise HTTPError(
37+
400, Errors.MISSING_REQUIRED_PARAMETER.format("model_format")
38+
)
39+
try:
40+
model_format = ModelFormat(model_format.upper())
41+
except ValueError:
42+
raise AquaValueError(f"Invalid model format: {model_format}")
43+
else:
44+
return self.finish(AquaModelApp.get_model_files(os_path, model_format))
45+
elif not model_id:
2446
return self.list()
47+
2548
return self.read(model_id)
2649

2750
def read(self, model_id):
@@ -81,6 +104,7 @@ def post(self, *args, **kwargs):
81104
finetuning_container = input_data.get("finetuning_container")
82105
compartment_id = input_data.get("compartment_id")
83106
project_id = input_data.get("project_id")
107+
model_file = input_data.get("model_file")
84108

85109
return self.finish(
86110
AquaModelApp().register(
@@ -90,6 +114,7 @@ def post(self, *args, **kwargs):
90114
finetuning_container=finetuning_container,
91115
compartment_id=compartment_id,
92116
project_id=project_id,
117+
model_file=model_file,
93118
)
94119
)
95120

ads/aqua/model/entities.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
from typing import List, Optional
1515

1616
import oci
17-
1817
from ads.aqua import logger
1918
from ads.aqua.app import CLIBuilderMixin
2019
from ads.aqua.common import utils
2120
from ads.aqua.constants import LIFECYCLE_DETAILS_MISSING_JOBRUN, UNKNOWN_VALUE
2221
from ads.aqua.data import AquaResourceIdentifier
2322
from ads.aqua.model.enums import FineTuningDefinedMetadata
2423
from ads.aqua.training.exceptions import exit_code_dict
24+
from ads.aqua.ui import ModelFormat
2525
from ads.common.serializer import DataClassSerializable
2626
from ads.common.utils import get_log_links
2727
from ads.model.datascience_model import DataScienceModel
@@ -76,7 +76,9 @@ class AquaModelSummary(DataClassSerializable):
7676
ready_to_deploy: bool = True
7777
ready_to_finetune: bool = False
7878
ready_to_import: bool = False
79-
platform: List[str] = field(default_factory=lambda: ["nvidia_gpu"])
79+
nvidia_gpu_supported: bool = False
80+
arm_cpu_supported: bool = False
81+
model_format: ModelFormat = ModelFormat.UNKNOWN
8082

8183

8284
@dataclass(repr=False)
@@ -260,6 +262,7 @@ class ImportModelDetails(CLIBuilderMixin):
260262
finetuning_container: Optional[str] = None
261263
compartment_id: Optional[str] = None
262264
project_id: Optional[str] = None
265+
model_file: Optional[str] = None
263266

264267
def __post_init__(self):
265268
self._command = "model register"

0 commit comments

Comments
 (0)