Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F61228951
lm_seqs_dataset.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Sun, May 5, 08:42
Size
5 KB
Mime Type
text/x-python
Expires
Tue, May 7, 08:42 (2 d)
Engine
blob
Format
Raw Data
Handle
17483038
Attached To
R11484 ADDI
lm_seqs_dataset.py
View Options
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Dataset to distilled models
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
"""
import
numpy
as
np
import
torch
from
torch.utils.data
import
Dataset
from
utils
import
logger
class
LmSeqsDataset
(
Dataset
):
"""Custom Dataset wrapping language modeling sequences.
Each sample will be retrieved by indexing the list of token_ids and their corresponding lengths.
Input:
------
params: `NameSpace` parameters
data: `List[np.array[int]]
"""
def
__init__
(
self
,
params
,
data
):
self
.
params
=
params
self
.
token_ids
=
np
.
array
(
data
)
self
.
lengths
=
np
.
array
([
len
(
t
)
for
t
in
data
])
self
.
check
()
self
.
remove_long_sequences
()
self
.
remove_empty_sequences
()
self
.
remove_unknown_sequences
()
self
.
check
()
self
.
print_statistics
()
def
__getitem__
(
self
,
index
):
return
(
self
.
token_ids
[
index
],
self
.
lengths
[
index
])
def
__len__
(
self
):
return
len
(
self
.
lengths
)
def
check
(
self
):
"""
Some sanity checks
"""
assert
len
(
self
.
token_ids
)
==
len
(
self
.
lengths
)
assert
all
(
self
.
lengths
[
i
]
==
len
(
self
.
token_ids
[
i
])
for
i
in
range
(
len
(
self
.
lengths
)))
def
remove_long_sequences
(
self
):
"""
Sequences that are too long are split by chunk of max_model_input_size.
"""
max_len
=
self
.
params
.
max_model_input_size
indices
=
self
.
lengths
>
max_len
logger
.
info
(
f
"Splitting {sum(indices)} too long sequences."
)
def
divide_chunks
(
l
,
n
):
return
[
l
[
i
:
i
+
n
]
for
i
in
range
(
0
,
len
(
l
),
n
)]
new_tok_ids
=
[]
new_lengths
=
[]
if
self
.
params
.
mlm
:
cls_id
,
sep_id
=
self
.
params
.
special_tok_ids
[
"cls_token"
],
self
.
params
.
special_tok_ids
[
"sep_token"
]
else
:
cls_id
,
sep_id
=
self
.
params
.
special_tok_ids
[
"bos_token"
],
self
.
params
.
special_tok_ids
[
"eos_token"
]
for
seq_
,
len_
in
zip
(
self
.
token_ids
,
self
.
lengths
):
assert
(
seq_
[
0
]
==
cls_id
)
and
(
seq_
[
-
1
]
==
sep_id
),
seq_
if
len_
<=
max_len
:
new_tok_ids
.
append
(
seq_
)
new_lengths
.
append
(
len_
)
else
:
sub_seqs
=
[]
for
sub_s
in
divide_chunks
(
seq_
,
max_len
-
2
):
if
sub_s
[
0
]
!=
cls_id
:
sub_s
=
np
.
insert
(
sub_s
,
0
,
cls_id
)
if
sub_s
[
-
1
]
!=
sep_id
:
sub_s
=
np
.
insert
(
sub_s
,
len
(
sub_s
),
sep_id
)
assert
len
(
sub_s
)
<=
max_len
assert
(
sub_s
[
0
]
==
cls_id
)
and
(
sub_s
[
-
1
]
==
sep_id
),
sub_s
sub_seqs
.
append
(
sub_s
)
new_tok_ids
.
extend
(
sub_seqs
)
new_lengths
.
extend
([
len
(
l
)
for
l
in
sub_seqs
])
self
.
token_ids
=
np
.
array
(
new_tok_ids
)
self
.
lengths
=
np
.
array
(
new_lengths
)
def
remove_empty_sequences
(
self
):
"""
Too short sequences are simply removed. This could be tuned.
"""
init_size
=
len
(
self
)
indices
=
self
.
lengths
>
11
self
.
token_ids
=
self
.
token_ids
[
indices
]
self
.
lengths
=
self
.
lengths
[
indices
]
new_size
=
len
(
self
)
logger
.
info
(
f
"Remove {init_size - new_size} too short (<=11 tokens) sequences."
)
def
remove_unknown_sequences
(
self
):
"""
Remove sequences with a (too) high level of unknown tokens.
"""
if
"unk_token"
not
in
self
.
params
.
special_tok_ids
:
return
else
:
unk_token_id
=
self
.
params
.
special_tok_ids
[
"unk_token"
]
init_size
=
len
(
self
)
unk_occs
=
np
.
array
([
np
.
count_nonzero
(
a
==
unk_token_id
)
for
a
in
self
.
token_ids
])
indices
=
(
unk_occs
/
self
.
lengths
)
<
0.5
self
.
token_ids
=
self
.
token_ids
[
indices
]
self
.
lengths
=
self
.
lengths
[
indices
]
new_size
=
len
(
self
)
logger
.
info
(
f
"Remove {init_size - new_size} sequences with a high level of unknown tokens (50%)."
)
def
print_statistics
(
self
):
"""
Print some statistics on the corpus. Only the master process.
"""
if
not
self
.
params
.
is_master
:
return
logger
.
info
(
f
"{len(self)} sequences"
)
# data_len = sum(self.lengths)
# nb_unique_tokens = len(Counter(list(chain(*self.token_ids))))
# logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)')
# unk_idx = self.params.special_tok_ids['unk_token']
# nb_unknown = sum([(t==unk_idx).sum() for t in self.token_ids])
# logger.info(f'{nb_unknown} unknown tokens (covering {100*nb_unknown/data_len:.2f}% of the data)')
def
batch_sequences
(
self
,
batch
):
"""
Do the padding and transform into torch.tensor.
"""
token_ids
=
[
t
[
0
]
for
t
in
batch
]
lengths
=
[
t
[
1
]
for
t
in
batch
]
assert
len
(
token_ids
)
==
len
(
lengths
)
# Max for paddings
max_seq_len_
=
max
(
lengths
)
# Pad token ids
if
self
.
params
.
mlm
:
pad_idx
=
self
.
params
.
special_tok_ids
[
"pad_token"
]
else
:
pad_idx
=
self
.
params
.
special_tok_ids
[
"unk_token"
]
tk_
=
[
list
(
t
.
astype
(
int
))
+
[
pad_idx
]
*
(
max_seq_len_
-
len
(
t
))
for
t
in
token_ids
]
assert
len
(
tk_
)
==
len
(
token_ids
)
assert
all
(
len
(
t
)
==
max_seq_len_
for
t
in
tk_
)
tk_t
=
torch
.
tensor
(
tk_
)
# (bs, max_seq_len_)
lg_t
=
torch
.
tensor
(
lengths
)
# (bs)
return
tk_t
,
lg_t
Event Timeline
Log In to Comment