13
13
import ads
14
14
from ads .opctl .operator .lowcode .common .utils import load_data
15
15
from ads .opctl .operator .common .operator_config import InputData
16
+ from ads .opctl .operator .lowcode .forecast .const import SupportedModels
16
17
17
18
ads .set_auth ("resource_principal" )
18
19
26
27
Inference script. This script is used for prediction by scoring server when schema is known.
27
28
"""
28
29
29
- AUTOTS = "autots"
30
- ARIMA = "arima"
31
- PROPHET = "prophet"
32
- NEURALPROPHET = "neuralprophet"
33
- AUTOMLX = "automlx"
34
-
35
30
36
31
@lru_cache (maxsize = 10 )
37
32
def load_model ():
@@ -142,7 +137,7 @@ def get_forecast(future_df, model_name, series_id, model_object, date_col, targe
142
137
future_df [date_col_name ] = pd .to_datetime (
143
138
future_df [date_col_name ], format = date_col_format
144
139
)
145
- if model_name == AUTOTS :
140
+ if model_name == SupportedModels . AutoTS :
146
141
series_id_col = "Series"
147
142
full_data_indexed = future_df .rename (columns = {target_cat_col : series_id_col })
148
143
additional_regressors = list (
@@ -155,12 +150,12 @@ def get_forecast(future_df, model_name, series_id, model_object, date_col, targe
155
150
)
156
151
pred_obj = model_object .predict (future_regressor = future_reg )
157
152
return pred_obj .forecast [series_id ].tolist ()
158
- elif model_name == PROPHET and series_id in model_object :
153
+ elif model_name == SupportedModels . Prophet and series_id in model_object :
159
154
model = model_object [series_id ]
160
155
processed = future_df .rename (columns = {date_col_name : 'ds' , target_column : 'y' })
161
156
forecast = model .predict (processed )
162
157
return forecast ['yhat' ].tolist ()
163
- elif model_name == NEURALPROPHET and series_id in model_object :
158
+ elif model_name == SupportedModels . NeuralProphet and series_id in model_object :
164
159
model = model_object [series_id ]
165
160
model .restore_trainer ()
166
161
accepted_regressors = list (model .config_regressors .regressors .keys ())
@@ -169,7 +164,7 @@ def get_forecast(future_df, model_name, series_id, model_object, date_col, targe
169
164
future ["y" ] = None
170
165
forecast = model .predict (future )
171
166
return forecast ['yhat1' ].tolist ()
172
- elif model_name == ARIMA and series_id in model_object :
167
+ elif model_name == SupportedModels . Arima and series_id in model_object :
173
168
model = model_object [series_id ]
174
169
future_df = future_df .set_index (date_col_name )
175
170
x_pred = future_df .drop (target_cat_col , axis = 1 )
@@ -180,7 +175,7 @@ def get_forecast(future_df, model_name, series_id, model_object, date_col, targe
180
175
)
181
176
yhat_clean = pd .DataFrame (yhat , index = yhat .index , columns = ["yhat" ])
182
177
return yhat_clean ['yhat' ].tolist ()
183
- elif model_name == AUTOMLX and series_id in model_object :
178
+ elif model_name == SupportedModels . AutoMLX and series_id in model_object :
184
179
# automlx model
185
180
model = model_object [series_id ]
186
181
x_pred = future_df .drop (target_cat_col , axis = 1 )
0 commit comments