Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F75367840
main.cpp
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Thu, Aug 1, 21:10
Size
4 KB
Mime Type
text/x-c
Expires
Sat, Aug 3, 21:10 (1 d, 21 h)
Engine
blob
Format
Raw Data
Handle
19532343
Attached To
R9868 DeepHealth_UC13_seizure_detection
main.cpp
View Options
#include <string>
#include "database.h"
#include "basicnet.h"
#define NB_CHNS 4
// Processing parameters
const
double
validation_size
=
0.15
;
const
int
test_stride
=
640
;
// Corresponds to 50% overlap, can be set to 1280 to have no overlap
// Model parameters
const
double
lr
=
0.00075
;
const
int
epochs
=
1
;
const
int
batch_size
=
32
;
int
main
()
{
std
::
cout
<<
"[INFO] Reading Dataset..."
<<
std
::
endl
;
Database
database
;
std
::
cout
<<
"[INFO] Reading Dataset Completed"
<<
std
::
endl
;
std
::
cout
<<
"[INFO] Building Net..."
<<
std
::
endl
;
BasicNet
basic_net
;
std
::
cout
<<
"[INFO] Building Net Completed"
<<
std
::
endl
;
// Count n. unique patients
// Instance a result class
// for each patient:
// write info
// test_data = data of current patient
// train_data = Database EXCEPT data of current patient
// .......
/*
patients = np.unique(data.patient)
#Load seizure times
seizures = pd.read_csv(data_folder+"seizures.csv", delimiter='\t')
seizures_ = seizures[seizures.Patient.isin(patients)]
seizures_["length"] = seizures_.apply(lambda x: (x.end_seizure - x.start_seizure), axis=1)
results = {}
for patient in patients:
print("Patient: ", patient)
patient_data = data[data.patient == patient]
files = np.unique(patient_data.file)
print(' ', len(files), ' files.')
test_data = patient_data
train_data = data.drop(patient_data.index)
#Build train/test set by cutting each signals in pieces of 5 seconds, with 50% overlap
x_train, y_train, x_test, y_test = cut_signal_data(train_data, test_data, stride_test=test_stride)
#Shuffle and balance classes
x_train, y_train = shuffle(x_train, y_train)
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=validation_size, stratify = y_train)
x_train = np.array(x_train)
x_val = np.array(x_val)
// x_data->reshape_({x_data->shape[0], NB_CHNS, 1280, 1)});
x_train = x_train.reshape((len(x_train),NB_CHNS,1280, 1))
x_val = x_val.reshape((len(x_val),NB_CHNS,1280, 1))
#Create and train new model
model = getModel()
es = EarlyStopping(monitor='val_loss', min_delta=0.005, patience=20, restore_best_weights=True)
train(model, x_train, y_train, [x_val, y_val], epochs=epochs, batch_size = batch_size)
x_test = np.array(x_test)
lengths = [len(x[0]) for x in x_test]
x_test_splitted = np.array_split(x_test, len(files))
y_test_splitted = np.array_split(np.array(y_test), len(files))
for x_test, y_test, file in zip(x_test_splitted, y_test_splitted, files):
x_test = np.array(x_test)
x_test = x_test.reshape((len(x_test),NB_CHNS,1280, 1))
predict(model, x_test, y_test)
pred = model.predict(x_test) - 0.25
#show_confusion_matrix(y_test, pred, str(patient) +' '+ str(file))
print("Accuracy (segments) :", compute_accuracy(pred, y_test))
detection_time = compute_detect_time(np.max(np.rint(pred), axis=1), y_test, [3,4])
print("Detection time window (positive if 23<d<45) :", detection_time)
seizure_length = seizures_.length[(seizures_.File == file) & (seizures_.Patient == patient)].values
plt.plot(np.linspace(0, len(y_test)-1, len(y_test)),np.max(np.rint(pred), axis=1),'.')
plt.axvline(x=int(len(y_test)/2), color='k')
if(len(seizure_length)==1):
seizure_chunk = np.rint(seizure_length/5)
plt.axvline(x=int(len(y_test)/2)+seizure_chunk, color='k')
plt.axvline(detection_time, color='r')
plt.title(str(patient)+" "+str(file))
plt.show()
#Record accuracy (temp measure. Will change)
if(str(patient) not in results.keys()):
results[str(patient)] = []
results[str(patient)].append([str(file), detection_time, compute_accuracy(pred, y_test)])
detection_times = []
counter = 0
for patient in results.keys():
for res in results[patient]:
counter+=1
detect_time = res[1]
if((detect_time>23) and (detect_time<=45)):
detection_times.append(detect_time)
plt.hist(detection_times, bins=23, range=[23,46])
plt.xlabel("Per segments of 5 seconds")
plt.show()
detected = len(detection_times)/counter
print("Detection accuracy: ", detected)
*/
}
Event Timeline
Log In to Comment