Page MenuHomec4science

model_basic_lopo.py
No OneTemporary

File Metadata

Created
Sun, May 12, 02:56

model_basic_lopo.py

#Imports
import random, time, os
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
import json
#Utility file
from utils import *
import pyeddl._core.eddl as eddl
import pyeddl._core.eddlT as eddlT
#sklearn
from sklearn.utils import shuffle
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, roc_auc_score
from sklearn.model_selection import train_test_split
from scipy.signal import spectrogram, stft
#Msc
temp_data_folder = "../temp_data/"
model_file = "../model/model.bin"
data_folder = "../dataset/"
signal_length = "1mn"
#processing
validation_size=0.15
test_stride = 640 #Corresponds to 50% overlap, can be set to 1280 to have no overlap
#Model
epochs = 50
batch_size = 32
learning_rate = 0.0075 # 0.00075
with open(data_folder+"signal_mit_"+signal_length+".csv", "rb") as file:
x_data_ = file.read().splitlines()
y_data = np.loadtxt(data_folder+"labels_mit_"+signal_length+".txt")
info_data = np.loadtxt(data_folder+"infos_mit_"+signal_length+".txt", dtype="str")
#Convert from string
x_data = []
for sig in x_data_:
x_data.append(np.fromstring(sig, sep=','))
#Reshape in 2D as data are 1D in csv file
x_data = [np.reshape(np.array(x_data_i), (NB_CHNS,int(len(x_data_i)/NB_CHNS))) for x_data_i in x_data]
#Create the pandas df
data = pd.concat([pd.Series(y_data), pd.Series([info[0] for info in info_data]),pd.Series([info[1] for info in info_data])], axis=1)
data.columns = ["label", "patient", "file"]
data["signal"] = ""
data["signal"].astype(object)
for i,sig in enumerate(x_data):
data.at[i,"signal"] = sig
patients = np.unique(data.patient)
data.sort_values(["patient", "file", "label"], inplace=True)
data = data.reset_index(drop=True)
#Load seizure times
seizures = pd.read_csv(data_folder+"seizures.csv", delimiter='\t')
seizures_ = seizures[seizures.Patient.isin(patients)]
seizures_["length"] = seizures_.apply(lambda x: (x.end_seizure - x.start_seizure), axis=1)
results = {}
for patient in patients:
print("Patient: ", patient)
patient_data = data[data.patient == patient]
files = np.unique(patient_data.file)
print(' ', len(files), ' files.')
test_data = patient_data
train_data = data.drop(patient_data.index)
#Build train/test set by cutting each signals in pieces of 5 seconds, with 50% overlap
x_train, y_train, x_test, y_test = cut_signal_data(train_data, test_data, stride_test=test_stride)
#Shuffle and balance classes
x_train, y_train = shuffle(x_train, y_train)
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=validation_size, stratify = y_train)
x_train = np.array(x_train)
x_val = np.array(x_val)
x_train = x_train.reshape((len(x_train),NB_CHNS,1280, 1))
x_val = x_val.reshape((len(x_val),NB_CHNS,1280, 1))
y_train = np.eye(2)[np.array(y_train).astype(int)]
y_val = np.eye(2)[np.array(y_val).astype(int)]
train(temp_data_folder, model_file, epochs, batch_size, learning_rate, x_train, y_train, x_val, y_val)
x_test = np.array(x_test)
lengths = [len(x[0]) for x in x_test]
x_test_splitted = np.array_split(x_test, len(files))
y_test_splitted = np.array_split(np.array(y_test), len(files))
for x_test, y_test, file in zip(x_test_splitted, y_test_splitted, files):
x_test = np.array(x_test)
x_test = x_test.reshape((len(x_test),NB_CHNS,1280, 1))
y_test = np.eye(2)[np.array(y_test).astype(int)]
evaluate(temp_data_folder, model_file, x_test, y_test)
quit()
# pred = model.predict(x_test) - 0.25
#show_confusion_matrix(y_test, pred, str(patient) +' '+ str(file))
# print("Accuracy (segments) :", compute_accuracy(pred, y_test))
# detection_time = compute_detect_time(np.max(np.rint(pred), axis=1), y_test, [3,4])
# print("Detection time window (positive if 23<d<45) :", detection_time)
# seizure_length = seizures_.length[(seizures_.File == file) & (seizures_.Patient == patient)].values
# plt.plot(np.linspace(0, len(y_test)-1, len(y_test)),np.max(np.rint(pred), axis=1),'.')
# plt.axvline(x=int(len(y_test)/2), color='k')
# if(len(seizure_length)==1):
# seizure_chunk = np.rint(seizure_length/5)
# plt.axvline(x=int(len(y_test)/2)+seizure_chunk, color='k')
# plt.axvline(detection_time, color='r')
# plt.title(str(patient)+" "+str(file))
# plt.show()
#Record accuracy (temp measure. Will change)
# if(str(patient) not in results.keys()):
# results[str(patient)] = []
# results[str(patient)].append([str(file), detection_time, compute_accuracy(pred, y_test)])
#detection_times = []
#counter = 0
#for patient in results.keys():
# for res in results[patient]:
# counter+=1
# detect_time = res[1]
# if((detect_time>23) and (detect_time<=45)):
# detection_times.append(detect_time)
#plt.hist(detection_times, bins=23, range=[23,46])
#plt.xlabel("Per segments of 5 seconds")
#plt.show()
#detected = len(detection_times)/counter
#print("Detection accuracy: ", detected)

Event Timeline