Skip to content

Commit 898111e

Browse files
committed
minor updates and refinements
1 parent 4eb3389 commit 898111e

File tree

4 files changed

+33
-31
lines changed

4 files changed

+33
-31
lines changed

‎ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,7 @@ def get_data_multi_indexed(self):
167167
self.historical_data.data,
168168
self.additional_data.data,
169169
],
170-
axis=1,
171-
join='inner'
170+
axis=1
172171
)
173172

174173
def get_data_by_series(self, include_horizon=True):

‎ads/opctl/operator/lowcode/forecast/schema.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -348,9 +348,9 @@ spec:
348348
model_deployment:
349349
type: dict
350350
required: false
351-
meta: "If model_deployment_id is not specified, a new model deployment is created; otherwise, the model is linked to the specified model deployment."
351+
meta: "If model_deployment id is not specified, a new model deployment is created; otherwise, the model is linked to the specified model deployment."
352352
schema:
353-
model_deployment_id:
353+
id:
354354
type: string
355355
required: false
356356
display_name:

‎ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py

+24-16
Original file line numberDiff line numberDiff line change
@@ -53,22 +53,30 @@ def _sanity_test(self):
5353
"""
5454
Function perform sanity test for saved artifact
5555
"""
56-
sys.path.insert(0, f"{self.path_to_artifact}")
57-
from score import load_model, predict
58-
_ = load_model()
59-
60-
# Write additional data to tmp file and perform sanity check
61-
with tempfile.NamedTemporaryFile(suffix='.csv') as temp_file:
62-
one_series = next(iter(self.additional_data))
63-
sample_prediction_data = self.additional_data[one_series].tail(self.horizon)
64-
sample_prediction_data[self.spec.target_category_columns[0]] = one_series
65-
date_col_name = self.spec.datetime_column.name
66-
date_col_format = self.spec.datetime_column.format
67-
sample_prediction_data[date_col_name] = sample_prediction_data[date_col_name].dt.strftime(date_col_format)
68-
sample_prediction_data.to_csv(temp_file.name, index=False)
69-
input_data = {"additional_data": {"url": temp_file.name}}
70-
prediction_test = predict(input_data, _)
71-
logger.info(f"prediction test completed with result :{prediction_test}")
56+
org_sys_path = sys.path[:]
57+
try:
58+
sys.path.insert(0, f"{self.path_to_artifact}")
59+
from score import load_model, predict
60+
_ = load_model()
61+
62+
# Write additional data to tmp file and perform sanity check
63+
with tempfile.NamedTemporaryFile(suffix='.csv') as temp_file:
64+
one_series = next(iter(self.additional_data))
65+
sample_prediction_data = self.additional_data[one_series].tail(self.horizon)
66+
sample_prediction_data[self.spec.target_category_columns[0]] = one_series
67+
date_col_name = self.spec.datetime_column.name
68+
date_col_format = self.spec.datetime_column.format
69+
sample_prediction_data[date_col_name] = sample_prediction_data[date_col_name].dt.strftime(
70+
date_col_format)
71+
sample_prediction_data.to_csv(temp_file.name, index=False)
72+
input_data = {"additional_data": {"url": temp_file.name}}
73+
prediction_test = predict(input_data, _)
74+
logger.info(f"prediction test completed with result :{prediction_test}")
75+
except Exception as e:
76+
logger.error(f"An error occurred during the sanity test: {e}")
77+
raise
78+
finally:
79+
sys.path = org_sys_path
7280

7381
def _copy_score_file(self):
7482
"""

‎ads/opctl/operator/lowcode/forecast/whatifserve/score.py

+6-11
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import ads
1414
from ads.opctl.operator.lowcode.common.utils import load_data
1515
from ads.opctl.operator.common.operator_config import InputData
16+
from ads.opctl.operator.lowcode.forecast.const import SupportedModels
1617

1718
ads.set_auth("resource_principal")
1819

@@ -26,12 +27,6 @@
2627
Inference script. This script is used for prediction by scoring server when schema is known.
2728
"""
2829

29-
AUTOTS = "autots"
30-
ARIMA = "arima"
31-
PROPHET = "prophet"
32-
NEURALPROPHET = "neuralprophet"
33-
AUTOMLX = "automlx"
34-
3530

3631
@lru_cache(maxsize=10)
3732
def load_model():
@@ -142,7 +137,7 @@ def get_forecast(future_df, model_name, series_id, model_object, date_col, targe
142137
future_df[date_col_name] = pd.to_datetime(
143138
future_df[date_col_name], format=date_col_format
144139
)
145-
if model_name == AUTOTS:
140+
if model_name == SupportedModels.AutoTS:
146141
series_id_col = "Series"
147142
full_data_indexed = future_df.rename(columns={target_cat_col: series_id_col})
148143
additional_regressors = list(
@@ -155,12 +150,12 @@ def get_forecast(future_df, model_name, series_id, model_object, date_col, targe
155150
)
156151
pred_obj = model_object.predict(future_regressor=future_reg)
157152
return pred_obj.forecast[series_id].tolist()
158-
elif model_name == PROPHET and series_id in model_object:
153+
elif model_name == SupportedModels.Prophet and series_id in model_object:
159154
model = model_object[series_id]
160155
processed = future_df.rename(columns={date_col_name: 'ds', target_column: 'y'})
161156
forecast = model.predict(processed)
162157
return forecast['yhat'].tolist()
163-
elif model_name == NEURALPROPHET and series_id in model_object:
158+
elif model_name == SupportedModels.NeuralProphet and series_id in model_object:
164159
model = model_object[series_id]
165160
model.restore_trainer()
166161
accepted_regressors = list(model.config_regressors.regressors.keys())
@@ -169,7 +164,7 @@ def get_forecast(future_df, model_name, series_id, model_object, date_col, targe
169164
future["y"] = None
170165
forecast = model.predict(future)
171166
return forecast['yhat1'].tolist()
172-
elif model_name == ARIMA and series_id in model_object:
167+
elif model_name == SupportedModels.Arima and series_id in model_object:
173168
model = model_object[series_id]
174169
future_df = future_df.set_index(date_col_name)
175170
x_pred = future_df.drop(target_cat_col, axis=1)
@@ -180,7 +175,7 @@ def get_forecast(future_df, model_name, series_id, model_object, date_col, targe
180175
)
181176
yhat_clean = pd.DataFrame(yhat, index=yhat.index, columns=["yhat"])
182177
return yhat_clean['yhat'].tolist()
183-
elif model_name == AUTOMLX and series_id in model_object:
178+
elif model_name == SupportedModels.AutoMLX and series_id in model_object:
184179
# automlx model
185180
model = model_object[series_id]
186181
x_pred = future_df.drop(target_cat_col, axis=1)

0 commit comments

Comments
 (0)