Page MenuHomec4science

model_training.py
No OneTemporary

File Metadata

Created
Thu, Jul 3, 01:57

model_training.py

import optuna
from xgboost import XGBClassifier
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.utils import shuffle
from data_processing import split_data
from feature_extraction import calculate_means_and_interval_model1, extract_features_model2, find_cluster_centers
from sklearn.preprocessing import LabelEncoder
#objective function for optuna opti,ization for both gram and bacteria classification
#function to perform shuffling and cross validation
def repeated_random_splitting_evaluation(docs, iterations=2):
gram_accuracies = []
bacteria_accuracies = []
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
for i in range(iterations):
print(f"Iteration: {i+1}/{iterations}")
shuffled_docs = shuffle(docs, random_state=i)
train_docs, test_docs = split_data(shuffled_docs)
train_features, train_gram_labels, train_bacteria_labels = extract_features_model2(train_docs)
test_features, test_gram_labels, test_bacteria_labels = extract_features_model2(test_docs)
label_encoder = LabelEncoder()
encoded_train_bacteria_labels = label_encoder.fit_transform(train_bacteria_labels)
encoded_test_bacteria_labels = label_encoder.transform(test_bacteria_labels)
X_train, X_test, y_train, y_test = train_features, test_features, train_gram_labels, test_gram_labels
X_train_bacteria, X_test_bacteria, y_train_bacteria, y_test_bacteria = train_features, test_features, encoded_train_bacteria_labels, encoded_test_bacteria_labels
gram_classifier = RandomForestClassifier(
n_estimators=100, random_state=42, max_depth=3, max_features='sqrt',
min_samples_split=20, min_samples_leaf=10, max_samples=0.8
)
gram_cv_scores = cross_val_score(gram_classifier, X_train, y_train, cv=skf, scoring='accuracy')
gram_accuracies.append(gram_cv_scores.mean())
bacteria_classifier = RandomForestClassifier(
n_estimators=300, random_state=42, max_depth=4, max_features='sqrt',
min_samples_split=20, min_samples_leaf=10, max_samples=0.8
)
bacteria_cv_scores = cross_val_score(bacteria_classifier, X_train_bacteria, y_train_bacteria, cv=skf, scoring='accuracy')
bacteria_accuracies.append(bacteria_cv_scores.mean())
return gram_accuracies, bacteria_accuracies

Event Timeline