Page MenuHomec4science

modeling_highway_bert.py
No OneTemporary

File Metadata

Created
Sun, May 12, 18:16

modeling_highway_bert.py

import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.models.bert.modeling_bert import (
BERT_INPUTS_DOCSTRING,
BERT_START_DOCSTRING,
BertEmbeddings,
BertLayer,
BertPooler,
BertPreTrainedModel,
)
def entropy(x):
"""Calculate entropy of a pre-softmax logit Tensor"""
exp_x = torch.exp(x)
A = torch.sum(exp_x, dim=1) # sum of exp(x_i)
B = torch.sum(x * exp_x, dim=1) # sum of x_i * exp(x_i)
return torch.log(A) - B / A
class DeeBertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
self.highway = nn.ModuleList([BertHighway(config) for _ in range(config.num_hidden_layers)])
self.early_exit_entropy = [-1 for _ in range(config.num_hidden_layers)]
def set_early_exit_entropy(self, x):
if (type(x) is float) or (type(x) is int):
for i in range(len(self.early_exit_entropy)):
self.early_exit_entropy[i] = x
else:
self.early_exit_entropy = x
def init_highway_pooler(self, pooler):
loaded_model = pooler.state_dict()
for highway in self.highway:
for name, param in highway.pooler.state_dict().items():
param.copy_(loaded_model[name])
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
):
all_hidden_states = ()
all_attentions = ()
all_highway_exits = ()
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask
)
hidden_states = layer_outputs[0]
if self.output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
current_outputs = (hidden_states,)
if self.output_hidden_states:
current_outputs = current_outputs + (all_hidden_states,)
if self.output_attentions:
current_outputs = current_outputs + (all_attentions,)
highway_exit = self.highway[i](current_outputs)
# logits, pooled_output
if not self.training:
highway_logits = highway_exit[0]
highway_entropy = entropy(highway_logits)
highway_exit = highway_exit + (highway_entropy,) # logits, hidden_states(?), entropy
all_highway_exits = all_highway_exits + (highway_exit,)
if highway_entropy < self.early_exit_entropy[i]:
new_output = (highway_logits,) + current_outputs[1:] + (all_highway_exits,)
raise HighwayException(new_output, i + 1)
else:
all_highway_exits = all_highway_exits + (highway_exit,)
# Add last layer
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
outputs = outputs + (all_attentions,)
outputs = outputs + (all_highway_exits,)
return outputs # last-layer hidden state, (all hidden states), (all attentions), all highway exits
@add_start_docstrings(
"The Bert Model transformer with early exiting (DeeBERT). ",
BERT_START_DOCSTRING,
)
class DeeBertModel(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.embeddings = BertEmbeddings(config)
self.encoder = DeeBertEncoder(config)
self.pooler = BertPooler(config)
self.init_weights()
def init_highway_pooler(self):
self.encoder.init_highway_pooler(self.pooler)
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction (classification)
objective during pre-training.
This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
highway_exits (:obj:`tuple(tuple(torch.Tensor))`:
Tuple of each early exit's results (total length: number of layers)
Each tuple is again, a tuple of length 2 - the first entry is logits and the second entry is hidden states.
"""
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if encoder_attention_mask.dim() == 3:
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
if encoder_attention_mask.dim() == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
encoder_extended_attention_mask = encoder_extended_attention_mask.to(
dtype=next(self.parameters()).dtype
) # fp16 compatibility
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
outputs = (sequence_output, pooled_output,) + encoder_outputs[
1:
] # add hidden_states and attentions if they are here
return outputs # sequence_output, pooled_output, (hidden_states), (attentions), highway exits
class HighwayException(Exception):
def __init__(self, message, exit_layer):
self.message = message
self.exit_layer = exit_layer # start from 1!
class BertHighway(nn.Module):
"""A module to provide a shortcut
from (the output of one non-final BertLayer in BertEncoder) to (cross-entropy computation in BertForSequenceClassification)
"""
def __init__(self, config):
super().__init__()
self.pooler = BertPooler(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, encoder_outputs):
# Pooler
pooler_input = encoder_outputs[0]
pooler_output = self.pooler(pooler_input)
# "return" pooler_output
# BertModel
bmodel_output = (pooler_input, pooler_output) + encoder_outputs[1:]
# "return" bmodel_output
# Dropout and classification
pooled_output = bmodel_output[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return logits, pooled_output
@add_start_docstrings(
"""Bert Model (with early exiting - DeeBERT) with a classifier on top,
also takes care of multi-layer training. """,
BERT_START_DOCSTRING,
)
class DeeBertForSequenceClassification(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.num_layers = config.num_hidden_layers
self.bert = DeeBertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
self.init_weights()
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_layer=-1,
train_highway=False,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss.
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
highway_exits (:obj:`tuple(tuple(torch.Tensor))`:
Tuple of each early exit's results (total length: number of layers)
Each tuple is again, a tuple of length 2 - the first entry is logits and the second entry is hidden states.
"""
exit_layer = self.num_layers
try:
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
# sequence_output, pooled_output, (hidden_states), (attentions), highway exits
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
except HighwayException as e:
outputs = e.message
exit_layer = e.exit_layer
logits = outputs[0]
if not self.training:
original_entropy = entropy(logits)
highway_entropy = []
highway_logits_all = []
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
# work with highway exits
highway_losses = []
for highway_exit in outputs[-1]:
highway_logits = highway_exit[0]
if not self.training:
highway_logits_all.append(highway_logits)
highway_entropy.append(highway_exit[2])
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
highway_loss = loss_fct(highway_logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
highway_loss = loss_fct(highway_logits.view(-1, self.num_labels), labels.view(-1))
highway_losses.append(highway_loss)
if train_highway:
outputs = (sum(highway_losses[:-1]),) + outputs
# exclude the final highway, of course
else:
outputs = (loss,) + outputs
if not self.training:
outputs = outputs + ((original_entropy, highway_entropy), exit_layer)
if output_layer >= 0:
outputs = (
(outputs[0],) + (highway_logits_all[output_layer],) + outputs[2:]
) # use the highway of the last layer
return outputs # (loss), logits, (hidden_states), (attentions), (highway_exits)

Event Timeline