Skip to content

Commit ccabc04

Browse files
authored
odsc-65115 : Transform columns names to make them compatible across models (#1046)
2 parents 221eb14 + fa5f53d commit ccabc04

File tree

1 file changed

+38
-3
lines changed

1 file changed

+38
-3
lines changed

‎ads/opctl/operator/lowcode/common/transformations.py

+38-3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
InvalidParameterError,
1616
)
1717
from ads.opctl.operator.lowcode.common.utils import merge_category_columns
18+
from ads.opctl.operator.lowcode.forecast.operator_config import ForecastOperatorSpec
1819

1920

2021
class Transformations(ABC):
@@ -34,6 +35,7 @@ def __init__(self, dataset_info, name="historical_data"):
3435
self.dataset_info = dataset_info
3536
self.target_category_columns = dataset_info.target_category_columns
3637
self.target_column_name = dataset_info.target_column
38+
self.raw_column_names = None
3739
self.dt_column_name = (
3840
dataset_info.datetime_column.name if dataset_info.datetime_column else None
3941
)
@@ -60,7 +62,8 @@ def run(self, data):
6062
6163
"""
6264
clean_df = self._remove_trailing_whitespace(data)
63-
# clean_df = self._normalize_column_names(clean_df)
65+
if isinstance(self.dataset_info, ForecastOperatorSpec):
66+
clean_df = self._clean_column_names(clean_df)
6467
if self.name == "historical_data":
6568
self._check_historical_dataset(clean_df)
6669
clean_df = self._set_series_id_column(clean_df)
@@ -98,8 +101,36 @@ def run(self, data):
98101
def _remove_trailing_whitespace(self, df):
99102
return df.apply(lambda x: x.str.strip() if x.dtype == "object" else x)
100103

101-
# def _normalize_column_names(self, df):
102-
# return df.rename(columns=lambda x: re.sub("[^A-Za-z0-9_]+", "", x))
104+
def _clean_column_names(self, df):
105+
"""
106+
Remove all whitespaces from column names in a DataFrame and store the original names.
107+
108+
Parameters:
109+
df (pd.DataFrame): The DataFrame whose column names need to be cleaned.
110+
111+
Returns:
112+
pd.DataFrame: The DataFrame with cleaned column names.
113+
"""
114+
115+
self.raw_column_names = {
116+
col: col.replace(" ", "") for col in df.columns if " " in col
117+
}
118+
df.columns = [self.raw_column_names.get(col, col) for col in df.columns]
119+
120+
if self.target_column_name:
121+
self.target_column_name = self.raw_column_names.get(
122+
self.target_column_name, self.target_column_name
123+
)
124+
self.dt_column_name = self.raw_column_names.get(
125+
self.dt_column_name, self.dt_column_name
126+
)
127+
128+
if self.target_category_columns:
129+
self.target_category_columns = [
130+
self.raw_column_names.get(col, col)
131+
for col in self.target_category_columns
132+
]
133+
return df
103134

104135
def _set_series_id_column(self, df):
105136
self._target_category_columns_map = {}
@@ -233,6 +264,10 @@ def _check_historical_dataset(self, df):
233264
expected_names = [self.target_column_name, self.dt_column_name] + (
234265
self.target_category_columns if self.target_category_columns else []
235266
)
267+
268+
if self.raw_column_names:
269+
expected_names.extend(list(self.raw_column_names.values()))
270+
236271
if set(df.columns) != set(expected_names):
237272
raise DataMismatchError(
238273
f"Expected {self.name} to have columns: {expected_names}, but instead found column names: {df.columns}. Is the {self.name} path correct?"

0 commit comments

Comments
 (0)