Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F69220095
modeling_bertabs.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Sun, Jun 30, 19:28
Size
37 KB
Mime Type
text/x-python
Expires
Tue, Jul 2, 19:28 (2 d)
Engine
blob
Format
Raw Data
Handle
18682991
Attached To
R11484 ADDI
modeling_bertabs.py
View Options
# MIT License
# Copyright (c) 2019 Yang Liu and the HuggingFace team
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import
copy
import
math
import
numpy
as
np
import
torch
from
torch
import
nn
from
torch.nn.init
import
xavier_uniform_
from
configuration_bertabs
import
BertAbsConfig
from
transformers
import
BertConfig
,
BertModel
,
PreTrainedModel
MAX_SIZE
=
5000
BERTABS_FINETUNED_MODEL_ARCHIVE_LIST
=
[
"remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization"
,
]
class
BertAbsPreTrainedModel
(
PreTrainedModel
):
config_class
=
BertAbsConfig
load_tf_weights
=
False
base_model_prefix
=
"bert"
class
BertAbs
(
BertAbsPreTrainedModel
):
def
__init__
(
self
,
args
,
checkpoint
=
None
,
bert_extractive_checkpoint
=
None
):
super
()
.
__init__
(
args
)
self
.
args
=
args
self
.
bert
=
Bert
()
# If pre-trained weights are passed for Bert, load these.
load_bert_pretrained_extractive
=
True
if
bert_extractive_checkpoint
else
False
if
load_bert_pretrained_extractive
:
self
.
bert
.
model
.
load_state_dict
(
dict
([(
n
[
11
:],
p
)
for
n
,
p
in
bert_extractive_checkpoint
.
items
()
if
n
.
startswith
(
"bert.model"
)]),
strict
=
True
,
)
self
.
vocab_size
=
self
.
bert
.
model
.
config
.
vocab_size
if
args
.
max_pos
>
512
:
my_pos_embeddings
=
nn
.
Embedding
(
args
.
max_pos
,
self
.
bert
.
model
.
config
.
hidden_size
)
my_pos_embeddings
.
weight
.
data
[:
512
]
=
self
.
bert
.
model
.
embeddings
.
position_embeddings
.
weight
.
data
my_pos_embeddings
.
weight
.
data
[
512
:]
=
self
.
bert
.
model
.
embeddings
.
position_embeddings
.
weight
.
data
[
-
1
][
None
,
:
]
.
repeat
(
args
.
max_pos
-
512
,
1
)
self
.
bert
.
model
.
embeddings
.
position_embeddings
=
my_pos_embeddings
tgt_embeddings
=
nn
.
Embedding
(
self
.
vocab_size
,
self
.
bert
.
model
.
config
.
hidden_size
,
padding_idx
=
0
)
tgt_embeddings
.
weight
=
copy
.
deepcopy
(
self
.
bert
.
model
.
embeddings
.
word_embeddings
.
weight
)
self
.
decoder
=
TransformerDecoder
(
self
.
args
.
dec_layers
,
self
.
args
.
dec_hidden_size
,
heads
=
self
.
args
.
dec_heads
,
d_ff
=
self
.
args
.
dec_ff_size
,
dropout
=
self
.
args
.
dec_dropout
,
embeddings
=
tgt_embeddings
,
vocab_size
=
self
.
vocab_size
,
)
gen_func
=
nn
.
LogSoftmax
(
dim
=-
1
)
self
.
generator
=
nn
.
Sequential
(
nn
.
Linear
(
args
.
dec_hidden_size
,
args
.
vocab_size
),
gen_func
)
self
.
generator
[
0
]
.
weight
=
self
.
decoder
.
embeddings
.
weight
load_from_checkpoints
=
False
if
checkpoint
is
None
else
True
if
load_from_checkpoints
:
self
.
load_state_dict
(
checkpoint
)
def
init_weights
(
self
):
for
module
in
self
.
decoder
.
modules
():
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
)):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
0.02
)
elif
isinstance
(
module
,
nn
.
LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
if
isinstance
(
module
,
nn
.
Linear
)
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
for
p
in
self
.
generator
.
parameters
():
if
p
.
dim
()
>
1
:
xavier_uniform_
(
p
)
else
:
p
.
data
.
zero_
()
def
forward
(
self
,
encoder_input_ids
,
decoder_input_ids
,
token_type_ids
,
encoder_attention_mask
,
decoder_attention_mask
,
):
encoder_output
=
self
.
bert
(
input_ids
=
encoder_input_ids
,
token_type_ids
=
token_type_ids
,
attention_mask
=
encoder_attention_mask
,
)
encoder_hidden_states
=
encoder_output
[
0
]
dec_state
=
self
.
decoder
.
init_decoder_state
(
encoder_input_ids
,
encoder_hidden_states
)
decoder_outputs
,
_
=
self
.
decoder
(
decoder_input_ids
[:,
:
-
1
],
encoder_hidden_states
,
dec_state
)
return
decoder_outputs
class
Bert
(
nn
.
Module
):
"""This class is not really necessary and should probably disappear."""
def
__init__
(
self
):
super
()
.
__init__
()
config
=
BertConfig
.
from_pretrained
(
"bert-base-uncased"
)
self
.
model
=
BertModel
(
config
)
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
**
kwargs
):
self
.
eval
()
with
torch
.
no_grad
():
encoder_outputs
,
_
=
self
.
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
attention_mask
=
attention_mask
,
**
kwargs
)
return
encoder_outputs
class
TransformerDecoder
(
nn
.
Module
):
"""
The Transformer decoder from "Attention is All You Need".
Args:
num_layers (int): number of encoder layers.
d_model (int): size of the model
heads (int): number of heads
d_ff (int): size of the inner FF layer
dropout (float): dropout parameters
embeddings (:obj:`onmt.modules.Embeddings`):
embeddings to use, should have positional encodings
attn_type (str): if using a separate copy attention
"""
def
__init__
(
self
,
num_layers
,
d_model
,
heads
,
d_ff
,
dropout
,
embeddings
,
vocab_size
):
super
()
.
__init__
()
# Basic attributes.
self
.
decoder_type
=
"transformer"
self
.
num_layers
=
num_layers
self
.
embeddings
=
embeddings
self
.
pos_emb
=
PositionalEncoding
(
dropout
,
self
.
embeddings
.
embedding_dim
)
# Build TransformerDecoder.
self
.
transformer_layers
=
nn
.
ModuleList
(
[
TransformerDecoderLayer
(
d_model
,
heads
,
d_ff
,
dropout
)
for
_
in
range
(
num_layers
)]
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
,
eps
=
1e-6
)
# forward(input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask)
# def forward(self, input_ids, state, attention_mask=None, memory_lengths=None,
# step=None, cache=None, encoder_attention_mask=None, encoder_hidden_states=None, memory_masks=None):
def
forward
(
self
,
input_ids
,
encoder_hidden_states
=
None
,
state
=
None
,
attention_mask
=
None
,
memory_lengths
=
None
,
step
=
None
,
cache
=
None
,
encoder_attention_mask
=
None
,
):
"""
See :obj:`onmt.modules.RNNDecoderBase.forward()`
memory_bank = encoder_hidden_states
"""
# Name conversion
tgt
=
input_ids
memory_bank
=
encoder_hidden_states
memory_mask
=
encoder_attention_mask
# src_words = state.src
src_words
=
state
.
src
src_batch
,
src_len
=
src_words
.
size
()
padding_idx
=
self
.
embeddings
.
padding_idx
# Decoder padding mask
tgt_words
=
tgt
tgt_batch
,
tgt_len
=
tgt_words
.
size
()
tgt_pad_mask
=
tgt_words
.
data
.
eq
(
padding_idx
)
.
unsqueeze
(
1
)
.
expand
(
tgt_batch
,
tgt_len
,
tgt_len
)
# Encoder padding mask
if
memory_mask
is
not
None
:
src_len
=
memory_mask
.
size
(
-
1
)
src_pad_mask
=
memory_mask
.
expand
(
src_batch
,
tgt_len
,
src_len
)
else
:
src_pad_mask
=
src_words
.
data
.
eq
(
padding_idx
)
.
unsqueeze
(
1
)
.
expand
(
src_batch
,
tgt_len
,
src_len
)
# Pass through the embeddings
emb
=
self
.
embeddings
(
input_ids
)
output
=
self
.
pos_emb
(
emb
,
step
)
assert
emb
.
dim
()
==
3
# len x batch x embedding_dim
if
state
.
cache
is
None
:
saved_inputs
=
[]
for
i
in
range
(
self
.
num_layers
):
prev_layer_input
=
None
if
state
.
cache
is
None
:
if
state
.
previous_input
is
not
None
:
prev_layer_input
=
state
.
previous_layer_inputs
[
i
]
output
,
all_input
=
self
.
transformer_layers
[
i
](
output
,
memory_bank
,
src_pad_mask
,
tgt_pad_mask
,
previous_input
=
prev_layer_input
,
layer_cache
=
state
.
cache
[
"layer_{}"
.
format
(
i
)]
if
state
.
cache
is
not
None
else
None
,
step
=
step
,
)
if
state
.
cache
is
None
:
saved_inputs
.
append
(
all_input
)
if
state
.
cache
is
None
:
saved_inputs
=
torch
.
stack
(
saved_inputs
)
output
=
self
.
layer_norm
(
output
)
if
state
.
cache
is
None
:
state
=
state
.
update_state
(
tgt
,
saved_inputs
)
# Decoders in transformers return a tuple. Beam search will fail
# if we don't follow this convention.
return
output
,
state
# , state
def
init_decoder_state
(
self
,
src
,
memory_bank
,
with_cache
=
False
):
""" Init decoder state """
state
=
TransformerDecoderState
(
src
)
if
with_cache
:
state
.
_init_cache
(
memory_bank
,
self
.
num_layers
)
return
state
class
PositionalEncoding
(
nn
.
Module
):
def
__init__
(
self
,
dropout
,
dim
,
max_len
=
5000
):
pe
=
torch
.
zeros
(
max_len
,
dim
)
position
=
torch
.
arange
(
0
,
max_len
)
.
unsqueeze
(
1
)
div_term
=
torch
.
exp
((
torch
.
arange
(
0
,
dim
,
2
,
dtype
=
torch
.
float
)
*
-
(
math
.
log
(
10000.0
)
/
dim
)))
pe
[:,
0
::
2
]
=
torch
.
sin
(
position
.
float
()
*
div_term
)
pe
[:,
1
::
2
]
=
torch
.
cos
(
position
.
float
()
*
div_term
)
pe
=
pe
.
unsqueeze
(
0
)
super
()
.
__init__
()
self
.
register_buffer
(
"pe"
,
pe
)
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout
)
self
.
dim
=
dim
def
forward
(
self
,
emb
,
step
=
None
):
emb
=
emb
*
math
.
sqrt
(
self
.
dim
)
if
step
:
emb
=
emb
+
self
.
pe
[:,
step
][:,
None
,
:]
else
:
emb
=
emb
+
self
.
pe
[:,
:
emb
.
size
(
1
)]
emb
=
self
.
dropout
(
emb
)
return
emb
def
get_emb
(
self
,
emb
):
return
self
.
pe
[:,
:
emb
.
size
(
1
)]
class
TransformerDecoderLayer
(
nn
.
Module
):
"""
Args:
d_model (int): the dimension of keys/values/queries in
MultiHeadedAttention, also the input size of
the first-layer of the PositionwiseFeedForward.
heads (int): the number of heads for MultiHeadedAttention.
d_ff (int): the second-layer of the PositionwiseFeedForward.
dropout (float): dropout probability(0-1.0).
self_attn_type (string): type of self-attention scaled-dot, average
"""
def
__init__
(
self
,
d_model
,
heads
,
d_ff
,
dropout
):
super
()
.
__init__
()
self
.
self_attn
=
MultiHeadedAttention
(
heads
,
d_model
,
dropout
=
dropout
)
self
.
context_attn
=
MultiHeadedAttention
(
heads
,
d_model
,
dropout
=
dropout
)
self
.
feed_forward
=
PositionwiseFeedForward
(
d_model
,
d_ff
,
dropout
)
self
.
layer_norm_1
=
nn
.
LayerNorm
(
d_model
,
eps
=
1e-6
)
self
.
layer_norm_2
=
nn
.
LayerNorm
(
d_model
,
eps
=
1e-6
)
self
.
drop
=
nn
.
Dropout
(
dropout
)
mask
=
self
.
_get_attn_subsequent_mask
(
MAX_SIZE
)
# Register self.mask as a saved_state in TransformerDecoderLayer, so
# it gets TransformerDecoderLayer's cuda behavior automatically.
self
.
register_buffer
(
"mask"
,
mask
)
def
forward
(
self
,
inputs
,
memory_bank
,
src_pad_mask
,
tgt_pad_mask
,
previous_input
=
None
,
layer_cache
=
None
,
step
=
None
,
):
"""
Args:
inputs (`FloatTensor`): `[batch_size x 1 x model_dim]`
memory_bank (`FloatTensor`): `[batch_size x src_len x model_dim]`
src_pad_mask (`LongTensor`): `[batch_size x 1 x src_len]`
tgt_pad_mask (`LongTensor`): `[batch_size x 1 x 1]`
Returns:
(`FloatTensor`, `FloatTensor`, `FloatTensor`):
* output `[batch_size x 1 x model_dim]`
* attn `[batch_size x 1 x src_len]`
* all_input `[batch_size x current_step x model_dim]`
"""
dec_mask
=
torch
.
gt
(
tgt_pad_mask
+
self
.
mask
[:,
:
tgt_pad_mask
.
size
(
1
),
:
tgt_pad_mask
.
size
(
1
)],
0
)
input_norm
=
self
.
layer_norm_1
(
inputs
)
all_input
=
input_norm
if
previous_input
is
not
None
:
all_input
=
torch
.
cat
((
previous_input
,
input_norm
),
dim
=
1
)
dec_mask
=
None
query
=
self
.
self_attn
(
all_input
,
all_input
,
input_norm
,
mask
=
dec_mask
,
layer_cache
=
layer_cache
,
type
=
"self"
,
)
query
=
self
.
drop
(
query
)
+
inputs
query_norm
=
self
.
layer_norm_2
(
query
)
mid
=
self
.
context_attn
(
memory_bank
,
memory_bank
,
query_norm
,
mask
=
src_pad_mask
,
layer_cache
=
layer_cache
,
type
=
"context"
,
)
output
=
self
.
feed_forward
(
self
.
drop
(
mid
)
+
query
)
return
output
,
all_input
# return output
def
_get_attn_subsequent_mask
(
self
,
size
):
"""
Get an attention mask to avoid using the subsequent info.
Args:
size: int
Returns:
(`LongTensor`):
* subsequent_mask `[1 x size x size]`
"""
attn_shape
=
(
1
,
size
,
size
)
subsequent_mask
=
np
.
triu
(
np
.
ones
(
attn_shape
),
k
=
1
)
.
astype
(
"uint8"
)
subsequent_mask
=
torch
.
from_numpy
(
subsequent_mask
)
return
subsequent_mask
class
MultiHeadedAttention
(
nn
.
Module
):
"""
Multi-Head Attention module from
"Attention is All You Need"
:cite:`DBLP:journals/corr/VaswaniSPUJGKP17`.
Similar to standard `dot` attention but uses
multiple attention distributions simulataneously
to select relevant items.
.. mermaid::
graph BT
A[key]
B[value]
C[query]
O[output]
subgraph Attn
D[Attn 1]
E[Attn 2]
F[Attn N]
end
A --> D
C --> D
A --> E
C --> E
A --> F
C --> F
D --> O
E --> O
F --> O
B --> O
Also includes several additional tricks.
Args:
head_count (int): number of parallel heads
model_dim (int): the dimension of keys/values/queries,
must be divisible by head_count
dropout (float): dropout parameter
"""
def
__init__
(
self
,
head_count
,
model_dim
,
dropout
=
0.1
,
use_final_linear
=
True
):
assert
model_dim
%
head_count
==
0
self
.
dim_per_head
=
model_dim
//
head_count
self
.
model_dim
=
model_dim
super
()
.
__init__
()
self
.
head_count
=
head_count
self
.
linear_keys
=
nn
.
Linear
(
model_dim
,
head_count
*
self
.
dim_per_head
)
self
.
linear_values
=
nn
.
Linear
(
model_dim
,
head_count
*
self
.
dim_per_head
)
self
.
linear_query
=
nn
.
Linear
(
model_dim
,
head_count
*
self
.
dim_per_head
)
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
use_final_linear
=
use_final_linear
if
self
.
use_final_linear
:
self
.
final_linear
=
nn
.
Linear
(
model_dim
,
model_dim
)
def
forward
(
self
,
key
,
value
,
query
,
mask
=
None
,
layer_cache
=
None
,
type
=
None
,
predefined_graph_1
=
None
,
):
"""
Compute the context vector and the attention vectors.
Args:
key (`FloatTensor`): set of `key_len`
key vectors `[batch, key_len, dim]`
value (`FloatTensor`): set of `key_len`
value vectors `[batch, key_len, dim]`
query (`FloatTensor`): set of `query_len`
query vectors `[batch, query_len, dim]`
mask: binary mask indicating which keys have
non-zero attention `[batch, query_len, key_len]`
Returns:
(`FloatTensor`, `FloatTensor`) :
* output context vectors `[batch, query_len, dim]`
* one of the attention vectors `[batch, query_len, key_len]`
"""
batch_size
=
key
.
size
(
0
)
dim_per_head
=
self
.
dim_per_head
head_count
=
self
.
head_count
def
shape
(
x
):
""" projection """
return
x
.
view
(
batch_size
,
-
1
,
head_count
,
dim_per_head
)
.
transpose
(
1
,
2
)
def
unshape
(
x
):
""" compute context """
return
x
.
transpose
(
1
,
2
)
.
contiguous
()
.
view
(
batch_size
,
-
1
,
head_count
*
dim_per_head
)
# 1) Project key, value, and query.
if
layer_cache
is
not
None
:
if
type
==
"self"
:
query
,
key
,
value
=
(
self
.
linear_query
(
query
),
self
.
linear_keys
(
query
),
self
.
linear_values
(
query
),
)
key
=
shape
(
key
)
value
=
shape
(
value
)
if
layer_cache
is
not
None
:
device
=
key
.
device
if
layer_cache
[
"self_keys"
]
is
not
None
:
key
=
torch
.
cat
((
layer_cache
[
"self_keys"
]
.
to
(
device
),
key
),
dim
=
2
)
if
layer_cache
[
"self_values"
]
is
not
None
:
value
=
torch
.
cat
((
layer_cache
[
"self_values"
]
.
to
(
device
),
value
),
dim
=
2
)
layer_cache
[
"self_keys"
]
=
key
layer_cache
[
"self_values"
]
=
value
elif
type
==
"context"
:
query
=
self
.
linear_query
(
query
)
if
layer_cache
is
not
None
:
if
layer_cache
[
"memory_keys"
]
is
None
:
key
,
value
=
self
.
linear_keys
(
key
),
self
.
linear_values
(
value
)
key
=
shape
(
key
)
value
=
shape
(
value
)
else
:
key
,
value
=
(
layer_cache
[
"memory_keys"
],
layer_cache
[
"memory_values"
],
)
layer_cache
[
"memory_keys"
]
=
key
layer_cache
[
"memory_values"
]
=
value
else
:
key
,
value
=
self
.
linear_keys
(
key
),
self
.
linear_values
(
value
)
key
=
shape
(
key
)
value
=
shape
(
value
)
else
:
key
=
self
.
linear_keys
(
key
)
value
=
self
.
linear_values
(
value
)
query
=
self
.
linear_query
(
query
)
key
=
shape
(
key
)
value
=
shape
(
value
)
query
=
shape
(
query
)
# 2) Calculate and scale scores.
query
=
query
/
math
.
sqrt
(
dim_per_head
)
scores
=
torch
.
matmul
(
query
,
key
.
transpose
(
2
,
3
))
if
mask
is
not
None
:
mask
=
mask
.
unsqueeze
(
1
)
.
expand_as
(
scores
)
scores
=
scores
.
masked_fill
(
mask
,
-
1e18
)
# 3) Apply attention dropout and compute context vectors.
attn
=
self
.
softmax
(
scores
)
if
predefined_graph_1
is
not
None
:
attn_masked
=
attn
[:,
-
1
]
*
predefined_graph_1
attn_masked
=
attn_masked
/
(
torch
.
sum
(
attn_masked
,
2
)
.
unsqueeze
(
2
)
+
1e-9
)
attn
=
torch
.
cat
([
attn
[:,
:
-
1
],
attn_masked
.
unsqueeze
(
1
)],
1
)
drop_attn
=
self
.
dropout
(
attn
)
if
self
.
use_final_linear
:
context
=
unshape
(
torch
.
matmul
(
drop_attn
,
value
))
output
=
self
.
final_linear
(
context
)
return
output
else
:
context
=
torch
.
matmul
(
drop_attn
,
value
)
return
context
class
DecoderState
(
object
):
"""Interface for grouping together the current state of a recurrent
decoder. In the simplest case just represents the hidden state of
the model. But can also be used for implementing various forms of
input_feeding and non-recurrent models.
Modules need to implement this to utilize beam search decoding.
"""
def
detach
(
self
):
""" Need to document this """
self
.
hidden
=
tuple
([
_
.
detach
()
for
_
in
self
.
hidden
])
self
.
input_feed
=
self
.
input_feed
.
detach
()
def
beam_update
(
self
,
idx
,
positions
,
beam_size
):
""" Need to document this """
for
e
in
self
.
_all
:
sizes
=
e
.
size
()
br
=
sizes
[
1
]
if
len
(
sizes
)
==
3
:
sent_states
=
e
.
view
(
sizes
[
0
],
beam_size
,
br
//
beam_size
,
sizes
[
2
])[:,
:,
idx
]
else
:
sent_states
=
e
.
view
(
sizes
[
0
],
beam_size
,
br
//
beam_size
,
sizes
[
2
],
sizes
[
3
])[:,
:,
idx
]
sent_states
.
data
.
copy_
(
sent_states
.
data
.
index_select
(
1
,
positions
))
def
map_batch_fn
(
self
,
fn
):
raise
NotImplementedError
()
class
TransformerDecoderState
(
DecoderState
):
""" Transformer Decoder state base class """
def
__init__
(
self
,
src
):
"""
Args:
src (FloatTensor): a sequence of source words tensors
with optional feature tensors, of size (len x batch).
"""
self
.
src
=
src
self
.
previous_input
=
None
self
.
previous_layer_inputs
=
None
self
.
cache
=
None
@property
def
_all
(
self
):
"""
Contains attributes that need to be updated in self.beam_update().
"""
if
self
.
previous_input
is
not
None
and
self
.
previous_layer_inputs
is
not
None
:
return
(
self
.
previous_input
,
self
.
previous_layer_inputs
,
self
.
src
)
else
:
return
(
self
.
src
,)
def
detach
(
self
):
if
self
.
previous_input
is
not
None
:
self
.
previous_input
=
self
.
previous_input
.
detach
()
if
self
.
previous_layer_inputs
is
not
None
:
self
.
previous_layer_inputs
=
self
.
previous_layer_inputs
.
detach
()
self
.
src
=
self
.
src
.
detach
()
def
update_state
(
self
,
new_input
,
previous_layer_inputs
):
state
=
TransformerDecoderState
(
self
.
src
)
state
.
previous_input
=
new_input
state
.
previous_layer_inputs
=
previous_layer_inputs
return
state
def
_init_cache
(
self
,
memory_bank
,
num_layers
):
self
.
cache
=
{}
for
l
in
range
(
num_layers
):
layer_cache
=
{
"memory_keys"
:
None
,
"memory_values"
:
None
}
layer_cache
[
"self_keys"
]
=
None
layer_cache
[
"self_values"
]
=
None
self
.
cache
[
"layer_{}"
.
format
(
l
)]
=
layer_cache
def
repeat_beam_size_times
(
self
,
beam_size
):
""" Repeat beam_size times along batch dimension. """
self
.
src
=
self
.
src
.
data
.
repeat
(
1
,
beam_size
,
1
)
def
map_batch_fn
(
self
,
fn
):
def
_recursive_map
(
struct
,
batch_dim
=
0
):
for
k
,
v
in
struct
.
items
():
if
v
is
not
None
:
if
isinstance
(
v
,
dict
):
_recursive_map
(
v
)
else
:
struct
[
k
]
=
fn
(
v
,
batch_dim
)
self
.
src
=
fn
(
self
.
src
,
0
)
if
self
.
cache
is
not
None
:
_recursive_map
(
self
.
cache
)
def
gelu
(
x
):
return
0.5
*
x
*
(
1
+
torch
.
tanh
(
math
.
sqrt
(
2
/
math
.
pi
)
*
(
x
+
0.044715
*
torch
.
pow
(
x
,
3
))))
class
PositionwiseFeedForward
(
nn
.
Module
):
"""A two-layer Feed-Forward-Network with residual layer norm.
Args:
d_model (int): the size of input for the first-layer of the FFN.
d_ff (int): the hidden layer size of the second-layer
of the FNN.
dropout (float): dropout probability in :math:`[0, 1)`.
"""
def
__init__
(
self
,
d_model
,
d_ff
,
dropout
=
0.1
):
super
()
.
__init__
()
self
.
w_1
=
nn
.
Linear
(
d_model
,
d_ff
)
self
.
w_2
=
nn
.
Linear
(
d_ff
,
d_model
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
,
eps
=
1e-6
)
self
.
actv
=
gelu
self
.
dropout_1
=
nn
.
Dropout
(
dropout
)
self
.
dropout_2
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
x
):
inter
=
self
.
dropout_1
(
self
.
actv
(
self
.
w_1
(
self
.
layer_norm
(
x
))))
output
=
self
.
dropout_2
(
self
.
w_2
(
inter
))
return
output
+
x
#
# TRANSLATOR
# The following code is used to generate summaries using the
# pre-trained weights and beam search.
#
def
build_predictor
(
args
,
tokenizer
,
symbols
,
model
,
logger
=
None
):
# we should be able to refactor the global scorer a lot
scorer
=
GNMTGlobalScorer
(
args
.
alpha
,
length_penalty
=
"wu"
)
translator
=
Translator
(
args
,
model
,
tokenizer
,
symbols
,
global_scorer
=
scorer
,
logger
=
logger
)
return
translator
class
GNMTGlobalScorer
(
object
):
"""
NMT re-ranking score from
"Google's Neural Machine Translation System" :cite:`wu2016google`
Args:
alpha (float): length parameter
beta (float): coverage parameter
"""
def
__init__
(
self
,
alpha
,
length_penalty
):
self
.
alpha
=
alpha
penalty_builder
=
PenaltyBuilder
(
length_penalty
)
self
.
length_penalty
=
penalty_builder
.
length_penalty
()
def
score
(
self
,
beam
,
logprobs
):
"""
Rescores a prediction based on penalty functions
"""
normalized_probs
=
self
.
length_penalty
(
beam
,
logprobs
,
self
.
alpha
)
return
normalized_probs
class
PenaltyBuilder
(
object
):
"""
Returns the Length and Coverage Penalty function for Beam Search.
Args:
length_pen (str): option name of length pen
cov_pen (str): option name of cov pen
"""
def
__init__
(
self
,
length_pen
):
self
.
length_pen
=
length_pen
def
length_penalty
(
self
):
if
self
.
length_pen
==
"wu"
:
return
self
.
length_wu
elif
self
.
length_pen
==
"avg"
:
return
self
.
length_average
else
:
return
self
.
length_none
"""
Below are all the different penalty terms implemented so far
"""
def
length_wu
(
self
,
beam
,
logprobs
,
alpha
=
0.0
):
"""
NMT length re-ranking score from
"Google's Neural Machine Translation System" :cite:`wu2016google`.
"""
modifier
=
((
5
+
len
(
beam
.
next_ys
))
**
alpha
)
/
((
5
+
1
)
**
alpha
)
return
logprobs
/
modifier
def
length_average
(
self
,
beam
,
logprobs
,
alpha
=
0.0
):
"""
Returns the average probability of tokens in a sequence.
"""
return
logprobs
/
len
(
beam
.
next_ys
)
def
length_none
(
self
,
beam
,
logprobs
,
alpha
=
0.0
,
beta
=
0.0
):
"""
Returns unmodified scores.
"""
return
logprobs
class
Translator
(
object
):
"""
Uses a model to translate a batch of sentences.
Args:
model (:obj:`onmt.modules.NMTModel`):
NMT model to use for translation
fields (dict of Fields): data fields
beam_size (int): size of beam to use
n_best (int): number of translations produced
max_length (int): maximum length output to produce
global_scores (:obj:`GlobalScorer`):
object to rescore final translations
copy_attn (bool): use copy attention during translation
beam_trace (bool): trace beam search for debugging
logger(logging.Logger): logger.
"""
def
__init__
(
self
,
args
,
model
,
vocab
,
symbols
,
global_scorer
=
None
,
logger
=
None
):
self
.
logger
=
logger
self
.
args
=
args
self
.
model
=
model
self
.
generator
=
self
.
model
.
generator
self
.
vocab
=
vocab
self
.
symbols
=
symbols
self
.
start_token
=
symbols
[
"BOS"
]
self
.
end_token
=
symbols
[
"EOS"
]
self
.
global_scorer
=
global_scorer
self
.
beam_size
=
args
.
beam_size
self
.
min_length
=
args
.
min_length
self
.
max_length
=
args
.
max_length
def
translate
(
self
,
batch
,
step
,
attn_debug
=
False
):
"""Generates summaries from one batch of data."""
self
.
model
.
eval
()
with
torch
.
no_grad
():
batch_data
=
self
.
translate_batch
(
batch
)
translations
=
self
.
from_batch
(
batch_data
)
return
translations
def
translate_batch
(
self
,
batch
,
fast
=
False
):
"""
Translate a batch of sentences.
Mostly a wrapper around :obj:`Beam`.
Args:
batch (:obj:`Batch`): a batch from a dataset object
fast (bool): enables fast beam search (may not support all features)
"""
with
torch
.
no_grad
():
return
self
.
_fast_translate_batch
(
batch
,
self
.
max_length
,
min_length
=
self
.
min_length
)
# Where the beam search lives
# I have no idea why it is being called from the method above
def
_fast_translate_batch
(
self
,
batch
,
max_length
,
min_length
=
0
):
"""Beam Search using the encoder inputs contained in `batch`."""
# The batch object is funny
# Instead of just looking at the size of the arguments we encapsulate
# a size argument.
# Where is it defined?
beam_size
=
self
.
beam_size
batch_size
=
batch
.
batch_size
src
=
batch
.
src
segs
=
batch
.
segs
mask_src
=
batch
.
mask_src
src_features
=
self
.
model
.
bert
(
src
,
segs
,
mask_src
)
dec_states
=
self
.
model
.
decoder
.
init_decoder_state
(
src
,
src_features
,
with_cache
=
True
)
device
=
src_features
.
device
# Tile states and memory beam_size times.
dec_states
.
map_batch_fn
(
lambda
state
,
dim
:
tile
(
state
,
beam_size
,
dim
=
dim
))
src_features
=
tile
(
src_features
,
beam_size
,
dim
=
0
)
batch_offset
=
torch
.
arange
(
batch_size
,
dtype
=
torch
.
long
,
device
=
device
)
beam_offset
=
torch
.
arange
(
0
,
batch_size
*
beam_size
,
step
=
beam_size
,
dtype
=
torch
.
long
,
device
=
device
)
alive_seq
=
torch
.
full
([
batch_size
*
beam_size
,
1
],
self
.
start_token
,
dtype
=
torch
.
long
,
device
=
device
)
# Give full probability to the first beam on the first step.
topk_log_probs
=
torch
.
tensor
([
0.0
]
+
[
float
(
"-inf"
)]
*
(
beam_size
-
1
),
device
=
device
)
.
repeat
(
batch_size
)
# Structure that holds finished hypotheses.
hypotheses
=
[[]
for
_
in
range
(
batch_size
)]
# noqa: F812
results
=
{}
results
[
"predictions"
]
=
[[]
for
_
in
range
(
batch_size
)]
# noqa: F812
results
[
"scores"
]
=
[[]
for
_
in
range
(
batch_size
)]
# noqa: F812
results
[
"gold_score"
]
=
[
0
]
*
batch_size
results
[
"batch"
]
=
batch
for
step
in
range
(
max_length
):
decoder_input
=
alive_seq
[:,
-
1
]
.
view
(
1
,
-
1
)
# Decoder forward.
decoder_input
=
decoder_input
.
transpose
(
0
,
1
)
dec_out
,
dec_states
=
self
.
model
.
decoder
(
decoder_input
,
src_features
,
dec_states
,
step
=
step
)
# Generator forward.
log_probs
=
self
.
generator
(
dec_out
.
transpose
(
0
,
1
)
.
squeeze
(
0
))
vocab_size
=
log_probs
.
size
(
-
1
)
if
step
<
min_length
:
log_probs
[:,
self
.
end_token
]
=
-
1e20
# Multiply probs by the beam probability.
log_probs
+=
topk_log_probs
.
view
(
-
1
)
.
unsqueeze
(
1
)
alpha
=
self
.
global_scorer
.
alpha
length_penalty
=
((
5.0
+
(
step
+
1
))
/
6.0
)
**
alpha
# Flatten probs into a list of possibilities.
curr_scores
=
log_probs
/
length_penalty
if
self
.
args
.
block_trigram
:
cur_len
=
alive_seq
.
size
(
1
)
if
cur_len
>
3
:
for
i
in
range
(
alive_seq
.
size
(
0
)):
fail
=
False
words
=
[
int
(
w
)
for
w
in
alive_seq
[
i
]]
words
=
[
self
.
vocab
.
ids_to_tokens
[
w
]
for
w
in
words
]
words
=
" "
.
join
(
words
)
.
replace
(
" ##"
,
""
)
.
split
()
if
len
(
words
)
<=
3
:
continue
trigrams
=
[(
words
[
i
-
1
],
words
[
i
],
words
[
i
+
1
])
for
i
in
range
(
1
,
len
(
words
)
-
1
)]
trigram
=
tuple
(
trigrams
[
-
1
])
if
trigram
in
trigrams
[:
-
1
]:
fail
=
True
if
fail
:
curr_scores
[
i
]
=
-
10e20
curr_scores
=
curr_scores
.
reshape
(
-
1
,
beam_size
*
vocab_size
)
topk_scores
,
topk_ids
=
curr_scores
.
topk
(
beam_size
,
dim
=-
1
)
# Recover log probs.
topk_log_probs
=
topk_scores
*
length_penalty
# Resolve beam origin and true word ids.
topk_beam_index
=
topk_ids
.
div
(
vocab_size
)
topk_ids
=
topk_ids
.
fmod
(
vocab_size
)
# Map beam_index to batch_index in the flat representation.
batch_index
=
topk_beam_index
+
beam_offset
[:
topk_beam_index
.
size
(
0
)]
.
unsqueeze
(
1
)
select_indices
=
batch_index
.
view
(
-
1
)
# Append last prediction.
alive_seq
=
torch
.
cat
([
alive_seq
.
index_select
(
0
,
select_indices
),
topk_ids
.
view
(
-
1
,
1
)],
-
1
)
is_finished
=
topk_ids
.
eq
(
self
.
end_token
)
if
step
+
1
==
max_length
:
is_finished
.
fill_
(
1
)
# End condition is top beam is finished.
end_condition
=
is_finished
[:,
0
]
.
eq
(
1
)
# Save finished hypotheses.
if
is_finished
.
any
():
predictions
=
alive_seq
.
view
(
-
1
,
beam_size
,
alive_seq
.
size
(
-
1
))
for
i
in
range
(
is_finished
.
size
(
0
)):
b
=
batch_offset
[
i
]
if
end_condition
[
i
]:
is_finished
[
i
]
.
fill_
(
1
)
finished_hyp
=
is_finished
[
i
]
.
nonzero
()
.
view
(
-
1
)
# Store finished hypotheses for this batch.
for
j
in
finished_hyp
:
hypotheses
[
b
]
.
append
((
topk_scores
[
i
,
j
],
predictions
[
i
,
j
,
1
:]))
# If the batch reached the end, save the n_best hypotheses.
if
end_condition
[
i
]:
best_hyp
=
sorted
(
hypotheses
[
b
],
key
=
lambda
x
:
x
[
0
],
reverse
=
True
)
score
,
pred
=
best_hyp
[
0
]
results
[
"scores"
][
b
]
.
append
(
score
)
results
[
"predictions"
][
b
]
.
append
(
pred
)
non_finished
=
end_condition
.
eq
(
0
)
.
nonzero
()
.
view
(
-
1
)
# If all sentences are translated, no need to go further.
if
len
(
non_finished
)
==
0
:
break
# Remove finished batches for the next step.
topk_log_probs
=
topk_log_probs
.
index_select
(
0
,
non_finished
)
batch_index
=
batch_index
.
index_select
(
0
,
non_finished
)
batch_offset
=
batch_offset
.
index_select
(
0
,
non_finished
)
alive_seq
=
predictions
.
index_select
(
0
,
non_finished
)
.
view
(
-
1
,
alive_seq
.
size
(
-
1
))
# Reorder states.
select_indices
=
batch_index
.
view
(
-
1
)
src_features
=
src_features
.
index_select
(
0
,
select_indices
)
dec_states
.
map_batch_fn
(
lambda
state
,
dim
:
state
.
index_select
(
dim
,
select_indices
))
return
results
def
from_batch
(
self
,
translation_batch
):
batch
=
translation_batch
[
"batch"
]
assert
len
(
translation_batch
[
"gold_score"
])
==
len
(
translation_batch
[
"predictions"
])
batch_size
=
batch
.
batch_size
preds
,
_
,
_
,
tgt_str
,
src
=
(
translation_batch
[
"predictions"
],
translation_batch
[
"scores"
],
translation_batch
[
"gold_score"
],
batch
.
tgt_str
,
batch
.
src
,
)
translations
=
[]
for
b
in
range
(
batch_size
):
pred_sents
=
self
.
vocab
.
convert_ids_to_tokens
([
int
(
n
)
for
n
in
preds
[
b
][
0
]])
pred_sents
=
" "
.
join
(
pred_sents
)
.
replace
(
" ##"
,
""
)
gold_sent
=
" "
.
join
(
tgt_str
[
b
]
.
split
())
raw_src
=
[
self
.
vocab
.
ids_to_tokens
[
int
(
t
)]
for
t
in
src
[
b
]][:
500
]
raw_src
=
" "
.
join
(
raw_src
)
translation
=
(
pred_sents
,
gold_sent
,
raw_src
)
translations
.
append
(
translation
)
return
translations
def
tile
(
x
,
count
,
dim
=
0
):
"""
Tiles x on dimension dim count times.
"""
perm
=
list
(
range
(
len
(
x
.
size
())))
if
dim
!=
0
:
perm
[
0
],
perm
[
dim
]
=
perm
[
dim
],
perm
[
0
]
x
=
x
.
permute
(
perm
)
.
contiguous
()
out_size
=
list
(
x
.
size
())
out_size
[
0
]
*=
count
batch
=
x
.
size
(
0
)
x
=
x
.
view
(
batch
,
-
1
)
.
transpose
(
0
,
1
)
.
repeat
(
count
,
1
)
.
transpose
(
0
,
1
)
.
contiguous
()
.
view
(
*
out_size
)
if
dim
!=
0
:
x
=
x
.
permute
(
perm
)
.
contiguous
()
return
x
#
# Optimizer for training. We keep this here in case we want to add
# a finetuning script.
#
class
BertSumOptimizer
(
object
):
"""Specific optimizer for BertSum.
As described in [1], the authors fine-tune BertSum for abstractive
summarization using two Adam Optimizers with different warm-up steps and
learning rate. They also use a custom learning rate scheduler.
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
arXiv preprint arXiv:1908.08345 (2019).
"""
def
__init__
(
self
,
model
,
lr
,
warmup_steps
,
beta_1
=
0.99
,
beta_2
=
0.999
,
eps
=
1e-8
):
self
.
encoder
=
model
.
encoder
self
.
decoder
=
model
.
decoder
self
.
lr
=
lr
self
.
warmup_steps
=
warmup_steps
self
.
optimizers
=
{
"encoder"
:
torch
.
optim
.
Adam
(
model
.
encoder
.
parameters
(),
lr
=
lr
[
"encoder"
],
betas
=
(
beta_1
,
beta_2
),
eps
=
eps
,
),
"decoder"
:
torch
.
optim
.
Adam
(
model
.
decoder
.
parameters
(),
lr
=
lr
[
"decoder"
],
betas
=
(
beta_1
,
beta_2
),
eps
=
eps
,
),
}
self
.
_step
=
0
self
.
current_learning_rates
=
{}
def
_update_rate
(
self
,
stack
):
return
self
.
lr
[
stack
]
*
min
(
self
.
_step
**
(
-
0.5
),
self
.
_step
*
self
.
warmup_steps
[
stack
]
**
(
-
1.5
))
def
zero_grad
(
self
):
self
.
optimizer_decoder
.
zero_grad
()
self
.
optimizer_encoder
.
zero_grad
()
def
step
(
self
):
self
.
_step
+=
1
for
stack
,
optimizer
in
self
.
optimizers
.
items
():
new_rate
=
self
.
_update_rate
(
stack
)
for
param_group
in
optimizer
.
param_groups
:
param_group
[
"lr"
]
=
new_rate
optimizer
.
step
()
self
.
current_learning_rates
[
stack
]
=
new_rate
Event Timeline
Log In to Comment