Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F97187545
main.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Fri, Jan 3, 06:57
Size
10 KB
Mime Type
text/x-python
Expires
Sun, Jan 5, 06:57 (1 d, 23 h)
Engine
blob
Format
Raw Data
Handle
23349633
Attached To
R13271 Optical_Trapping_ML
main.py
View Options
import
warnings
from
sklearn.model_selection
import
GridSearchCV
from
sklearn.preprocessing
import
LabelEncoder
from
sklearn.metrics
import
classification_report
from
Z_HELPERS.DM_communication
import
*
from
data_processing
import
*
from
feature_extraction
import
*
from
model_training
import
*
from
visualization
import
*
from
optuna_objectives
import
*
import
pickle
# Ignore future warnings
warnings
.
filterwarnings
(
"ignore"
,
category
=
FutureWarning
)
def
baseline_model_model1
(
normalized_transmission
,
interval_min
,
interval_max
):
# Baseline model will be our first implementation and will not use any learning method
# Based on assumptions about gram classification in function of the mean value of normalized frequency
mean_value
=
np
.
mean
(
normalized_transmission
)
# Classify based on the mean value
if
interval_min
>=
mean_value
:
return
1
# Gram-positive
elif
mean_value
>=
interval_max
:
return
0
# Gram-negative
else
:
if
interval_max
-
mean_value
>=
mean_value
-
interval_min
:
return
1
else
:
return
0
# Predict the nearest cluster for bacteria family classification
def
predict_bacteria_family_model1
(
normalized_transmission
,
cluster_centers
):
mean_transmission
=
np
.
mean
(
normalized_transmission
)
std_transmission
=
np
.
std
(
normalized_transmission
)
min_distance
=
float
(
'inf'
)
predicted_bacteria
=
None
for
bacteria
,
center
in
cluster_centers
.
items
():
distance
=
np
.
sqrt
((
mean_transmission
-
center
[
0
])
**
2
+
(
std_transmission
-
center
[
1
])
**
2
)
if
distance
<
min_distance
:
min_distance
=
distance
predicted_bacteria
=
bacteria
return
predicted_bacteria
if
__name__
==
'__main__'
:
print
(
"Loading data from pickle file..."
)
with
open
(
'data.pkl'
,
'rb'
)
as
f
:
docs
=
pickle
.
load
(
f
)
print
(
"Data loaded successfully."
)
modelnb
=
input
(
"Type the number of the model you want to use : "
)
# List of problematic values to exclude
problematic_names
=
[
'ns139_p8_2'
,
'ys134_p7_1'
,
'pp46_p3_0'
,
'li142_p2_3'
,
'li142_p2_2'
,
'li142_p3_4'
,
'li142_p3_5'
,
'se9_p2_0'
,
'se9_p7_1'
]
display_names
=
{
'bs134'
:
'B. subtilis'
,
'coli'
:
'E. coli'
,
'li142'
:
'L. innocua'
,
'ns139'
:
'N. sicca'
,
'pp46'
:
'P. putida'
,
'pp6'
:
'Pseudomonas putida'
,
'se26'
:
'S. epidermidis B'
,
'se9'
:
'S. epidermidis A'
,
'ys134'
:
'Y. ruckeri'
}
#model 1 not using any learning
# Model 1 not using any learning
if
modelnb
==
'1'
:
normalized_docs
=
normalize_data_model1
(
docs
)
filtered_docs
=
[
doc
for
doc
in
normalized_docs
if
doc
[
'name'
]
not
in
problematic_names
and
doc
[
'bacteria'
]
!=
'pp6'
]
train_docs
,
test_docs
=
split_data
(
filtered_docs
)
interval_min
,
interval_max
=
calculate_means_and_interval_model1
(
train_docs
)
print
(
f
"Interval: ({interval_min}, {interval_max})"
)
# Gram type classification
true_labels
=
[]
predicted_labels
=
[]
for
doc
in
test_docs
:
normalized_transmission
=
doc
[
'transmission_normalized'
]
true_label
=
doc
[
'label'
]
predicted_label
=
baseline_model_model1
(
normalized_transmission
,
interval_min
,
interval_max
)
true_labels
.
append
(
true_label
)
predicted_labels
.
append
(
predicted_label
)
accuracy
=
np
.
mean
(
np
.
array
(
true_labels
)
==
np
.
array
(
predicted_labels
))
print
(
f
'Baseline model accuracy (Gram type): {accuracy * 100:.2f}%'
)
# Bacteria family classification
cluster_centers
=
find_cluster_centers
(
train_docs
)
#plot_training_clusters(train_docs, cluster_centers, display_names)
true_bacteria_labels
=
[]
predicted_bacteria_labels
=
[]
for
doc
in
test_docs
:
normalized_transmission
=
doc
[
'transmission_normalized'
]
true_bacteria
=
doc
[
'bacteria'
]
predicted_bacteria
=
predict_bacteria_family_model1
(
normalized_transmission
,
cluster_centers
)
true_bacteria_labels
.
append
(
true_bacteria
)
predicted_bacteria_labels
.
append
(
predicted_bacteria
)
bacteria_accuracy
=
np
.
mean
(
np
.
array
(
true_bacteria_labels
)
==
np
.
array
(
predicted_bacteria_labels
))
print
(
f
'Baseline model accuracy (Bacteria type): {bacteria_accuracy * 100:.2f}%'
)
plot_all_clusters
(
train_docs
,
test_docs
,
cluster_centers
,
display_names
,
true_bacteria_labels
,
predicted_bacteria_labels
)
# Scatter plot for Gram type classification
# Call the function to plot Gram type classification
plot_gram_type_classification
(
train_docs
,
interval_min
,
interval_max
)
from
sklearn.metrics
import
confusion_matrix
,
ConfusionMatrixDisplay
# Confusion matrix for bacteria family classification with improved formatting
# Call the function to plot the confusion matrix
plot_confusion_matrix
(
true_bacteria_labels
,
predicted_bacteria_labels
,
display_names
)
# Function to plot cluster center calculation illustration
# Call the function to plot cluster center calculation illustration
plot_cluster_center_calculation
(
train_docs
,
cluster_centers
,
display_names
)
# Model 2 using basic, simple learning methods
elif
modelnb
==
'2'
:
normalized_docs
=
normalize_data_model2
(
docs
)
filtered_docs
=
[
doc
for
doc
in
normalized_docs
if
doc
[
'name'
]
not
in
problematic_names
and
doc
[
'bacteria'
]
!=
'pp6'
]
chunked_docs
=
chunk_time_series
(
filtered_docs
,
num_chunks
=
6
)
# Augment the data
augmented_docs
=
augment_data
(
chunked_docs
,
noise_level
=
0.05
,
shift_max
=
10
)
# Perform repeated random splitting and evaluation
gram_accuracies
,
bacteria_accuracies
=
repeated_random_splitting_evaluation
(
augmented_docs
,
iterations
=
2
)
# Print the accuracy scores for each iteration
print
(
"Gram type classifier cross-validation accuracies for each iteration:"
)
print
(
gram_accuracies
)
print
(
f
"Mean accuracy: {np.mean(gram_accuracies):.2f} ± {np.std(gram_accuracies):.2f}"
)
print
(
"Bacteria type classifier cross-validation accuracies for each iteration:"
)
print
(
bacteria_accuracies
)
print
(
f
"Mean accuracy: {np.mean(bacteria_accuracies):.2f} ± {np.std(bacteria_accuracies):.2f}"
)
# Model 3: Using CatBoost with Optuna
elif
modelnb
==
'3'
:
normalized_docs
=
normalize_data_model2
(
docs
)
filtered_docs
=
[
doc
for
doc
in
normalized_docs
if
doc
[
'name'
]
not
in
problematic_names
and
doc
[
'bacteria'
]
!=
'pp6'
]
chunked_docs
=
chunk_time_series
(
filtered_docs
,
num_chunks
=
6
)
augmented_docs
=
augment_data
(
chunked_docs
,
noise_level
=
0.05
,
shift_max
=
10
)
train_docs
,
test_docs
=
split_data
(
augmented_docs
)
train_features
,
train_gram_labels
,
train_bacteria_labels
=
extract_features_model2
(
train_docs
)
test_features
,
test_gram_labels
,
test_bacteria_labels
=
extract_features_model2
(
test_docs
)
label_encoder
=
LabelEncoder
()
encoded_train_bacteria_labels
=
label_encoder
.
fit_transform
(
train_bacteria_labels
)
encoded_test_bacteria_labels
=
label_encoder
.
transform
(
test_bacteria_labels
)
X_train
,
X_test
,
y_train
,
y_test
=
train_features
,
test_features
,
train_gram_labels
,
test_gram_labels
X_train_bacteria
,
X_test_bacteria
,
y_train_bacteria
,
y_test_bacteria
=
train_features
,
test_features
,
encoded_train_bacteria_labels
,
encoded_test_bacteria_labels
objective
=
ObjectiveXGB
(
X_train
,
X_test
,
y_train
,
y_test
,
X_train_bacteria
,
X_test_bacteria
,
y_train_bacteria
,
y_test_bacteria
)
study_gram
=
optuna
.
create_study
(
direction
=
'maximize'
)
study_gram
.
optimize
(
objective
.
objective_xgb_gram
,
n_trials
=
50
)
print
(
f
"Best trial (Gram type): {study_gram.best_trial.value}"
)
print
(
"Best hyperparameters (Gram type): "
,
study_gram
.
best_trial
.
params
)
best_params_gram
=
study_gram
.
best_trial
.
params
xgb_gram
=
XGBClassifier
(
**
best_params_gram
,
use_label_encoder
=
False
,
eval_metric
=
'logloss'
)
xgb_gram
.
fit
(
X_train
,
y_train
)
train_accuracy_gram
=
xgb_gram
.
score
(
X_train
,
y_train
)
print
(
f
'Training accuracy for Gram type classifier: {train_accuracy_gram:.2f}'
)
y_pred_gram
=
xgb_gram
.
predict
(
X_test
)
print
(
"Gram Type Classification Report:"
)
print
(
classification_report
(
y_test
,
y_pred_gram
,
zero_division
=
0
))
study_bacteria
=
optuna
.
create_study
(
direction
=
'maximize'
)
study_bacteria
.
optimize
(
objective
.
objective_xgb_bacteria
,
n_trials
=
50
)
print
(
f
"Best trial (Bacteria type): {study_bacteria.best_trial.value}"
)
print
(
"Best hyperparameters (Bacteria type): "
,
study_bacteria
.
best_trial
.
params
)
best_params_bacteria
=
study_bacteria
.
best_trial
.
params
xgb_bacteria
=
XGBClassifier
(
**
best_params_bacteria
,
use_label_encoder
=
False
,
eval_metric
=
'mlogloss'
,
objective
=
'multi:softmax'
,
num_class
=
len
(
np
.
unique
(
y_train_bacteria
)))
xgb_bacteria
.
fit
(
X_train_bacteria
,
y_train_bacteria
)
train_accuracy_bacteria
=
xgb_bacteria
.
score
(
X_train_bacteria
,
y_train_bacteria
)
print
(
f
'Training accuracy for Bacteria type classifier: {train_accuracy_bacteria:.2f}'
)
y_pred_bacteria
=
xgb_bacteria
.
predict
(
X_test_bacteria
)
print
(
"Bacteria Type Classification Report:"
)
print
(
classification_report
(
y_test_bacteria
,
y_pred_bacteria
,
target_names
=
label_encoder
.
classes_
,
zero_division
=
0
))
gram_cv_scores
=
cross_val_score
(
xgb_gram
,
X_train
,
y_train
,
cv
=
5
,
scoring
=
'accuracy'
)
print
(
f
'Gram type classifier cross-validation accuracy: {gram_cv_scores.mean():.2f} ± {gram_cv_scores.std():.2f}'
)
bacteria_cv_scores
=
cross_val_score
(
xgb_bacteria
,
X_train_bacteria
,
y_train_bacteria
,
cv
=
5
,
scoring
=
'accuracy'
)
print
(
f
'Bacteria type classifier cross-validation accuracy: {bacteria_cv_scores.mean():.2f} ± {bacteria_cv_scores.std():.2f}'
)
test_accuracy_gram
=
xgb_gram
.
score
(
X_test
,
y_test
)
test_accuracy_bacteria
=
xgb_bacteria
.
score
(
X_test_bacteria
,
y_test_bacteria
)
print
(
f
'Test accuracy for Gram type classifier: {test_accuracy_gram:.2f}'
)
print
(
f
'Test accuracy for Bacteria type classifier: {test_accuracy_bacteria:.2f}'
)
if
train_accuracy_gram
>
test_accuracy_gram
:
print
(
"Potential overfitting detected in Gram type classifier."
)
elif
train_accuracy_gram
<
test_accuracy_gram
:
print
(
"Potential underfitting detected in Gram type classifier."
)
if
train_accuracy_bacteria
>
test_accuracy_bacteria
:
print
(
"Potential overfitting detected in Bacteria type classifier."
)
elif
train_accuracy_bacteria
<
test_accuracy_bacteria
:
print
(
"Potential underfitting detected in Bacteria type classifier."
)
Event Timeline
Log In to Comment