Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F120562960
test_retrieval_rag.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Sat, Jul 5, 06:17
Size
15 KB
Mime Type
text/x-python
Expires
Mon, Jul 7, 06:17 (1 d, 23 h)
Engine
blob
Format
Raw Data
Handle
27203073
Attached To
R11484 ADDI
test_retrieval_rag.py
View Options
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
import
pickle
import
shutil
import
tempfile
from
unittest
import
TestCase
from
unittest.mock
import
patch
import
numpy
as
np
from
datasets
import
Dataset
from
transformers
import
is_faiss_available
from
transformers.models.bart.configuration_bart
import
BartConfig
from
transformers.models.bart.tokenization_bart
import
BartTokenizer
from
transformers.models.bert.tokenization_bert
import
VOCAB_FILES_NAMES
as
DPR_VOCAB_FILES_NAMES
from
transformers.models.dpr.configuration_dpr
import
DPRConfig
from
transformers.models.dpr.tokenization_dpr
import
DPRQuestionEncoderTokenizer
from
transformers.models.rag.configuration_rag
import
RagConfig
from
transformers.models.rag.retrieval_rag
import
CustomHFIndex
,
RagRetriever
from
transformers.models.roberta.tokenization_roberta
import
VOCAB_FILES_NAMES
as
BART_VOCAB_FILES_NAMES
from
transformers.testing_utils
import
(
require_datasets
,
require_faiss
,
require_sentencepiece
,
require_tokenizers
,
require_torch
,
)
if
is_faiss_available
():
import
faiss
@require_faiss
@require_datasets
class
RagRetrieverTest
(
TestCase
):
def
setUp
(
self
):
self
.
tmpdirname
=
tempfile
.
mkdtemp
()
self
.
retrieval_vector_size
=
8
# DPR tok
vocab_tokens
=
[
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"[PAD]"
,
"[MASK]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"##ing"
,
","
,
"low"
,
"lowest"
,
]
dpr_tokenizer_path
=
os
.
path
.
join
(
self
.
tmpdirname
,
"dpr_tokenizer"
)
os
.
makedirs
(
dpr_tokenizer_path
,
exist_ok
=
True
)
self
.
vocab_file
=
os
.
path
.
join
(
dpr_tokenizer_path
,
DPR_VOCAB_FILES_NAMES
[
"vocab_file"
])
with
open
(
self
.
vocab_file
,
"w"
,
encoding
=
"utf-8"
)
as
vocab_writer
:
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]))
# BART tok
vocab
=
[
"l"
,
"o"
,
"w"
,
"e"
,
"r"
,
"s"
,
"t"
,
"i"
,
"d"
,
"n"
,
"
\u0120
"
,
"
\u0120
l"
,
"
\u0120
n"
,
"
\u0120
lo"
,
"
\u0120
low"
,
"er"
,
"
\u0120
lowest"
,
"
\u0120
newer"
,
"
\u0120
wider"
,
"<unk>"
,
]
vocab_tokens
=
dict
(
zip
(
vocab
,
range
(
len
(
vocab
))))
merges
=
[
"#version: 0.2"
,
"
\u0120
l"
,
"
\u0120
l o"
,
"
\u0120
lo w"
,
"e r"
,
""
]
self
.
special_tokens_map
=
{
"unk_token"
:
"<unk>"
}
bart_tokenizer_path
=
os
.
path
.
join
(
self
.
tmpdirname
,
"bart_tokenizer"
)
os
.
makedirs
(
bart_tokenizer_path
,
exist_ok
=
True
)
self
.
vocab_file
=
os
.
path
.
join
(
bart_tokenizer_path
,
BART_VOCAB_FILES_NAMES
[
"vocab_file"
])
self
.
merges_file
=
os
.
path
.
join
(
bart_tokenizer_path
,
BART_VOCAB_FILES_NAMES
[
"merges_file"
])
with
open
(
self
.
vocab_file
,
"w"
,
encoding
=
"utf-8"
)
as
fp
:
fp
.
write
(
json
.
dumps
(
vocab_tokens
)
+
"
\n
"
)
with
open
(
self
.
merges_file
,
"w"
,
encoding
=
"utf-8"
)
as
fp
:
fp
.
write
(
"
\n
"
.
join
(
merges
))
def
get_dpr_tokenizer
(
self
)
->
DPRQuestionEncoderTokenizer
:
return
DPRQuestionEncoderTokenizer
.
from_pretrained
(
os
.
path
.
join
(
self
.
tmpdirname
,
"dpr_tokenizer"
))
def
get_bart_tokenizer
(
self
)
->
BartTokenizer
:
return
BartTokenizer
.
from_pretrained
(
os
.
path
.
join
(
self
.
tmpdirname
,
"bart_tokenizer"
))
def
tearDown
(
self
):
shutil
.
rmtree
(
self
.
tmpdirname
)
def
get_dummy_dataset
(
self
):
dataset
=
Dataset
.
from_dict
(
{
"id"
:
[
"0"
,
"1"
],
"text"
:
[
"foo"
,
"bar"
],
"title"
:
[
"Foo"
,
"Bar"
],
"embeddings"
:
[
np
.
ones
(
self
.
retrieval_vector_size
),
2
*
np
.
ones
(
self
.
retrieval_vector_size
)],
}
)
dataset
.
add_faiss_index
(
"embeddings"
,
string_factory
=
"Flat"
,
metric_type
=
faiss
.
METRIC_INNER_PRODUCT
)
return
dataset
def
get_dummy_canonical_hf_index_retriever
(
self
):
dataset
=
self
.
get_dummy_dataset
()
config
=
RagConfig
(
retrieval_vector_size
=
self
.
retrieval_vector_size
,
question_encoder
=
DPRConfig
()
.
to_dict
(),
generator
=
BartConfig
()
.
to_dict
(),
)
with
patch
(
"transformers.models.rag.retrieval_rag.load_dataset"
)
as
mock_load_dataset
:
mock_load_dataset
.
return_value
=
dataset
retriever
=
RagRetriever
(
config
,
question_encoder_tokenizer
=
self
.
get_dpr_tokenizer
(),
generator_tokenizer
=
self
.
get_bart_tokenizer
(),
)
return
retriever
def
get_dummy_custom_hf_index_retriever
(
self
,
from_disk
:
bool
):
dataset
=
self
.
get_dummy_dataset
()
config
=
RagConfig
(
retrieval_vector_size
=
self
.
retrieval_vector_size
,
question_encoder
=
DPRConfig
()
.
to_dict
(),
generator
=
BartConfig
()
.
to_dict
(),
index_name
=
"custom"
,
)
if
from_disk
:
config
.
passages_path
=
os
.
path
.
join
(
self
.
tmpdirname
,
"dataset"
)
config
.
index_path
=
os
.
path
.
join
(
self
.
tmpdirname
,
"index.faiss"
)
dataset
.
get_index
(
"embeddings"
)
.
save
(
os
.
path
.
join
(
self
.
tmpdirname
,
"index.faiss"
))
dataset
.
drop_index
(
"embeddings"
)
dataset
.
save_to_disk
(
os
.
path
.
join
(
self
.
tmpdirname
,
"dataset"
))
del
dataset
retriever
=
RagRetriever
(
config
,
question_encoder_tokenizer
=
self
.
get_dpr_tokenizer
(),
generator_tokenizer
=
self
.
get_bart_tokenizer
(),
)
else
:
retriever
=
RagRetriever
(
config
,
question_encoder_tokenizer
=
self
.
get_dpr_tokenizer
(),
generator_tokenizer
=
self
.
get_bart_tokenizer
(),
index
=
CustomHFIndex
(
config
.
retrieval_vector_size
,
dataset
),
)
return
retriever
def
get_dummy_legacy_index_retriever
(
self
):
dataset
=
Dataset
.
from_dict
(
{
"id"
:
[
"0"
,
"1"
],
"text"
:
[
"foo"
,
"bar"
],
"title"
:
[
"Foo"
,
"Bar"
],
"embeddings"
:
[
np
.
ones
(
self
.
retrieval_vector_size
+
1
),
2
*
np
.
ones
(
self
.
retrieval_vector_size
+
1
)],
}
)
dataset
.
add_faiss_index
(
"embeddings"
,
string_factory
=
"Flat"
,
metric_type
=
faiss
.
METRIC_INNER_PRODUCT
)
index_file_name
=
os
.
path
.
join
(
self
.
tmpdirname
,
"hf_bert_base.hnswSQ8_correct_phi_128.c_index"
)
dataset
.
save_faiss_index
(
"embeddings"
,
index_file_name
+
".index.dpr"
)
pickle
.
dump
(
dataset
[
"id"
],
open
(
index_file_name
+
".index_meta.dpr"
,
"wb"
))
passages_file_name
=
os
.
path
.
join
(
self
.
tmpdirname
,
"psgs_w100.tsv.pkl"
)
passages
=
{
sample
[
"id"
]:
[
sample
[
"text"
],
sample
[
"title"
]]
for
sample
in
dataset
}
pickle
.
dump
(
passages
,
open
(
passages_file_name
,
"wb"
))
config
=
RagConfig
(
retrieval_vector_size
=
self
.
retrieval_vector_size
,
question_encoder
=
DPRConfig
()
.
to_dict
(),
generator
=
BartConfig
()
.
to_dict
(),
index_name
=
"legacy"
,
index_path
=
self
.
tmpdirname
,
)
retriever
=
RagRetriever
(
config
,
question_encoder_tokenizer
=
self
.
get_dpr_tokenizer
(),
generator_tokenizer
=
self
.
get_bart_tokenizer
()
)
return
retriever
def
test_canonical_hf_index_retriever_retrieve
(
self
):
n_docs
=
1
retriever
=
self
.
get_dummy_canonical_hf_index_retriever
()
hidden_states
=
np
.
array
(
[
np
.
ones
(
self
.
retrieval_vector_size
),
-
np
.
ones
(
self
.
retrieval_vector_size
)],
dtype
=
np
.
float32
)
retrieved_doc_embeds
,
doc_ids
,
doc_dicts
=
retriever
.
retrieve
(
hidden_states
,
n_docs
=
n_docs
)
self
.
assertEqual
(
retrieved_doc_embeds
.
shape
,
(
2
,
n_docs
,
self
.
retrieval_vector_size
))
self
.
assertEqual
(
len
(
doc_dicts
),
2
)
self
.
assertEqual
(
sorted
(
doc_dicts
[
0
]),
[
"embeddings"
,
"id"
,
"text"
,
"title"
])
self
.
assertEqual
(
len
(
doc_dicts
[
0
][
"id"
]),
n_docs
)
self
.
assertEqual
(
doc_dicts
[
0
][
"id"
][
0
],
"1"
)
# max inner product is reached with second doc
self
.
assertEqual
(
doc_dicts
[
1
][
"id"
][
0
],
"0"
)
# max inner product is reached with first doc
self
.
assertListEqual
(
doc_ids
.
tolist
(),
[[
1
],
[
0
]])
def
test_canonical_hf_index_retriever_save_and_from_pretrained
(
self
):
retriever
=
self
.
get_dummy_canonical_hf_index_retriever
()
with
tempfile
.
TemporaryDirectory
()
as
tmp_dirname
:
with
patch
(
"transformers.models.rag.retrieval_rag.load_dataset"
)
as
mock_load_dataset
:
mock_load_dataset
.
return_value
=
self
.
get_dummy_dataset
()
retriever
.
save_pretrained
(
tmp_dirname
)
retriever
=
RagRetriever
.
from_pretrained
(
tmp_dirname
)
self
.
assertIsInstance
(
retriever
,
RagRetriever
)
hidden_states
=
np
.
array
(
[
np
.
ones
(
self
.
retrieval_vector_size
),
-
np
.
ones
(
self
.
retrieval_vector_size
)],
dtype
=
np
.
float32
)
out
=
retriever
.
retrieve
(
hidden_states
,
n_docs
=
1
)
self
.
assertTrue
(
out
is
not
None
)
def
test_custom_hf_index_retriever_retrieve
(
self
):
n_docs
=
1
retriever
=
self
.
get_dummy_custom_hf_index_retriever
(
from_disk
=
False
)
hidden_states
=
np
.
array
(
[
np
.
ones
(
self
.
retrieval_vector_size
),
-
np
.
ones
(
self
.
retrieval_vector_size
)],
dtype
=
np
.
float32
)
retrieved_doc_embeds
,
doc_ids
,
doc_dicts
=
retriever
.
retrieve
(
hidden_states
,
n_docs
=
n_docs
)
self
.
assertEqual
(
retrieved_doc_embeds
.
shape
,
(
2
,
n_docs
,
self
.
retrieval_vector_size
))
self
.
assertEqual
(
len
(
doc_dicts
),
2
)
self
.
assertEqual
(
sorted
(
doc_dicts
[
0
]),
[
"embeddings"
,
"id"
,
"text"
,
"title"
])
self
.
assertEqual
(
len
(
doc_dicts
[
0
][
"id"
]),
n_docs
)
self
.
assertEqual
(
doc_dicts
[
0
][
"id"
][
0
],
"1"
)
# max inner product is reached with second doc
self
.
assertEqual
(
doc_dicts
[
1
][
"id"
][
0
],
"0"
)
# max inner product is reached with first doc
self
.
assertListEqual
(
doc_ids
.
tolist
(),
[[
1
],
[
0
]])
def
test_custom_hf_index_retriever_save_and_from_pretrained
(
self
):
retriever
=
self
.
get_dummy_custom_hf_index_retriever
(
from_disk
=
False
)
with
tempfile
.
TemporaryDirectory
()
as
tmp_dirname
:
retriever
.
save_pretrained
(
tmp_dirname
)
retriever
=
RagRetriever
.
from_pretrained
(
tmp_dirname
)
self
.
assertIsInstance
(
retriever
,
RagRetriever
)
hidden_states
=
np
.
array
(
[
np
.
ones
(
self
.
retrieval_vector_size
),
-
np
.
ones
(
self
.
retrieval_vector_size
)],
dtype
=
np
.
float32
)
out
=
retriever
.
retrieve
(
hidden_states
,
n_docs
=
1
)
self
.
assertTrue
(
out
is
not
None
)
def
test_custom_hf_index_retriever_retrieve_from_disk
(
self
):
n_docs
=
1
retriever
=
self
.
get_dummy_custom_hf_index_retriever
(
from_disk
=
True
)
hidden_states
=
np
.
array
(
[
np
.
ones
(
self
.
retrieval_vector_size
),
-
np
.
ones
(
self
.
retrieval_vector_size
)],
dtype
=
np
.
float32
)
retrieved_doc_embeds
,
doc_ids
,
doc_dicts
=
retriever
.
retrieve
(
hidden_states
,
n_docs
=
n_docs
)
self
.
assertEqual
(
retrieved_doc_embeds
.
shape
,
(
2
,
n_docs
,
self
.
retrieval_vector_size
))
self
.
assertEqual
(
len
(
doc_dicts
),
2
)
self
.
assertEqual
(
sorted
(
doc_dicts
[
0
]),
[
"embeddings"
,
"id"
,
"text"
,
"title"
])
self
.
assertEqual
(
len
(
doc_dicts
[
0
][
"id"
]),
n_docs
)
self
.
assertEqual
(
doc_dicts
[
0
][
"id"
][
0
],
"1"
)
# max inner product is reached with second doc
self
.
assertEqual
(
doc_dicts
[
1
][
"id"
][
0
],
"0"
)
# max inner product is reached with first doc
self
.
assertListEqual
(
doc_ids
.
tolist
(),
[[
1
],
[
0
]])
def
test_custom_hf_index_retriever_save_and_from_pretrained_from_disk
(
self
):
retriever
=
self
.
get_dummy_custom_hf_index_retriever
(
from_disk
=
True
)
with
tempfile
.
TemporaryDirectory
()
as
tmp_dirname
:
retriever
.
save_pretrained
(
tmp_dirname
)
retriever
=
RagRetriever
.
from_pretrained
(
tmp_dirname
)
self
.
assertIsInstance
(
retriever
,
RagRetriever
)
hidden_states
=
np
.
array
(
[
np
.
ones
(
self
.
retrieval_vector_size
),
-
np
.
ones
(
self
.
retrieval_vector_size
)],
dtype
=
np
.
float32
)
out
=
retriever
.
retrieve
(
hidden_states
,
n_docs
=
1
)
self
.
assertTrue
(
out
is
not
None
)
def
test_legacy_index_retriever_retrieve
(
self
):
n_docs
=
1
retriever
=
self
.
get_dummy_legacy_index_retriever
()
hidden_states
=
np
.
array
(
[
np
.
ones
(
self
.
retrieval_vector_size
),
-
np
.
ones
(
self
.
retrieval_vector_size
)],
dtype
=
np
.
float32
)
retrieved_doc_embeds
,
doc_ids
,
doc_dicts
=
retriever
.
retrieve
(
hidden_states
,
n_docs
=
n_docs
)
self
.
assertEqual
(
retrieved_doc_embeds
.
shape
,
(
2
,
n_docs
,
self
.
retrieval_vector_size
))
self
.
assertEqual
(
len
(
doc_dicts
),
2
)
self
.
assertEqual
(
sorted
(
doc_dicts
[
0
]),
[
"text"
,
"title"
])
self
.
assertEqual
(
len
(
doc_dicts
[
0
][
"text"
]),
n_docs
)
self
.
assertEqual
(
doc_dicts
[
0
][
"text"
][
0
],
"bar"
)
# max inner product is reached with second doc
self
.
assertEqual
(
doc_dicts
[
1
][
"text"
][
0
],
"foo"
)
# max inner product is reached with first doc
self
.
assertListEqual
(
doc_ids
.
tolist
(),
[[
1
],
[
0
]])
def
test_legacy_hf_index_retriever_save_and_from_pretrained
(
self
):
retriever
=
self
.
get_dummy_legacy_index_retriever
()
with
tempfile
.
TemporaryDirectory
()
as
tmp_dirname
:
retriever
.
save_pretrained
(
tmp_dirname
)
retriever
=
RagRetriever
.
from_pretrained
(
tmp_dirname
)
self
.
assertIsInstance
(
retriever
,
RagRetriever
)
hidden_states
=
np
.
array
(
[
np
.
ones
(
self
.
retrieval_vector_size
),
-
np
.
ones
(
self
.
retrieval_vector_size
)],
dtype
=
np
.
float32
)
out
=
retriever
.
retrieve
(
hidden_states
,
n_docs
=
1
)
self
.
assertTrue
(
out
is
not
None
)
@require_torch
@require_tokenizers
@require_sentencepiece
def
test_hf_index_retriever_call
(
self
):
import
torch
n_docs
=
1
retriever
=
self
.
get_dummy_canonical_hf_index_retriever
()
question_input_ids
=
[[
5
,
7
],
[
10
,
11
]]
hidden_states
=
np
.
array
(
[
np
.
ones
(
self
.
retrieval_vector_size
),
-
np
.
ones
(
self
.
retrieval_vector_size
)],
dtype
=
np
.
float32
)
out
=
retriever
(
question_input_ids
,
hidden_states
,
prefix
=
retriever
.
config
.
generator
.
prefix
,
n_docs
=
n_docs
)
context_input_ids
,
context_attention_mask
,
retrieved_doc_embeds
=
(
out
[
"context_input_ids"
],
out
[
"context_attention_mask"
],
out
[
"retrieved_doc_embeds"
],
)
self
.
assertEqual
(
retrieved_doc_embeds
.
shape
,
(
2
,
n_docs
,
self
.
retrieval_vector_size
))
self
.
assertIsInstance
(
context_input_ids
,
list
)
self
.
assertIsInstance
(
context_attention_mask
,
list
)
self
.
assertIsInstance
(
retrieved_doc_embeds
,
np
.
ndarray
)
out
=
retriever
(
question_input_ids
,
hidden_states
,
prefix
=
retriever
.
config
.
generator
.
prefix
,
n_docs
=
n_docs
,
return_tensors
=
"pt"
,
)
context_input_ids
,
context_attention_mask
,
retrieved_doc_embeds
,
doc_ids
=
(
# noqa: F841
out
[
"context_input_ids"
],
out
[
"context_attention_mask"
],
out
[
"retrieved_doc_embeds"
],
out
[
"doc_ids"
],
)
self
.
assertEqual
(
retrieved_doc_embeds
.
shape
,
(
2
,
n_docs
,
self
.
retrieval_vector_size
))
self
.
assertIsInstance
(
context_input_ids
,
torch
.
Tensor
)
self
.
assertIsInstance
(
context_attention_mask
,
torch
.
Tensor
)
self
.
assertIsInstance
(
retrieved_doc_embeds
,
torch
.
Tensor
)
Event Timeline
Log In to Comment