From 2ae855aa2612a1901b9e4a95792da18988be294d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Assun=C3=A7ao=20Jeshon?= Date: Tue, 2 Oct 2018 07:44:02 +0200 Subject: [PATCH] Fix bug for labelised/splitted datas --- load_data.py | 79 +++++++++++++++++++++++++++++++--------------------- 1 file changed, 47 insertions(+), 32 deletions(-) diff --git a/load_data.py b/load_data.py index 179ce8f..13180bb 100644 --- a/load_data.py +++ b/load_data.py @@ -140,36 +140,51 @@ def get_splitted_datas(df, training_size=4, validation_size=1, test_size=1): def set_labels_Y(df): - len_df = len(df) - - # Sort the DF by time - df.sort_values(by=['datetime'], inplace=True) - - # Create labels array - good_array = [1] * (len_df // 3) # 1 = Good quality - middle_array = [2] * (len_df // 3) # 2 = Middle quality - bad_array = [3] * (len_df // 3) # 3 = Bad quality - - missing_datas_nb = len_df - ((len_df // 3) * 3) - good_array = good_array + ([1] * missing_datas_nb) - - # Insert labels to df - df['quality'] = good_array + middle_array + bad_array - - return df - - -def main(): - batteries_to_keep = [25, 26, 27, 28, 33, 34] - folder_to_exclude = ['BatteryAgingARC_25_26_27_28_P1'] - dict_files = build_files(folder_to_exclude) - - df = mat_to_pandas(dict_files, batteries_to_keep) - - df = set_labels_Y(df) - - df_training, df_validation, df_test = get_splitted_datas(df) - + df_res = pd.DataFrame() + for batt_nb in df['battery_nb'].unique(): + df_batt = df[ + df['battery_nb'] == batt_nb + ] + + len_df = len(df_batt) + + # Sort the DF by time + df_batt.sort_values(by=['datetime', 'charge_nb', 'discharge_nb'], + inplace=True) + + # Create labels array + good_array = [1] * (len_df // 3) # 1 = Good quality + middle_array = [2] * (len_df // 3) # 1 = Middle quality + bad_array = [3] * (len_df // 3) # 1 = Bad quality + + missing_datas_nb = len_df - ((len_df // 3) * 3) + good_array = good_array + ([1] * missing_datas_nb) + + # Insert labels to df + df_batt['quality'] = good_array + middle_array + bad_array + + df_res = pd.concat([df_res, df_batt]) + return df_res + + +def get_datas( + folder_to_exclude=['BatteryAgingARC_25_26_27_28_P1'], + batteries_to_keep=[25, 26, 27, 28, 33, 34], + src_dir=os.getcwd(), + training_size=4, + validation_size=1, + test_size=1 +): + dict_files = build_files( + folder_to_exclude=folder_to_exclude, + src_dir=src_dir + ) + df = mat_to_pandas(files=dict_files, bat_to_keep=batteries_to_keep) + df_training, df_validation, df_test = get_splitted_datas( + df, + training_size=training_size, + validation_size=validation_size, + test_size=test_size + ) -if __name__ == "__main__": - main() + return set_labels_Y(df_training), set_labels_Y(df_validation), set_labels_Y(df_test) -- GitLab