Page MenuHomec4science

main.py
No OneTemporary

File Metadata

Created
Fri, Jan 3, 06:57
import warnings
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report
from Z_HELPERS.DM_communication import *
from data_processing import *
from feature_extraction import *
from model_training import *
from visualization import *
from optuna_objectives import *
import pickle
# Ignore future warnings
warnings.filterwarnings("ignore", category=FutureWarning)
def baseline_model_model1(normalized_transmission, interval_min, interval_max):
# Baseline model will be our first implementation and will not use any learning method
# Based on assumptions about gram classification in function of the mean value of normalized frequency
mean_value = np.mean(normalized_transmission)
# Classify based on the mean value
if interval_min >= mean_value:
return 1 # Gram-positive
elif mean_value >= interval_max:
return 0 # Gram-negative
else:
if interval_max - mean_value >= mean_value - interval_min:
return 1
else:
return 0
# Predict the nearest cluster for bacteria family classification
def predict_bacteria_family_model1(normalized_transmission, cluster_centers):
mean_transmission = np.mean(normalized_transmission)
std_transmission = np.std(normalized_transmission)
min_distance = float('inf')
predicted_bacteria = None
for bacteria, center in cluster_centers.items():
distance = np.sqrt((mean_transmission - center[0]) ** 2 + (std_transmission - center[1]) ** 2)
if distance < min_distance:
min_distance = distance
predicted_bacteria = bacteria
return predicted_bacteria
if __name__ == '__main__':
print("Loading data from pickle file...")
with open('data.pkl', 'rb') as f:
docs = pickle.load(f)
print("Data loaded successfully.")
modelnb= input("Type the number of the model you want to use : ")
# List of problematic values to exclude
problematic_names = [
'ns139_p8_2', 'ys134_p7_1', 'pp46_p3_0', 'li142_p2_3',
'li142_p2_2', 'li142_p3_4', 'li142_p3_5', 'se9_p2_0', 'se9_p7_1'
]
display_names = {
'bs134': 'B. subtilis',
'coli': 'E. coli',
'li142': 'L. innocua',
'ns139': 'N. sicca',
'pp46': 'P. putida',
'pp6': 'Pseudomonas putida',
'se26': 'S. epidermidis B',
'se9': 'S. epidermidis A',
'ys134': 'Y. ruckeri'
}
#model 1 not using any learning
# Model 1 not using any learning
if modelnb == '1':
normalized_docs = normalize_data_model1(docs)
filtered_docs = [doc for doc in normalized_docs if
doc['name'] not in problematic_names and doc['bacteria'] != 'pp6']
train_docs, test_docs = split_data(filtered_docs)
interval_min, interval_max = calculate_means_and_interval_model1(train_docs)
print(f"Interval: ({interval_min}, {interval_max})")
# Gram type classification
true_labels = []
predicted_labels = []
for doc in test_docs:
normalized_transmission = doc['transmission_normalized']
true_label = doc['label']
predicted_label = baseline_model_model1(normalized_transmission, interval_min, interval_max)
true_labels.append(true_label)
predicted_labels.append(predicted_label)
accuracy = np.mean(np.array(true_labels) == np.array(predicted_labels))
print(f'Baseline model accuracy (Gram type): {accuracy * 100:.2f}%')
# Bacteria family classification
cluster_centers = find_cluster_centers(train_docs)
#plot_training_clusters(train_docs, cluster_centers, display_names)
true_bacteria_labels = []
predicted_bacteria_labels = []
for doc in test_docs:
normalized_transmission = doc['transmission_normalized']
true_bacteria = doc['bacteria']
predicted_bacteria = predict_bacteria_family_model1(normalized_transmission, cluster_centers)
true_bacteria_labels.append(true_bacteria)
predicted_bacteria_labels.append(predicted_bacteria)
bacteria_accuracy = np.mean(np.array(true_bacteria_labels) == np.array(predicted_bacteria_labels))
print(f'Baseline model accuracy (Bacteria type): {bacteria_accuracy * 100:.2f}%')
plot_all_clusters(train_docs, test_docs, cluster_centers, display_names, true_bacteria_labels,predicted_bacteria_labels)
# Scatter plot for Gram type classification
# Call the function to plot Gram type classification
plot_gram_type_classification(train_docs, interval_min, interval_max)
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
# Confusion matrix for bacteria family classification with improved formatting
# Call the function to plot the confusion matrix
plot_confusion_matrix(true_bacteria_labels, predicted_bacteria_labels, display_names)
# Function to plot cluster center calculation illustration
# Call the function to plot cluster center calculation illustration
plot_cluster_center_calculation(train_docs, cluster_centers, display_names)
# Model 2 using basic, simple learning methods
elif modelnb == '2':
normalized_docs = normalize_data_model2(docs)
filtered_docs = [doc for doc in normalized_docs if
doc['name'] not in problematic_names and doc['bacteria'] != 'pp6']
chunked_docs = chunk_time_series(filtered_docs, num_chunks=6)
# Augment the data
augmented_docs = augment_data(chunked_docs, noise_level=0.05, shift_max=10)
# Perform repeated random splitting and evaluation
gram_accuracies, bacteria_accuracies = repeated_random_splitting_evaluation(augmented_docs, iterations=2)
# Print the accuracy scores for each iteration
print("Gram type classifier cross-validation accuracies for each iteration:")
print(gram_accuracies)
print(f"Mean accuracy: {np.mean(gram_accuracies):.2f} ± {np.std(gram_accuracies):.2f}")
print("Bacteria type classifier cross-validation accuracies for each iteration:")
print(bacteria_accuracies)
print(f"Mean accuracy: {np.mean(bacteria_accuracies):.2f} ± {np.std(bacteria_accuracies):.2f}")
# Model 3: Using CatBoost with Optuna
elif modelnb == '3':
normalized_docs = normalize_data_model2(docs)
filtered_docs = [doc for doc in normalized_docs if
doc['name'] not in problematic_names and doc['bacteria'] != 'pp6']
chunked_docs = chunk_time_series(filtered_docs, num_chunks=6)
augmented_docs = augment_data(chunked_docs, noise_level=0.05, shift_max=10)
train_docs, test_docs = split_data(augmented_docs)
train_features, train_gram_labels, train_bacteria_labels = extract_features_model2(train_docs)
test_features, test_gram_labels, test_bacteria_labels = extract_features_model2(test_docs)
label_encoder = LabelEncoder()
encoded_train_bacteria_labels = label_encoder.fit_transform(train_bacteria_labels)
encoded_test_bacteria_labels = label_encoder.transform(test_bacteria_labels)
X_train, X_test, y_train, y_test = train_features, test_features, train_gram_labels, test_gram_labels
X_train_bacteria, X_test_bacteria, y_train_bacteria, y_test_bacteria = train_features, test_features, encoded_train_bacteria_labels, encoded_test_bacteria_labels
objective = ObjectiveXGB(X_train, X_test, y_train, y_test, X_train_bacteria, X_test_bacteria, y_train_bacteria,
y_test_bacteria)
study_gram = optuna.create_study(direction='maximize')
study_gram.optimize(objective.objective_xgb_gram, n_trials=50)
print(f"Best trial (Gram type): {study_gram.best_trial.value}")
print("Best hyperparameters (Gram type): ", study_gram.best_trial.params)
best_params_gram = study_gram.best_trial.params
xgb_gram = XGBClassifier(**best_params_gram, use_label_encoder=False, eval_metric='logloss')
xgb_gram.fit(X_train, y_train)
train_accuracy_gram = xgb_gram.score(X_train, y_train)
print(f'Training accuracy for Gram type classifier: {train_accuracy_gram:.2f}')
y_pred_gram = xgb_gram.predict(X_test)
print("Gram Type Classification Report:")
print(classification_report(y_test, y_pred_gram, zero_division=0))
study_bacteria = optuna.create_study(direction='maximize')
study_bacteria.optimize(objective.objective_xgb_bacteria, n_trials=50)
print(f"Best trial (Bacteria type): {study_bacteria.best_trial.value}")
print("Best hyperparameters (Bacteria type): ", study_bacteria.best_trial.params)
best_params_bacteria = study_bacteria.best_trial.params
xgb_bacteria = XGBClassifier(**best_params_bacteria, use_label_encoder=False, eval_metric='mlogloss',
objective='multi:softmax', num_class=len(np.unique(y_train_bacteria)))
xgb_bacteria.fit(X_train_bacteria, y_train_bacteria)
train_accuracy_bacteria = xgb_bacteria.score(X_train_bacteria, y_train_bacteria)
print(f'Training accuracy for Bacteria type classifier: {train_accuracy_bacteria:.2f}')
y_pred_bacteria = xgb_bacteria.predict(X_test_bacteria)
print("Bacteria Type Classification Report:")
print(classification_report(y_test_bacteria, y_pred_bacteria, target_names=label_encoder.classes_,
zero_division=0))
gram_cv_scores = cross_val_score(xgb_gram, X_train, y_train, cv=5, scoring='accuracy')
print(
f'Gram type classifier cross-validation accuracy: {gram_cv_scores.mean():.2f} ± {gram_cv_scores.std():.2f}')
bacteria_cv_scores = cross_val_score(xgb_bacteria, X_train_bacteria, y_train_bacteria, cv=5, scoring='accuracy')
print(
f'Bacteria type classifier cross-validation accuracy: {bacteria_cv_scores.mean():.2f} ± {bacteria_cv_scores.std():.2f}')
test_accuracy_gram = xgb_gram.score(X_test, y_test)
test_accuracy_bacteria = xgb_bacteria.score(X_test_bacteria, y_test_bacteria)
print(f'Test accuracy for Gram type classifier: {test_accuracy_gram:.2f}')
print(f'Test accuracy for Bacteria type classifier: {test_accuracy_bacteria:.2f}')
if train_accuracy_gram > test_accuracy_gram:
print("Potential overfitting detected in Gram type classifier.")
elif train_accuracy_gram < test_accuracy_gram:
print("Potential underfitting detected in Gram type classifier.")
if train_accuracy_bacteria > test_accuracy_bacteria:
print("Potential overfitting detected in Bacteria type classifier.")
elif train_accuracy_bacteria < test_accuracy_bacteria:
print("Potential underfitting detected in Bacteria type classifier.")

Event Timeline