Skip to content

Commit 3e3a2a0

Browse files
authored
AutoMLx internal explainability mode (#1025)
2 parents b643faa + 6cfb305 commit 3e3a2a0

File tree

5 files changed

+138
-28
lines changed

5 files changed

+138
-28
lines changed

‎ads/opctl/operator/lowcode/forecast/const.py

+2
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@ class SpeedAccuracyMode(str, metaclass=ExtendedEnumMeta):
2727
HIGH_ACCURACY = "HIGH_ACCURACY"
2828
BALANCED = "BALANCED"
2929
FAST_APPROXIMATE = "FAST_APPROXIMATE"
30+
AUTOMLX = "AUTOMLX"
3031
ratio = {}
3132
ratio[HIGH_ACCURACY] = 1 # 100 % data used for generating explanations
3233
ratio[BALANCED] = 0.5 # 50 % data used for generating explanations
3334
ratio[FAST_APPROXIMATE] = 0 # constant
35+
ratio[AUTOMLX] = 0 # constant
3436

3537

3638
class SupportedMetrics(str, metaclass=ExtendedEnumMeta):

‎ads/opctl/operator/lowcode/forecast/model/automlx.py

+89-9
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from ads.opctl.operator.lowcode.forecast.const import (
1818
AUTOMLX_METRIC_MAP,
1919
ForecastOutputColumns,
20+
SpeedAccuracyMode,
2021
SupportedModels,
2122
)
2223
from ads.opctl.operator.lowcode.forecast.utils import _label_encode_dataframe
@@ -241,18 +242,18 @@ def _generate_report(self):
241242
# If the key is present, call the "explain_model" method
242243
self.explain_model()
243244

244-
# Convert the global explanation data to a DataFrame
245-
global_explanation_df = pd.DataFrame(self.global_explanation)
245+
global_explanation_section = None
246+
if self.spec.explanations_accuracy_mode != SpeedAccuracyMode.AUTOMLX:
247+
# Convert the global explanation data to a DataFrame
248+
global_explanation_df = pd.DataFrame(self.global_explanation)
246249

247-
self.formatted_global_explanation = (
248-
global_explanation_df / global_explanation_df.sum(axis=0) * 100
249-
)
250-
self.formatted_global_explanation = (
251-
self.formatted_global_explanation.rename(
250+
self.formatted_global_explanation = (
251+
global_explanation_df / global_explanation_df.sum(axis=0) * 100
252+
)
253+
self.formatted_global_explanation = self.formatted_global_explanation.rename(
252254
{self.spec.datetime_column.name: ForecastOutputColumns.DATE},
253255
axis=1,
254256
)
255-
)
256257

257258
aggregate_local_explanations = pd.DataFrame()
258259
for s_id, local_ex_df in self.local_explanation.items():
@@ -293,8 +294,11 @@ def _generate_report(self):
293294
)
294295

295296
# Append the global explanation text and section to the "other_sections" list
297+
if global_explanation_section:
298+
other_sections.append(global_explanation_section)
299+
300+
# Append the local explanation text and section to the "other_sections" list
296301
other_sections = other_sections + [
297-
global_explanation_section,
298302
local_explanation_section,
299303
]
300304
except Exception as e:
@@ -375,3 +379,79 @@ def _custom_predict_automlx(self, data):
375379
return self.models.get(self.series_id).forecast(
376380
X=data_temp, periods=data_temp.shape[0]
377381
)[self.series_id]
382+
383+
@runtime_dependency(
384+
module="automlx",
385+
err_msg=(
386+
"Please run `python3 -m pip install automlx` to install the required dependencies for model explanation."
387+
),
388+
)
389+
def explain_model(self):
390+
"""
391+
Generates explanations for the model using the AutoMLx library.
392+
393+
Parameters
394+
----------
395+
None
396+
397+
Returns
398+
-------
399+
None
400+
401+
Notes
402+
-----
403+
This function works by generating local explanations for each series in the dataset.
404+
It uses the ``MLExplainer`` class from the AutoMLx library to generate feature attributions
405+
for each series. The feature attributions are then stored in the ``self.local_explanation`` dictionary.
406+
407+
If the accuracy mode is set to AutoMLX, it uses the AutoMLx library to generate explanations.
408+
Otherwise, it falls back to the default explanation generation method.
409+
"""
410+
import automlx
411+
412+
# Loop through each series in the dataset
413+
for s_id, data_i in self.datasets.get_data_by_series(
414+
include_horizon=False
415+
).items():
416+
try:
417+
if self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX:
418+
# Use the MLExplainer class from AutoMLx to generate explanations
419+
explainer = automlx.MLExplainer(
420+
self.models[s_id],
421+
self.datasets.additional_data.get_data_for_series(series_id=s_id)
422+
.drop(self.spec.datetime_column.name, axis=1)
423+
.head(-self.spec.horizon)
424+
if self.spec.additional_data
425+
else None,
426+
pd.DataFrame(data_i[self.spec.target_column]),
427+
task="forecasting",
428+
)
429+
430+
# Generate explanations for the forecast
431+
explanations = explainer.explain_prediction(
432+
X=self.datasets.additional_data.get_data_for_series(series_id=s_id)
433+
.drop(self.spec.datetime_column.name, axis=1)
434+
.tail(self.spec.horizon)
435+
if self.spec.additional_data
436+
else None,
437+
forecast_timepoints=list(range(self.spec.horizon + 1)),
438+
)
439+
440+
# Convert the explanations to a DataFrame
441+
explanations_df = pd.concat(
442+
[exp.to_dataframe() for exp in explanations]
443+
)
444+
explanations_df["row"] = explanations_df.groupby("Feature").cumcount()
445+
explanations_df = explanations_df.pivot(
446+
index="row", columns="Feature", values="Attribution"
447+
)
448+
explanations_df = explanations_df.reset_index(drop=True)
449+
450+
# Store the explanations in the local_explanation dictionary
451+
self.local_explanation[s_id] = explanations_df
452+
else:
453+
# Fall back to the default explanation generation method
454+
super().explain_model()
455+
except Exception as e:
456+
logger.warning(f"Failed to generate explanations for series {s_id} with error: {e}.")
457+
logger.debug(f"Full Traceback: {traceback.format_exc()}")

‎ads/opctl/operator/lowcode/forecast/model/base_model.py

+36-6
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
SpeedAccuracyMode,
4949
SupportedMetrics,
5050
SupportedModels,
51-
BACKTEST_REPORT_NAME
51+
BACKTEST_REPORT_NAME,
5252
)
5353
from ..operator_config import ForecastOperatorConfig, ForecastOperatorSpec
5454
from .forecast_datasets import ForecastDatasets
@@ -266,7 +266,11 @@ def generate_report(self):
266266
output_dir = self.spec.output_directory.url
267267
file_path = f"{output_dir}/{BACKTEST_REPORT_NAME}"
268268
if self.spec.model == AUTO_SELECT:
269-
backtest_sections.append(rc.Heading("Auto-Select Backtesting and Performance Metrics", level=2))
269+
backtest_sections.append(
270+
rc.Heading(
271+
"Auto-Select Backtesting and Performance Metrics", level=2
272+
)
273+
)
270274
if not os.path.exists(file_path):
271275
failure_msg = rc.Text(
272276
"auto-select could not be executed. Please check the "
@@ -275,15 +279,23 @@ def generate_report(self):
275279
backtest_sections.append(failure_msg)
276280
else:
277281
backtest_stats = pd.read_csv(file_path)
278-
model_metric_map = backtest_stats.drop(columns=['metric', 'backtest'])
279-
average_dict = {k: round(v, 4) for k, v in model_metric_map.mean().to_dict().items()}
282+
model_metric_map = backtest_stats.drop(
283+
columns=["metric", "backtest"]
284+
)
285+
average_dict = {
286+
k: round(v, 4)
287+
for k, v in model_metric_map.mean().to_dict().items()
288+
}
280289
best_model = min(average_dict, key=average_dict.get)
281290
summary_text = rc.Text(
282291
f"Overall, the average {self.spec.metric} scores for the models are {average_dict}, with"
283-
f" {best_model} being identified as the top-performing model during backtesting.")
292+
f" {best_model} being identified as the top-performing model during backtesting."
293+
)
284294
backtest_table = rc.DataTable(backtest_stats, index=True)
285295
liner_plot = get_auto_select_plot(backtest_stats)
286-
backtest_sections.extend([backtest_table, summary_text, liner_plot])
296+
backtest_sections.extend(
297+
[backtest_table, summary_text, liner_plot]
298+
)
287299

