Page MenuHomec4science

model_xmlr.py
No OneTemporary

File Metadata

Created
Fri, Nov 8, 23:16

model_xmlr.py

import os
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
def train(model_path, epochs, name, end="", start=None, tok_loc="tok"):
print("train ...")
cmd = "python run_qa.py \
--train_file \"{}\" \
--do_train \
--validation_file \"{}\" \
--do_eval \
--num_train_epochs \"{}\" \
--output_dir=\"{}\" \
--fp16 \
".format(
"train" + end + ".json", "eval" + end + ".json", epochs, model_path + "/" + name)
if start is not None:
if start not in ["xlm-roberta-base", "roberta-base"]:
start = model_path+"/"+start
cmd += " --model_name_or_path=\"{}\"".format(start)
print(cmd)
os.system(cmd)
def get_model(model_path, name):
model = AutoModelForQuestionAnswering.from_pretrained(model_path+"/"+name, from_pt=True)
tokenizer = AutoTokenizer.from_pretrained(model_path + "/" + name, from_pt=True)
return model, tokenizer

Event Timeline