Skip to content

Commit 1a81431

Browse files
authored
Refactor Container Index JSON Loader: Move to Config Package & Switch to Pydantic (#1092)
2 parents 7db15d0 + 34844b3 commit 1a81431

File tree

12 files changed

+312
-314
lines changed

12 files changed

+312
-314
lines changed

‎ads/aqua/common/enums.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,6 @@
22
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

5-
"""
6-
aqua.common.enums
7-
~~~~~~~~~~~~~~
8-
This module contains the set of enums used in AQUA.
9-
"""
10-
115
from ads.common.extended_enum import ExtendedEnum
126

137

@@ -88,7 +82,8 @@ class RqsAdditionalDetails(ExtendedEnum):
8882

8983
class TextEmbeddingInferenceContainerParams(ExtendedEnum):
9084
"""Contains a subset of params that are required for enabling model deployment in OCI Data Science. More options
91-
are available at https://huggingface.co/docs/text-embeddings-inference/en/cli_arguments"""
85+
are available at https://huggingface.co/docs/text-embeddings-inference/en/cli_arguments
86+
"""
9287

9388
MODEL_ID = "model-id"
9489
PORT = "port"
@@ -97,3 +92,14 @@ class TextEmbeddingInferenceContainerParams(ExtendedEnum):
9792
class ConfigFolder(ExtendedEnum):
9893
CONFIG = "config"
9994
ARTIFACT = "artifact"
95+
96+
97+
class ModelFormat(ExtendedEnum):
98+
GGUF = "GGUF"
99+
SAFETENSORS = "SAFETENSORS"
100+
UNKNOWN = "UNKNOWN"
101+
102+
103+
class Platform(ExtendedEnum):
104+
ARM_CPU = "ARM_CPU"
105+
NVIDIA_GPU = "NVIDIA_GPU"

‎ads/aqua/common/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def service_config_path():
553553
return f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config"
554554

555555

556-
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
556+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=10), timer=datetime.now))
557557
def get_container_config():
558558
config = load_config(
559559
file_path=service_config_path(),

‎ads/aqua/config/container_config.py

+198
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2025 Oracle and/or its affiliates.
3+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4+
5+
from typing import Dict, List, Optional
6+
7+
from pydantic import Field
8+
9+
from ads.aqua.common.entities import ContainerSpec
10+
from ads.aqua.config.utils.serializer import Serializable
11+
12+
13+
class AquaContainerConfigSpec(Serializable):
14+
"""
15+
Represents container specification details.
16+
17+
Attributes
18+
----------
19+
cli_param (Optional[str]): CLI parameter for container configuration.
20+
server_port (Optional[str]): The server port for the container.
21+
health_check_port (Optional[str]): The health check port for the container.
22+
env_vars (Optional[List[Dict]]): Environment variables for the container.
23+
restricted_params (Optional[List[str]]): Restricted parameters for container configuration.
24+
"""
25+
26+
cli_param: Optional[str] = Field(
27+
default=None, description="CLI parameter for container configuration."
28+
)
29+
server_port: Optional[str] = Field(
30+
default=None, description="Server port for the container."
31+
)
32+
health_check_port: Optional[str] = Field(
33+
default=None, description="Health check port for the container."
34+
)
35+
env_vars: Optional[List[Dict]] = Field(
36+
default_factory=list, description="List of environment variables."
37+
)
38+
restricted_params: Optional[List[str]] = Field(
39+
default_factory=list, description="List of restricted parameters."
40+
)
41+
42+
class Config:
43+
extra = "allow"
44+
45+
46+
class AquaContainerConfigItem(Serializable):
47+
"""
48+
Represents an item of the AQUA container configuration.
49+
50+
Attributes
51+
----------
52+
name (Optional[str]): Name of the container configuration item.
53+
version (Optional[str]): Version of the container.
54+
display_name (Optional[str]): Display name for UI.
55+
family (Optional[str]): Container family or category.
56+
platforms (Optional[List[str]]): Supported platforms.
57+
model_formats (Optional[List[str]]): Supported model formats.
58+
spec (Optional[AquaContainerConfigSpec]): Container specification details.
59+
"""
60+
61+
name: Optional[str] = Field(
62+
default=None, description="Name of the container configuration item."
63+
)
64+
version: Optional[str] = Field(
65+
default=None, description="Version of the container."
66+
)
67+
display_name: Optional[str] = Field(
68+
default=None, description="Display name of the container."
69+
)
70+
family: Optional[str] = Field(
71+
default=None, description="Container family or category."
72+
)
73+
platforms: Optional[List[str]] = Field(
74+
default_factory=list, description="Supported platforms."
75+
)
76+
model_formats: Optional[List[str]] = Field(
77+
default_factory=list, description="Supported model formats."
78+
)
79+
spec: Optional[AquaContainerConfigSpec] = Field(
80+
default_factory=AquaContainerConfigSpec,
81+
description="Detailed container specification.",
82+
)
83+
84+
class Config:
85+
extra = "allow"
86+
87+
88+
class AquaContainerConfig(Serializable):
89+
"""
90+
Represents a configuration of AQUA containers to be returned to the client.
91+
92+
Attributes
93+
----------
94+
inference (Dict[str, AquaContainerConfigItem]): Inference container configuration items.
95+
finetune (Dict[str, AquaContainerConfigItem]): Fine-tuning container configuration items.
96+
evaluate (Dict[str, AquaContainerConfigItem]): Evaluation container configuration items.
97+
"""
98+
99+
inference: Dict[str, AquaContainerConfigItem] = Field(
100+
default_factory=dict, description="Inference container configuration items."
101+
)
102+
finetune: Dict[str, AquaContainerConfigItem] = Field(
103+
default_factory=dict, description="Fine-tuning container configuration items."
104+
)
105+
evaluate: Dict[str, AquaContainerConfigItem] = Field(
106+
default_factory=dict, description="Evaluation container configuration items."
107+
)
108+
109+
def to_dict(self):
110+
return {
111+
"inference": list(self.inference.values()),
112+
"finetune": list(self.finetune.values()),
113+
"evaluate": list(self.evaluate.values()),
114+
}
115+
116+
@classmethod
117+
def from_container_index_json(
118+
cls,
119+
config: Dict,
120+
enable_spec: Optional[bool] = False,
121+
) -> "AquaContainerConfig":
122+
"""
123+
Creates an AquaContainerConfig instance from a container index JSON.
124+
125+
Parameters
126+
----------
127+
config (Optional[Dict]): The container index JSON.
128+
enable_spec (Optional[bool]): If True, fetch container specification details.
129+
130+
Returns
131+
-------
132+
AquaContainerConfig: The constructed container configuration.
133+
"""
134+
#TODO: Return this logic back if necessary in the next iteraion.
135+
# if not config:
136+
# config = get_container_config()
137+
138+
inference_items: Dict[str, AquaContainerConfigItem] = {}
139+
finetune_items: Dict[str, AquaContainerConfigItem] = {}
140+
evaluate_items: Dict[str, AquaContainerConfigItem] = {}
141+
142+
for container_type, containers in config.items():
143+
if isinstance(containers, list):
144+
for container in containers:
145+
platforms = container.get("platforms", [])
146+
model_formats = container.get("modelFormats", [])
147+
container_spec = (
148+
config.get(ContainerSpec.CONTAINER_SPEC, {}).get(
149+
container_type, {}
150+
)
151+
if enable_spec
152+
else None
153+
)
154+
container_item = AquaContainerConfigItem(
155+
name=container.get("name", ""),
156+
version=container.get("version", ""),
157+
display_name=container.get(
158+
"displayName", container.get("version", "")
159+
),
160+
family=container_type,
161+
platforms=platforms,
162+
model_formats=model_formats,
163+
spec=(
164+
AquaContainerConfigSpec(
165+
cli_param=container_spec.get(
166+
ContainerSpec.CLI_PARM, ""
167+
),
168+
server_port=container_spec.get(
169+
ContainerSpec.SERVER_PORT, ""
170+
),
171+
health_check_port=container_spec.get(
172+
ContainerSpec.HEALTH_CHECK_PORT, ""
173+
),
174+
env_vars=container_spec.get(ContainerSpec.ENV_VARS, []),
175+
restricted_params=container_spec.get(
176+
ContainerSpec.RESTRICTED_PARAMS, []
177+
),
178+
)
179+
if container_spec
180+
else None
181+
),
182+
)
183+
if container.get("type") == "inference":
184+
inference_items[container_type] = container_item
185+
elif (
186+
container.get("type") == "fine-tune"
187+
or container_type == "odsc-llm-fine-tuning"
188+
):
189+
finetune_items[container_type] = container_item
190+
elif (
191+
container.get("type") == "evaluate"
192+
or container_type == "odsc-llm-evaluate"
193+
):
194+
evaluate_items[container_type] = container_item
195+
196+
return cls(
197+
inference=inference_items, finetune=finetune_items, evaluate=evaluate_items
198+
)

‎ads/aqua/evaluation/evaluation.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,13 @@
4040
from ads.aqua.common.utils import (
4141
extract_id_and_name_from_tag,
4242
fire_and_forget,
43+
get_container_config,
4344
get_container_image,
4445
is_valid_ocid,
4546
upload_local_to_os,
4647
)
4748
from ads.aqua.config.config import get_evaluation_service_config
49+
from ads.aqua.config.container_config import AquaContainerConfig
4850
from ads.aqua.constants import (
4951
CONSOLE_LINK_RESOURCE_TYPE_MAPPING,
5052
EVALUATION_REPORT,
@@ -75,7 +77,6 @@
7577
CreateAquaEvaluationDetails,
7678
)
7779
from ads.aqua.evaluation.errors import EVALUATION_JOB_EXIT_CODE_MESSAGE
78-
from ads.aqua.ui import AquaContainerConfig
7980
from ads.common.auth import default_signer
8081
from ads.common.object_storage_details import ObjectStorageDetails
8182
from ads.common.utils import get_console_link, get_files, get_log_links
@@ -192,7 +193,7 @@ def create(
192193
evaluation_source.runtime.to_dict()
193194
)
194195
inference_config = AquaContainerConfig.from_container_index_json(
195-
enable_spec=True
196+
config=get_container_config(), enable_spec=True
196197
).inference
197198
for container in inference_config.values():
198199
if container.name == runtime.image[: runtime.image.rfind(":")]:

‎ads/aqua/extension/model_handler.py

+19-28
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,13 @@
88
from tornado.web import HTTPError
99

1010
from ads.aqua.common.decorator import handle_exceptions
11-
from ads.aqua.common.enums import (
12-
CustomInferenceContainerTypeFamily,
13-
)
11+
from ads.aqua.common.enums import CustomInferenceContainerTypeFamily
1412
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
15-
from ads.aqua.common.utils import (
16-
get_hf_model_info,
17-
is_valid_ocid,
18-
list_hf_models,
19-
)
13+
from ads.aqua.common.utils import get_hf_model_info, is_valid_ocid, list_hf_models
2014
from ads.aqua.extension.base_handler import AquaAPIhandler
2115
from ads.aqua.extension.errors import Errors
2216
from ads.aqua.model import AquaModelApp
2317
from ads.aqua.model.entities import AquaModelSummary, HFModelSummary
24-
from ads.aqua.ui import ModelFormat
2518

2619

2720
class AquaModelHandler(AquaAPIhandler):
@@ -44,26 +37,24 @@ def get(
4437
raise HTTPError(
4538
400, Errors.MISSING_REQUIRED_PARAMETER.format("model_format")
4639
)
47-
try:
48-
model_format = ModelFormat(model_format.upper())
49-
except ValueError as err:
50-
raise AquaValueError(f"Invalid model format: {model_format}") from err
40+
41+
model_format = model_format.upper()
42+
43+
if os_path:
44+
return self.finish(
45+
AquaModelApp.get_model_files(os_path, model_format)
46+
)
47+
elif model_name:
48+
return self.finish(
49+
AquaModelApp.get_hf_model_files(model_name, model_format)
50+
)
5151
else:
52-
if os_path:
53-
return self.finish(
54-
AquaModelApp.get_model_files(os_path, model_format)
55-
)
56-
elif model_name:
57-
return self.finish(
58-
AquaModelApp.get_hf_model_files(model_name, model_format)
59-
)
60-
else:
61-
raise HTTPError(
62-
400,
63-
Errors.MISSING_ONEOF_REQUIRED_PARAMETER.format(
64-
"os_path", "model_name"
65-
),
66-
)
52+
raise HTTPError(
53+
400,
54+
Errors.MISSING_ONEOF_REQUIRED_PARAMETER.format(
55+
"os_path", "model_name"
56+
),
57+
)
6758
elif not model_id:
6859
return self.list()
6960

‎ads/aqua/model/entities.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from ads.aqua.data import AquaResourceIdentifier
2424
from ads.aqua.model.enums import FineTuningDefinedMetadata
2525
from ads.aqua.training.exceptions import exit_code_dict
26-
from ads.aqua.ui import ModelFormat
2726
from ads.common.serializer import DataClassSerializable
2827
from ads.common.utils import get_log_links
2928
from ads.model.datascience_model import DataScienceModel
@@ -46,7 +45,7 @@ class AquaFineTuneValidation(DataClassSerializable):
4645
@dataclass(repr=False)
4746
class ModelValidationResult:
4847
model_file: Optional[str] = None
49-
model_formats: List[ModelFormat] = field(default_factory=list)
48+
model_formats: List[str] = field(default_factory=list)
5049
telemetry_model_name: str = None
5150
tags: Optional[dict] = None
5251

@@ -89,7 +88,7 @@ class AquaModelSummary(DataClassSerializable):
8988
nvidia_gpu_supported: bool = False
9089
arm_cpu_supported: bool = False
9190
model_file: Optional[str] = None
92-
model_formats: List[ModelFormat] = field(default_factory=list)
91+
model_formats: List[str] = field(default_factory=list)
9392

9493

9594
@dataclass(repr=False)

0 commit comments

Comments
 (0)