15
15
InvalidParameterError ,
16
16
)
17
17
from ads .opctl .operator .lowcode .common .utils import merge_category_columns
18
+ from ads .opctl .operator .lowcode .forecast .operator_config import ForecastOperatorSpec
18
19
19
20
20
21
class Transformations (ABC ):
@@ -34,6 +35,7 @@ def __init__(self, dataset_info, name="historical_data"):
34
35
self .dataset_info = dataset_info
35
36
self .target_category_columns = dataset_info .target_category_columns
36
37
self .target_column_name = dataset_info .target_column
38
+ self .raw_column_names = None
37
39
self .dt_column_name = (
38
40
dataset_info .datetime_column .name if dataset_info .datetime_column else None
39
41
)
@@ -60,7 +62,8 @@ def run(self, data):
60
62
61
63
"""
62
64
clean_df = self ._remove_trailing_whitespace (data )
63
- # clean_df = self._normalize_column_names(clean_df)
65
+ if isinstance (self .dataset_info , ForecastOperatorSpec ):
66
+ clean_df = self ._clean_column_names (clean_df )
64
67
if self .name == "historical_data" :
65
68
self ._check_historical_dataset (clean_df )
66
69
clean_df = self ._set_series_id_column (clean_df )
@@ -98,8 +101,36 @@ def run(self, data):
98
101
def _remove_trailing_whitespace (self , df ):
99
102
return df .apply (lambda x : x .str .strip () if x .dtype == "object" else x )
100
103
101
- # def _normalize_column_names(self, df):
102
- # return df.rename(columns=lambda x: re.sub("[^A-Za-z0-9_]+", "", x))
104
+ def _clean_column_names (self , df ):
105
+ """
106
+ Remove all whitespaces from column names in a DataFrame and store the original names.
107
+
108
+ Parameters:
109
+ df (pd.DataFrame): The DataFrame whose column names need to be cleaned.
110
+
111
+ Returns:
112
+ pd.DataFrame: The DataFrame with cleaned column names.
113
+ """
114
+
115
+ self .raw_column_names = {
116
+ col : col .replace (" " , "" ) for col in df .columns if " " in col
117
+ }
118
+ df .columns = [self .raw_column_names .get (col , col ) for col in df .columns ]
119
+
120
+ if self .target_column_name :
121
+ self .target_column_name = self .raw_column_names .get (
122
+ self .target_column_name , self .target_column_name
123
+ )
124
+ self .dt_column_name = self .raw_column_names .get (
125
+ self .dt_column_name , self .dt_column_name
126
+ )
127
+
128
+ if self .target_category_columns :
129
+ self .target_category_columns = [
130
+ self .raw_column_names .get (col , col )
131
+ for col in self .target_category_columns
132
+ ]
133
+ return df
103
134
104
135
def _set_series_id_column (self , df ):
105
136
self ._target_category_columns_map = {}
@@ -233,6 +264,10 @@ def _check_historical_dataset(self, df):
233
264
expected_names = [self .target_column_name , self .dt_column_name ] + (
234
265
self .target_category_columns if self .target_category_columns else []
235
266
)
267
+
268
+ if self .raw_column_names :
269
+ expected_names .extend (list (self .raw_column_names .values ()))
270
+
236
271
if set (df .columns ) != set (expected_names ):
237
272
raise DataMismatchError (
238
273
f"Expected { self .name } to have columns: { expected_names } , but instead found column names: { df .columns } . Is the { self .name } path correct?"
0 commit comments