|
17 | 17 | from ads.opctl.operator.lowcode.forecast.const import (
|
18 | 18 | AUTOMLX_METRIC_MAP,
|
19 | 19 | ForecastOutputColumns,
|
| 20 | +SpeedAccuracyMode, |
20 | 21 | SupportedModels,
|
21 | 22 | )
|
22 | 23 | from ads.opctl.operator.lowcode.forecast.utils import _label_encode_dataframe
|
@@ -241,18 +242,18 @@ def _generate_report(self):
|
241 | 242 | # If the key is present, call the "explain_model" method
|
242 | 243 | self.explain_model()
|
243 | 244 |
|
244 |
| -# Convert the global explanation data to a DataFrame |
245 |
| -global_explanation_df = pd.DataFrame(self.global_explanation) |
| 245 | +global_explanation_section = None |
| 246 | +if self.spec.explanations_accuracy_mode != SpeedAccuracyMode.AUTOMLX: |
| 247 | +# Convert the global explanation data to a DataFrame |
| 248 | +global_explanation_df = pd.DataFrame(self.global_explanation) |
246 | 249 |
|
247 |
| -self.formatted_global_explanation = ( |
248 |
| -global_explanation_df / global_explanation_df.sum(axis=0) * 100 |
249 |
| -) |
250 |
| -self.formatted_global_explanation = ( |
251 |
| -self.formatted_global_explanation.rename( |
| 250 | +self.formatted_global_explanation = ( |
| 251 | +global_explanation_df / global_explanation_df.sum(axis=0) * 100 |
| 252 | +) |
| 253 | +self.formatted_global_explanation = self.formatted_global_explanation.rename( |
252 | 254 | {self.spec.datetime_column.name: ForecastOutputColumns.DATE},
|
253 | 255 | axis=1,
|
254 | 256 | )
|
255 |
| -) |
256 | 257 |
|
257 | 258 | aggregate_local_explanations = pd.DataFrame()
|
258 | 259 | for s_id, local_ex_df in self.local_explanation.items():
|
@@ -293,8 +294,11 @@ def _generate_report(self):
|
293 | 294 | )
|
294 | 295 |
|
295 | 296 | # Append the global explanation text and section to the "other_sections" list
|
| 297 | +if global_explanation_section: |
| 298 | +other_sections.append(global_explanation_section) |
| 299 | + |
| 300 | +# Append the local explanation text and section to the "other_sections" list |
296 | 301 | other_sections = other_sections + [
|
297 |
| -global_explanation_section, |
298 | 302 | local_explanation_section,
|
299 | 303 | ]
|
300 | 304 | except Exception as e:
|
@@ -375,3 +379,79 @@ def _custom_predict_automlx(self, data):
|
375 | 379 | return self.models.get(self.series_id).forecast(
|
376 | 380 | X=data_temp, periods=data_temp.shape[0]
|
377 | 381 | )[self.series_id]
|
| 382 | + |
| 383 | +@runtime_dependency( |
| 384 | +module="automlx", |
| 385 | +err_msg=( |
| 386 | +"Please run `python3 -m pip install automlx` to install the required dependencies for model explanation." |
| 387 | +), |
| 388 | +) |
| 389 | +def explain_model(self): |
| 390 | +""" |
| 391 | +Generates explanations for the model using the AutoMLx library. |
| 392 | +
|
| 393 | +Parameters |
| 394 | +---------- |
| 395 | +None |
| 396 | +
|
| 397 | +Returns |
| 398 | +------- |
| 399 | +None |
| 400 | +
|
| 401 | +Notes |
| 402 | +----- |
| 403 | +This function works by generating local explanations for each series in the dataset. |
| 404 | +It uses the ``MLExplainer`` class from the AutoMLx library to generate feature attributions |
| 405 | +for each series. The feature attributions are then stored in the ``self.local_explanation`` dictionary. |
| 406 | +
|
| 407 | +If the accuracy mode is set to AutoMLX, it uses the AutoMLx library to generate explanations. |
| 408 | +Otherwise, it falls back to the default explanation generation method. |
| 409 | +""" |
| 410 | +import automlx |
| 411 | + |
| 412 | +# Loop through each series in the dataset |
| 413 | +for s_id, data_i in self.datasets.get_data_by_series( |
| 414 | +include_horizon=False |
| 415 | +).items(): |
| 416 | +try: |
| 417 | +if self.spec.explanations_accuracy_mode == SpeedAccuracyMode.AUTOMLX: |
| 418 | +# Use the MLExplainer class from AutoMLx to generate explanations |
| 419 | +explainer = automlx.MLExplainer( |
| 420 | +self.models[s_id], |
| 421 | +self.datasets.additional_data.get_data_for_series(series_id=s_id) |
| 422 | +.drop(self.spec.datetime_column.name, axis=1) |
| 423 | +.head(-self.spec.horizon) |
| 424 | +if self.spec.additional_data |
| 425 | +else None, |
| 426 | +pd.DataFrame(data_i[self.spec.target_column]), |
| 427 | +task="forecasting", |
| 428 | +) |
| 429 | + |
| 430 | +# Generate explanations for the forecast |
| 431 | +explanations = explainer.explain_prediction( |
| 432 | +X=self.datasets.additional_data.get_data_for_series(series_id=s_id) |
| 433 | +.drop(self.spec.datetime_column.name, axis=1) |
| 434 | +.tail(self.spec.horizon) |
| 435 | +if self.spec.additional_data |
| 436 | +else None, |
| 437 | +forecast_timepoints=list(range(self.spec.horizon + 1)), |
| 438 | +) |
| 439 | + |
| 440 | +# Convert the explanations to a DataFrame |
| 441 | +explanations_df = pd.concat( |
| 442 | +[exp.to_dataframe() for exp in explanations] |
| 443 | +) |
| 444 | +explanations_df["row"] = explanations_df.groupby("Feature").cumcount() |
| 445 | +explanations_df = explanations_df.pivot( |
| 446 | +index="row", columns="Feature", values="Attribution" |
| 447 | +) |
| 448 | +explanations_df = explanations_df.reset_index(drop=True) |
| 449 | + |
| 450 | +# Store the explanations in the local_explanation dictionary |
| 451 | +self.local_explanation[s_id] = explanations_df |
| 452 | +else: |
| 453 | +# Fall back to the default explanation generation method |
| 454 | +super().explain_model() |
| 455 | +except Exception as e: |
| 456 | +logger.warning(f"Failed to generate explanations for series {s_id} with error: {e}.") |
| 457 | +logger.debug(f"Full Traceback: {traceback.format_exc()}") |
0 commit comments