Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F61542769
eval_rag.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Tue, May 7, 08:13
Size
10 KB
Mime Type
text/x-python
Expires
Thu, May 9, 08:13 (1 d, 23 h)
Engine
blob
Format
Raw Data
Handle
17525960
Attached To
R11484 ADDI
eval_rag.py
View Options
""" Evaluation script for RAG models."""
import
argparse
import
ast
import
logging
import
os
import
sys
import
pandas
as
pd
import
torch
from
tqdm
import
tqdm
from
transformers
import
BartForConditionalGeneration
,
RagRetriever
,
RagSequenceForGeneration
,
RagTokenForGeneration
from
transformers
import
logging
as
transformers_logging
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
getcwd
()))
# noqa: E402 # isort:skip
from
utils_rag
import
exact_match_score
,
f1_score
# noqa: E402 # isort:skip
logger
=
logging
.
getLogger
(
__name__
)
logging
.
basicConfig
(
level
=
logging
.
INFO
)
transformers_logging
.
set_verbosity_info
()
def
infer_model_type
(
model_name_or_path
):
if
"token"
in
model_name_or_path
:
return
"rag_token"
if
"sequence"
in
model_name_or_path
:
return
"rag_sequence"
if
"bart"
in
model_name_or_path
:
return
"bart"
return
None
def
metric_max_over_ground_truths
(
metric_fn
,
prediction
,
ground_truths
):
return
max
(
metric_fn
(
prediction
,
gt
)
for
gt
in
ground_truths
)
def
get_scores
(
args
,
preds_path
,
gold_data_path
):
hypos
=
[
line
.
strip
()
for
line
in
open
(
preds_path
,
"r"
)
.
readlines
()]
answers
=
[]
if
args
.
gold_data_mode
==
"qa"
:
data
=
pd
.
read_csv
(
gold_data_path
,
sep
=
"
\t
"
,
header
=
None
)
for
answer_list
in
data
[
1
]:
ground_truths
=
ast
.
literal_eval
(
answer_list
)
answers
.
append
(
ground_truths
)
else
:
references
=
[
line
.
strip
()
for
line
in
open
(
gold_data_path
,
"r"
)
.
readlines
()]
answers
=
[[
reference
]
for
reference
in
references
]
f1
=
em
=
total
=
0
for
prediction
,
ground_truths
in
zip
(
hypos
,
answers
):
total
+=
1
em
+=
metric_max_over_ground_truths
(
exact_match_score
,
prediction
,
ground_truths
)
f1
+=
metric_max_over_ground_truths
(
f1_score
,
prediction
,
ground_truths
)
em
=
100.0
*
em
/
total
f1
=
100.0
*
f1
/
total
logger
.
info
(
f
"F1: {f1:.2f}"
)
logger
.
info
(
f
"EM: {em:.2f}"
)
def
get_precision_at_k
(
args
,
preds_path
,
gold_data_path
):
k
=
args
.
k
hypos
=
[
line
.
strip
()
for
line
in
open
(
preds_path
,
"r"
)
.
readlines
()]
references
=
[
line
.
strip
()
for
line
in
open
(
gold_data_path
,
"r"
)
.
readlines
()]
em
=
total
=
0
for
hypo
,
reference
in
zip
(
hypos
,
references
):
hypo_provenance
=
set
(
hypo
.
split
(
"
\t
"
)[:
k
])
ref_provenance
=
set
(
reference
.
split
(
"
\t
"
))
total
+=
1
em
+=
len
(
hypo_provenance
&
ref_provenance
)
/
k
em
=
100.0
*
em
/
total
logger
.
info
(
f
"Precision@{k}: {em: .2f}"
)
def
evaluate_batch_retrieval
(
args
,
rag_model
,
questions
):
def
strip_title
(
title
):
if
title
.
startswith
(
'"'
):
title
=
title
[
1
:]
if
title
.
endswith
(
'"'
):
title
=
title
[:
-
1
]
return
title
retriever_input_ids
=
rag_model
.
retriever
.
question_encoder_tokenizer
.
batch_encode_plus
(
questions
,
return_tensors
=
"pt"
,
padding
=
True
,
truncation
=
True
,
)[
"input_ids"
]
.
to
(
args
.
device
)
question_enc_outputs
=
rag_model
.
rag
.
question_encoder
(
retriever_input_ids
)
question_enc_pool_output
=
question_enc_outputs
[
0
]
result
=
rag_model
.
retriever
(
retriever_input_ids
,
question_enc_pool_output
.
cpu
()
.
detach
()
.
to
(
torch
.
float32
)
.
numpy
(),
prefix
=
rag_model
.
rag
.
generator
.
config
.
prefix
,
n_docs
=
rag_model
.
config
.
n_docs
,
return_tensors
=
"pt"
,
)
all_docs
=
rag_model
.
retriever
.
index
.
get_doc_dicts
(
result
.
doc_ids
)
provenance_strings
=
[]
for
docs
in
all_docs
:
provenance
=
[
strip_title
(
title
)
for
title
in
docs
[
"title"
]]
provenance_strings
.
append
(
"
\t
"
.
join
(
provenance
))
return
provenance_strings
def
evaluate_batch_e2e
(
args
,
rag_model
,
questions
):
with
torch
.
no_grad
():
inputs_dict
=
rag_model
.
retriever
.
question_encoder_tokenizer
.
batch_encode_plus
(
questions
,
return_tensors
=
"pt"
,
padding
=
True
,
truncation
=
True
)
input_ids
=
inputs_dict
.
input_ids
.
to
(
args
.
device
)
attention_mask
=
inputs_dict
.
attention_mask
.
to
(
args
.
device
)
outputs
=
rag_model
.
generate
(
# rag_model overwrites generate
input_ids
,
attention_mask
=
attention_mask
,
num_beams
=
args
.
num_beams
,
min_length
=
args
.
min_length
,
max_length
=
args
.
max_length
,
early_stopping
=
False
,
num_return_sequences
=
1
,
bad_words_ids
=
[[
0
,
0
]],
# BART likes to repeat BOS tokens, dont allow it to generate more than one
)
answers
=
rag_model
.
retriever
.
generator_tokenizer
.
batch_decode
(
outputs
,
skip_special_tokens
=
True
)
if
args
.
print_predictions
:
for
q
,
a
in
zip
(
questions
,
answers
):
logger
.
info
(
"Q: {} - A: {}"
.
format
(
q
,
a
))
return
answers
def
get_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_type"
,
choices
=
[
"rag_sequence"
,
"rag_token"
,
"bart"
],
type
=
str
,
help
=
"RAG model type: rag_sequence, rag_token or bart, if none specified, the type is inferred from the model_name_or_path"
,
)
parser
.
add_argument
(
"--index_name"
,
default
=
None
,
choices
=
[
"exact"
,
"compressed"
,
"legacy"
],
type
=
str
,
help
=
"RAG model retriever type"
,
)
parser
.
add_argument
(
"--index_path"
,
default
=
None
,
type
=
str
,
help
=
"Path to the retrieval index"
,
)
parser
.
add_argument
(
"--n_docs"
,
default
=
5
,
type
=
int
,
help
=
"Number of retrieved docs"
)
parser
.
add_argument
(
"--model_name_or_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to pretrained checkpoints or model identifier from huggingface.co/models"
,
)
parser
.
add_argument
(
"--eval_mode"
,
choices
=
[
"e2e"
,
"retrieval"
],
default
=
"e2e"
,
type
=
str
,
help
=
"Evaluation mode, e2e calculates exact match and F1 of the downstream task, retrieval calculates precision@k."
,
)
parser
.
add_argument
(
"--k"
,
default
=
1
,
type
=
int
,
help
=
"k for the precision@k calculation"
)
parser
.
add_argument
(
"--evaluation_set"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to a file containing evaluation samples"
,
)
parser
.
add_argument
(
"--gold_data_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to a tab-separated file with gold samples"
,
)
parser
.
add_argument
(
"--gold_data_mode"
,
default
=
"qa"
,
type
=
str
,
choices
=
[
"qa"
,
"ans"
],
help
=
"Format of the gold data file"
"qa - a single line in the following format: question [tab] answer_list"
"ans - a single line of the gold file contains the expected answer string"
,
)
parser
.
add_argument
(
"--predictions_path"
,
type
=
str
,
default
=
"predictions.txt"
,
help
=
"Name of the predictions file, to be stored in the checkpoints directory"
,
)
parser
.
add_argument
(
"--eval_all_checkpoints"
,
action
=
"store_true"
,
help
=
"Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number"
,
)
parser
.
add_argument
(
"--eval_batch_size"
,
default
=
8
,
type
=
int
,
help
=
"Batch size per GPU/CPU for evaluation."
,
)
parser
.
add_argument
(
"--recalculate"
,
help
=
"Recalculate predictions even if the prediction file exists"
,
action
=
"store_true"
,
)
parser
.
add_argument
(
"--num_beams"
,
default
=
4
,
type
=
int
,
help
=
"Number of beams to be used when generating answers"
,
)
parser
.
add_argument
(
"--min_length"
,
default
=
1
,
type
=
int
,
help
=
"Min length of the generated answers"
)
parser
.
add_argument
(
"--max_length"
,
default
=
50
,
type
=
int
,
help
=
"Max length of the generated answers"
)
parser
.
add_argument
(
"--print_predictions"
,
action
=
"store_true"
,
help
=
"If True, prints predictions while evaluating."
,
)
parser
.
add_argument
(
"--print_docs"
,
action
=
"store_true"
,
help
=
"If True, prints docs retried while generating."
,
)
args
=
parser
.
parse_args
()
args
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
return
args
def
main
(
args
):
model_kwargs
=
{}
if
args
.
model_type
is
None
:
args
.
model_type
=
infer_model_type
(
args
.
model_name_or_path
)
assert
args
.
model_type
is
not
None
if
args
.
model_type
.
startswith
(
"rag"
):
model_class
=
RagTokenForGeneration
if
args
.
model_type
==
"rag_token"
else
RagSequenceForGeneration
model_kwargs
[
"n_docs"
]
=
args
.
n_docs
if
args
.
index_name
is
not
None
:
model_kwargs
[
"index_name"
]
=
args
.
index_name
if
args
.
index_path
is
not
None
:
model_kwargs
[
"index_path"
]
=
args
.
index_path
else
:
model_class
=
BartForConditionalGeneration
checkpoints
=
(
[
f
.
path
for
f
in
os
.
scandir
(
args
.
model_name_or_path
)
if
f
.
is_dir
()]
if
args
.
eval_all_checkpoints
else
[
args
.
model_name_or_path
]
)
logger
.
info
(
"Evaluate the following checkpoints:
%s
"
,
checkpoints
)
score_fn
=
get_scores
if
args
.
eval_mode
==
"e2e"
else
get_precision_at_k
evaluate_batch_fn
=
evaluate_batch_e2e
if
args
.
eval_mode
==
"e2e"
else
evaluate_batch_retrieval
for
checkpoint
in
checkpoints
:
if
os
.
path
.
exists
(
args
.
predictions_path
)
and
(
not
args
.
recalculate
):
logger
.
info
(
"Calculating metrics based on an existing predictions file: {}"
.
format
(
args
.
predictions_path
))
score_fn
(
args
,
args
.
predictions_path
,
args
.
gold_data_path
)
continue
logger
.
info
(
"***** Running evaluation for {} *****"
.
format
(
checkpoint
))
logger
.
info
(
" Batch size =
%d
"
,
args
.
eval_batch_size
)
logger
.
info
(
" Predictions will be stored under {}"
.
format
(
args
.
predictions_path
))
if
args
.
model_type
.
startswith
(
"rag"
):
retriever
=
RagRetriever
.
from_pretrained
(
checkpoint
,
**
model_kwargs
)
model
=
model_class
.
from_pretrained
(
checkpoint
,
retriever
=
retriever
,
**
model_kwargs
)
model
.
retriever
.
init_retrieval
()
else
:
model
=
model_class
.
from_pretrained
(
checkpoint
,
**
model_kwargs
)
model
.
to
(
args
.
device
)
with
open
(
args
.
evaluation_set
,
"r"
)
as
eval_file
,
open
(
args
.
predictions_path
,
"w"
)
as
preds_file
:
questions
=
[]
for
line
in
tqdm
(
eval_file
):
questions
.
append
(
line
.
strip
())
if
len
(
questions
)
==
args
.
eval_batch_size
:
answers
=
evaluate_batch_fn
(
args
,
model
,
questions
)
preds_file
.
write
(
"
\n
"
.
join
(
answers
)
+
"
\n
"
)
preds_file
.
flush
()
questions
=
[]
if
len
(
questions
)
>
0
:
answers
=
evaluate_batch_fn
(
args
,
model
,
questions
)
preds_file
.
write
(
"
\n
"
.
join
(
answers
))
preds_file
.
flush
()
score_fn
(
args
,
args
.
predictions_path
,
args
.
gold_data_path
)
if
__name__
==
"__main__"
:
args
=
get_args
()
main
(
args
)
Event Timeline
Log In to Comment