Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F60242129
utils_rag.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Sun, Apr 28, 14:14
Size
7 KB
Mime Type
text/x-python
Expires
Tue, Apr 30, 14:14 (2 d)
Engine
blob
Format
Raw Data
Handle
17323862
Attached To
R11484 ADDI
utils_rag.py
View Options
import
itertools
import
json
import
linecache
import
os
import
pickle
import
re
import
socket
import
string
from
collections
import
Counter
from
logging
import
getLogger
from
pathlib
import
Path
from
typing
import
Callable
,
Dict
,
Iterable
,
List
import
git
import
torch
from
torch.utils.data
import
Dataset
from
transformers
import
BartTokenizer
,
RagTokenizer
,
T5Tokenizer
def
encode_line
(
tokenizer
,
line
,
max_length
,
padding_side
,
pad_to_max_length
=
True
,
return_tensors
=
"pt"
):
extra_kw
=
{
"add_prefix_space"
:
True
}
if
isinstance
(
tokenizer
,
BartTokenizer
)
and
not
line
.
startswith
(
" "
)
else
{}
tokenizer
.
padding_side
=
padding_side
return
tokenizer
(
[
line
],
max_length
=
max_length
,
padding
=
"max_length"
if
pad_to_max_length
else
None
,
truncation
=
True
,
return_tensors
=
return_tensors
,
add_special_tokens
=
True
,
**
extra_kw
,
)
def
trim_batch
(
input_ids
,
pad_token_id
,
attention_mask
=
None
,
):
"""Remove columns that are populated exclusively by pad_token_id"""
keep_column_mask
=
input_ids
.
ne
(
pad_token_id
)
.
any
(
dim
=
0
)
if
attention_mask
is
None
:
return
input_ids
[:,
keep_column_mask
]
else
:
return
(
input_ids
[:,
keep_column_mask
],
attention_mask
[:,
keep_column_mask
])
class
Seq2SeqDataset
(
Dataset
):
def
__init__
(
self
,
tokenizer
,
data_dir
,
max_source_length
,
max_target_length
,
type_path
=
"train"
,
n_obs
=
None
,
src_lang
=
None
,
tgt_lang
=
None
,
prefix
=
""
,
):
super
()
.
__init__
()
self
.
src_file
=
Path
(
data_dir
)
.
joinpath
(
type_path
+
".source"
)
self
.
tgt_file
=
Path
(
data_dir
)
.
joinpath
(
type_path
+
".target"
)
self
.
src_lens
=
self
.
get_char_lens
(
self
.
src_file
)
self
.
max_source_length
=
max_source_length
self
.
max_target_length
=
max_target_length
assert
min
(
self
.
src_lens
)
>
0
,
f
"found empty line in {self.src_file}"
self
.
tokenizer
=
tokenizer
self
.
prefix
=
prefix
if
n_obs
is
not
None
:
self
.
src_lens
=
self
.
src_lens
[:
n_obs
]
self
.
src_lang
=
src_lang
self
.
tgt_lang
=
tgt_lang
def
__len__
(
self
):
return
len
(
self
.
src_lens
)
def
__getitem__
(
self
,
index
)
->
Dict
[
str
,
torch
.
Tensor
]:
index
=
index
+
1
# linecache starts at 1
source_line
=
self
.
prefix
+
linecache
.
getline
(
str
(
self
.
src_file
),
index
)
.
rstrip
(
"
\n
"
)
tgt_line
=
linecache
.
getline
(
str
(
self
.
tgt_file
),
index
)
.
rstrip
(
"
\n
"
)
assert
source_line
,
f
"empty source line for index {index}"
assert
tgt_line
,
f
"empty tgt line for index {index}"
# Need to add eos token manually for T5
if
isinstance
(
self
.
tokenizer
,
T5Tokenizer
):
source_line
+=
self
.
tokenizer
.
eos_token
tgt_line
+=
self
.
tokenizer
.
eos_token
# Pad source and target to the right
source_tokenizer
=
(
self
.
tokenizer
.
question_encoder
if
isinstance
(
self
.
tokenizer
,
RagTokenizer
)
else
self
.
tokenizer
)
target_tokenizer
=
self
.
tokenizer
.
generator
if
isinstance
(
self
.
tokenizer
,
RagTokenizer
)
else
self
.
tokenizer
source_inputs
=
encode_line
(
source_tokenizer
,
source_line
,
self
.
max_source_length
,
"right"
)
target_inputs
=
encode_line
(
target_tokenizer
,
tgt_line
,
self
.
max_target_length
,
"right"
)
source_ids
=
source_inputs
[
"input_ids"
]
.
squeeze
()
target_ids
=
target_inputs
[
"input_ids"
]
.
squeeze
()
src_mask
=
source_inputs
[
"attention_mask"
]
.
squeeze
()
return
{
"input_ids"
:
source_ids
,
"attention_mask"
:
src_mask
,
"decoder_input_ids"
:
target_ids
,
}
@staticmethod
def
get_char_lens
(
data_file
):
return
[
len
(
x
)
for
x
in
Path
(
data_file
)
.
open
()
.
readlines
()]
def
collate_fn
(
self
,
batch
)
->
Dict
[
str
,
torch
.
Tensor
]:
input_ids
=
torch
.
stack
([
x
[
"input_ids"
]
for
x
in
batch
])
masks
=
torch
.
stack
([
x
[
"attention_mask"
]
for
x
in
batch
])
target_ids
=
torch
.
stack
([
x
[
"decoder_input_ids"
]
for
x
in
batch
])
tgt_pad_token_id
=
(
self
.
tokenizer
.
generator
.
pad_token_id
if
isinstance
(
self
.
tokenizer
,
RagTokenizer
)
else
self
.
tokenizer
.
pad_token_id
)
src_pad_token_id
=
(
self
.
tokenizer
.
question_encoder
.
pad_token_id
if
isinstance
(
self
.
tokenizer
,
RagTokenizer
)
else
self
.
tokenizer
.
pad_token_id
)
y
=
trim_batch
(
target_ids
,
tgt_pad_token_id
)
source_ids
,
source_mask
=
trim_batch
(
input_ids
,
src_pad_token_id
,
attention_mask
=
masks
)
batch
=
{
"input_ids"
:
source_ids
,
"attention_mask"
:
source_mask
,
"decoder_input_ids"
:
y
,
}
return
batch
logger
=
getLogger
(
__name__
)
def
flatten_list
(
summary_ids
:
List
[
List
]):
return
[
x
for
x
in
itertools
.
chain
.
from_iterable
(
summary_ids
)]
def
save_git_info
(
folder_path
:
str
)
->
None
:
"""Save git information to output_dir/git_log.json"""
repo_infos
=
get_git_info
()
save_json
(
repo_infos
,
os
.
path
.
join
(
folder_path
,
"git_log.json"
))
def
save_json
(
content
,
path
,
indent
=
4
,
**
json_dump_kwargs
):
with
open
(
path
,
"w"
)
as
f
:
json
.
dump
(
content
,
f
,
indent
=
indent
,
**
json_dump_kwargs
)
def
load_json
(
path
):
with
open
(
path
)
as
f
:
return
json
.
load
(
f
)
def
get_git_info
():
repo
=
git
.
Repo
(
search_parent_directories
=
True
)
repo_infos
=
{
"repo_id"
:
str
(
repo
),
"repo_sha"
:
str
(
repo
.
head
.
object
.
hexsha
),
"repo_branch"
:
str
(
repo
.
active_branch
),
"hostname"
:
str
(
socket
.
gethostname
()),
}
return
repo_infos
def
lmap
(
f
:
Callable
,
x
:
Iterable
)
->
List
:
"""list(map(f, x))"""
return
list
(
map
(
f
,
x
))
def
pickle_save
(
obj
,
path
):
"""pickle.dump(obj, path)"""
with
open
(
path
,
"wb"
)
as
f
:
return
pickle
.
dump
(
obj
,
f
)
def
normalize_answer
(
s
):
"""Lower text and remove punctuation, articles and extra whitespace."""
def
remove_articles
(
text
):
return
re
.
sub
(
r"\b(a|an|the)\b"
,
" "
,
text
)
def
white_space_fix
(
text
):
return
" "
.
join
(
text
.
split
())
def
remove_punc
(
text
):
exclude
=
set
(
string
.
punctuation
)
return
""
.
join
(
ch
for
ch
in
text
if
ch
not
in
exclude
)
def
lower
(
text
):
return
text
.
lower
()
return
white_space_fix
(
remove_articles
(
remove_punc
(
lower
(
s
))))
def
f1_score
(
prediction
,
ground_truth
):
prediction_tokens
=
normalize_answer
(
prediction
)
.
split
()
ground_truth_tokens
=
normalize_answer
(
ground_truth
)
.
split
()
common
=
Counter
(
prediction_tokens
)
&
Counter
(
ground_truth_tokens
)
num_same
=
sum
(
common
.
values
())
if
num_same
==
0
:
return
0
precision
=
1.0
*
num_same
/
len
(
prediction_tokens
)
recall
=
1.0
*
num_same
/
len
(
ground_truth_tokens
)
f1
=
(
2
*
precision
*
recall
)
/
(
precision
+
recall
)
return
f1
def
exact_match_score
(
prediction
,
ground_truth
):
return
normalize_answer
(
prediction
)
==
normalize_answer
(
ground_truth
)
def
calculate_exact_match
(
output_lns
:
List
[
str
],
reference_lns
:
List
[
str
])
->
Dict
:
assert
len
(
output_lns
)
==
len
(
reference_lns
)
em
=
0
for
hypo
,
pred
in
zip
(
output_lns
,
reference_lns
):
em
+=
exact_match_score
(
hypo
,
pred
)
if
len
(
output_lns
)
>
0
:
em
/=
len
(
output_lns
)
return
{
"em"
:
em
}
def
is_rag_model
(
model_prefix
):
return
model_prefix
.
startswith
(
"rag"
)
def
set_extra_model_params
(
extra_params
,
hparams
,
config
):
equivalent_param
=
{
p
:
p
for
p
in
extra_params
}
# T5 models don't have `dropout` param, they have `dropout_rate` instead
equivalent_param
[
"dropout"
]
=
"dropout_rate"
for
p
in
extra_params
:
if
getattr
(
hparams
,
p
,
None
):
if
not
hasattr
(
config
,
p
)
and
not
hasattr
(
config
,
equivalent_param
[
p
]):
logger
.
info
(
"config doesn't have a `{}` attribute"
.
format
(
p
))
delattr
(
hparams
,
p
)
continue
set_p
=
p
if
hasattr
(
config
,
p
)
else
equivalent_param
[
p
]
setattr
(
config
,
set_p
,
getattr
(
hparams
,
p
))
delattr
(
hparams
,
p
)
return
hparams
,
config
Event Timeline
Log In to Comment