Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F120261672
model_training.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Thu, Jul 3, 01:57
Size
2 KB
Mime Type
text/x-python
Expires
Sat, Jul 5, 01:57 (2 d)
Engine
blob
Format
Raw Data
Handle
27163544
Attached To
R13271 Optical_Trapping_ML
model_training.py
View Options
import
optuna
from
xgboost
import
XGBClassifier
from
sklearn.model_selection
import
StratifiedKFold
,
cross_val_score
from
sklearn.ensemble
import
RandomForestClassifier
from
sklearn.utils
import
shuffle
from
data_processing
import
split_data
from
feature_extraction
import
calculate_means_and_interval_model1
,
extract_features_model2
,
find_cluster_centers
from
sklearn.preprocessing
import
LabelEncoder
#objective function for optuna opti,ization for both gram and bacteria classification
#function to perform shuffling and cross validation
def
repeated_random_splitting_evaluation
(
docs
,
iterations
=
2
):
gram_accuracies
=
[]
bacteria_accuracies
=
[]
skf
=
StratifiedKFold
(
n_splits
=
5
,
shuffle
=
True
,
random_state
=
42
)
for
i
in
range
(
iterations
):
print
(
f
"Iteration: {i+1}/{iterations}"
)
shuffled_docs
=
shuffle
(
docs
,
random_state
=
i
)
train_docs
,
test_docs
=
split_data
(
shuffled_docs
)
train_features
,
train_gram_labels
,
train_bacteria_labels
=
extract_features_model2
(
train_docs
)
test_features
,
test_gram_labels
,
test_bacteria_labels
=
extract_features_model2
(
test_docs
)
label_encoder
=
LabelEncoder
()
encoded_train_bacteria_labels
=
label_encoder
.
fit_transform
(
train_bacteria_labels
)
encoded_test_bacteria_labels
=
label_encoder
.
transform
(
test_bacteria_labels
)
X_train
,
X_test
,
y_train
,
y_test
=
train_features
,
test_features
,
train_gram_labels
,
test_gram_labels
X_train_bacteria
,
X_test_bacteria
,
y_train_bacteria
,
y_test_bacteria
=
train_features
,
test_features
,
encoded_train_bacteria_labels
,
encoded_test_bacteria_labels
gram_classifier
=
RandomForestClassifier
(
n_estimators
=
100
,
random_state
=
42
,
max_depth
=
3
,
max_features
=
'sqrt'
,
min_samples_split
=
20
,
min_samples_leaf
=
10
,
max_samples
=
0.8
)
gram_cv_scores
=
cross_val_score
(
gram_classifier
,
X_train
,
y_train
,
cv
=
skf
,
scoring
=
'accuracy'
)
gram_accuracies
.
append
(
gram_cv_scores
.
mean
())
bacteria_classifier
=
RandomForestClassifier
(
n_estimators
=
300
,
random_state
=
42
,
max_depth
=
4
,
max_features
=
'sqrt'
,
min_samples_split
=
20
,
min_samples_leaf
=
10
,
max_samples
=
0.8
)
bacteria_cv_scores
=
cross_val_score
(
bacteria_classifier
,
X_train_bacteria
,
y_train_bacteria
,
cv
=
skf
,
scoring
=
'accuracy'
)
bacteria_accuracies
.
append
(
bacteria_cv_scores
.
mean
())
return
gram_accuracies
,
bacteria_accuracies
Event Timeline
Log In to Comment