288300
forecast_plots = []
289301
if len(self.forecast_output.list_series_ids()) > 0:
@@ -646,6 +658,13 @@ def _save_model(self, output_dir, storage_options):
646658
storage_options=storage_options,
647659
)
648660

661+
def _validate_automlx_explanation_mode(self):
662+
if self.spec.model != SupportedModels.AutoMLX and self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX:
663+
raise ValueError(
664+
"AUTOMLX explanation accuracy mode is only supported for AutoMLX models. "
665+
"Please select mode other than AUTOMLX from the available explanations_accuracy_mode options"
666+
)
667+
649668
@runtime_dependency(
650669
module="shap",
651670
err_msg=(
@@ -674,6 +693,9 @@ def explain_model(self):
674693
)
675694
ratio = SpeedAccuracyMode.ratio[self.spec.explanations_accuracy_mode]
676695

696+
# validate the automlx mode is use for automlx model
697+
self._validate_automlx_explanation_mode()
698+
677699
for s_id, data_i in self.datasets.get_data_by_series(
678700
include_horizon=False
679701
).items():
@@ -708,6 +730,14 @@ def explain_model(self):
708730
logger.warn(
709731
"No explanations generated. Ensure that additional data has been provided."
710732
)
733+
elif (
734+
self.spec.model == SupportedModels.AutoMLX
735+
and self.spec.explanations_accuracy_mode
736+
== SpeedAccuracyMode.AUTOMLX
737+
):
738+
logger.warning(
739+
"Global explanations not available for AutoMLX models with inherent explainability"
740+
)
711741
else:
712742
self.global_explanation[s_id] = dict(
713743
zip(

‎ads/opctl/operator/lowcode/forecast/schema.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ spec:
332332
- HIGH_ACCURACY
333333
- BALANCED
334334
- FAST_APPROXIMATE
335+
- AUTOMLX
335336

336337
generate_report:
337338
type: boolean

‎tests/operators/forecast/test_errors.py

+10-13
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,8 @@ def test_all_series_failure(model):
591591
yaml_i["spec"]["preprocessing"] = {"enabled": True, "steps": preprocessing_steps}
592592
if yaml_i["spec"].get("additional_data") is not None and model != "autots":
593593
yaml_i["spec"]["generate_explanations"] = True
594+
else:
595+
yaml_i["spec"]["generate_explanations"] = False
594596
if model == "autots":
595597
yaml_i["spec"]["model_kwargs"] = {"model_list": "superfast"}
596598
if model == "automlx":
@@ -672,6 +674,7 @@ def test_arima_automlx_errors(operator_setup, model):
672674
yaml_i["spec"]["model_kwargs"] = {"model_list": "superfast"}
673675
if model == "automlx":
674676
yaml_i["spec"]["model_kwargs"] = {"time_budget": 1}
677+
yaml_i["spec"]["explanations_accuracy_mode"] = "AUTOMLX"
675678

676679
run_yaml(
677680
tmpdirname=tmpdirname,
@@ -699,21 +702,15 @@ def test_arima_automlx_errors(operator_setup, model):
699702
in error_content["13"]["error"]
700703
), "Error message mismatch"
701704

702-
if model not in ["autots", "automlx"]: # , "lgbforecast"
703-
global_fn = f"{tmpdirname}/results/global_explanation.csv"
704-
assert os.path.exists(
705-
global_fn
706-
), f"Global explanation file not found at {report_path}"
705+
if model not in ["autots"]: # , "lgbforecast"
706+
if yaml_i["spec"].get("explanations_accuracy_mode") != "AUTOMLX":
707+
global_fn = f"{tmpdirname}/results/global_explanation.csv"
708+
assert os.path.exists(global_fn), f"Global explanation file not found at {report_path}"
709+
assert not pd.read_csv(global_fn, index_col=0).empty
707710

708711
local_fn = f"{tmpdirname}/results/local_explanation.csv"
709-
assert os.path.exists(
710-
local_fn
711-
), f"Local explanation file not found at {report_path}"
712-
713-
glb_expl = pd.read_csv(global_fn, index_col=0)
714-
loc_expl = pd.read_csv(local_fn)
715-
assert not glb_expl.empty
716-
assert not loc_expl.empty
712+
assert os.path.exists(local_fn), f"Local explanation file not found at {report_path}"
713+
assert not pd.read_csv(local_fn).empty
717714

718715

719716
def test_smape_error():

0 commit comments

Comments
 (0)