diff --git a/src_pyeddl/train_full_model_onnx.py b/src_pyeddl/train_full_model_onnx.py new file mode 100644 index 0000000..2e4eaa4 --- /dev/null +++ b/src_pyeddl/train_full_model_onnx.py @@ -0,0 +1,74 @@ +#Imports +import numpy as np +import pandas as pd + +# Utility file +import ModelHandler as mh + +# sklearn +from sklearn.utils import shuffle +from sklearn.model_selection import train_test_split + + +if __name__ == "__main__": + # Msc + temp_data_folder = "../temp_data/" + model_file = "../model/model.bin" + data_folder = "/shares/eslfiler1/scratch/teijeiro/chbmit_deephealth/" + signal_length = "1min" + + # processing + validation_size=0.15 + + # Model + epochs = 50 + batch_size = 32 + learning_rate = 0.0075 # 0.00075 + handler = mh.ModelHandler() + + with open(data_folder+"signal_mit_"+signal_length+".csv", "rb") as file: + x_data_ = file.read().splitlines() + + y_data = np.loadtxt(data_folder+"labels_mit_"+signal_length+".txt") + info_data = np.loadtxt(data_folder+"infos_mit_"+signal_length+".txt", dtype="str") + + # Convert from string + x_data = [] + for sig in x_data_: + x_data.append(np.fromstring(sig, sep=',')) + + # Reshape in 2D as data are 1D in csv file + x_data = [np.reshape(np.array(x_data_i), (handler.NB_CHNS, int(len(x_data_i)/handler.NB_CHNS))) for x_data_i in x_data] + + # Create the pandas df + data = pd.concat([pd.Series(y_data), pd.Series([info[0] for info in info_data]),pd.Series([info[1] for info in info_data])], axis=1) + data.columns = ["label", "patient", "file"] + data["signal"] = "" + data["signal"].astype(object) + for i, sig in enumerate(x_data): + data.at[i, "signal"] = sig + + patients = np.unique(data.patient) + data.sort_values(["patient", "file", "label"], inplace=True) + data = data.reset_index(drop=True) + + # Cutting the signals in 5-second segments for training + x_train, y_train, _, _ = handler.cut_signal_data(data, data[data.patient==None]) + + # Train/validation splitting + x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=validation_size, stratify=y_train) + x_train = np.array(x_train) + x_val = np.array(x_val) + x_train = x_train.reshape((len(x_train), handler.NB_CHNS, 1280, 1)) + x_val = x_val.reshape((len(x_val), handler.NB_CHNS, 1280, 1)) + y_train = np.eye(2)[np.array(y_train).astype(int)] + y_val = np.eye(2)[np.array(y_val).astype(int)] + + #Model training + handler.train(model_file, epochs, batch_size, learning_rate, x_train, y_train, x_val, y_val) + + #Plain evaluation with the training set + handler.evaluate(model_file, x_train, y_train) + + #Model storage in ONNX + handler.save_to_onnx('uc13_epilepsy_net.onnx')