Page MenuHomec4science

modeling_flax_performer_utils.py
No OneTemporary

File Metadata

Created
Wed, May 8, 02:34

modeling_flax_performer_utils.py

# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
IMPORTANT:
This code was copied from
https://github.com/google-research/google-research/blob/master/performer/fast_self_attention/fast_self_attention.py on
6/11/2020. This is very new code, so it might be prone to change soon -> make sure to check the original code and
update accordingly
Core Fast Attention Module for Flax. Implementation of the approximate fast softmax and generalized attention mechanism
leveraging structured random feature maps [RFM] techniques and low rank decomposition of the attention matrix.
"""
# pylint: disable=invalid-name, missing-function-docstring, line-too-long
import abc
import functools
from collections.abc import Iterable # pylint: disable=g-importing-member
import numpy as onp
from absl import logging
import jax
import jax.numpy as jnp
from jax import lax, random
def nonnegative_softmax_kernel_feature_creator(
data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=True, eps=0.0001
):
"""
Constructs nonnegative kernel features for fast softmax attention
Args:
data: input for which features are computes
projection_matrix: random matrix used to compute features
attention_dims_t: tuple of attention dimensions
batch_dims_t: tuple of batch dimensions
precision: precision parameter
is_query: predicate indicating whether input data corresponds to queries or
keys
normalize_data: predicate indicating whether data should be normalized,
eps: numerical stabilizer
Returns:
Random features for fast softmax attention.
"""
del attention_dims_t
if normalize_data:
# We have e^{qk^T/sqrt{d}} = e^{q_norm k_norm^T}, where
# w_norm = w * data_normalizer for w in {q,k}.
data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1])))
else:
data_normalizer = 1.0
ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0])
data_mod_shape = data.shape[0 : len(batch_dims_t)] + projection_matrix.shape
data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix
data_dash = lax.dot_general(
data_normalizer * data,
data_thick_random_matrix,
(((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), (batch_dims_t, batch_dims_t)),
precision=precision,
)
diag_data = jnp.square(data)
diag_data = jnp.sum(diag_data, axis=data.ndim - 1)
diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer
diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1)
if is_query:
last_dims_t = (len(data_dash.shape) - 1,)
data_dash = ratio * (
jnp.exp(data_dash - diag_data - jnp.max(data_dash, axis=last_dims_t, keepdims=True)) + eps
)
else:
data_dash = ratio * (jnp.exp(data_dash - diag_data - jnp.max(data_dash)) + eps)
return data_dash
def sincos_softmax_kernel_feature_creator(
data, projection_matrix, attention_dims_t, batch_dims_t, precision, normalize_data=True
):
"""
Constructs kernel sin-cos features for fast softmax attention
Args:
data: input for which features are computes
projection_matrix: random matrix used to compute features
attention_dims_t: tuple of attention dimensions
batch_dims_t: tuple of batch dimensions
precision: precision parameter
normalize_data: predicate indicating whether data should be normalized
Returns:
Random features for fast softmax attention.
"""
if normalize_data:
# We have: exp(qk^T/sqrt{d}) = exp(|q|^2/2sqrt{d}) * exp(|k|^2/2sqrt{d}) *
# exp(-(|q*c-k*c|^2)/2), where c = 1.0 / sqrt{sqrt{d}}.
data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1])))
else:
data_normalizer = 1.0
ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0])
data_mod_shape = data.shape[0 : len(batch_dims_t)] + projection_matrix.shape
data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix
data_dash = lax.dot_general(
data_normalizer * data,
data_thick_random_matrix,
(((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), (batch_dims_t, batch_dims_t)),
precision=precision,
)
data_dash_cos = ratio * jnp.cos(data_dash)
data_dash_sin = ratio * jnp.sin(data_dash)
data_dash = jnp.concatenate((data_dash_cos, data_dash_sin), axis=-1)
# Constructing D_data and data^{'}
diag_data = jnp.square(data)
diag_data = jnp.sum(diag_data, axis=data.ndim - 1)
diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer
diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1)
# Additional renormalization for numerical stability
data_renormalizer = jnp.max(diag_data, attention_dims_t, keepdims=True)
diag_data -= data_renormalizer
diag_data = jnp.exp(diag_data)
data_prime = data_dash * diag_data
return data_prime
def generalized_kernel_feature_creator(
data, projection_matrix, batch_dims_t, precision, kernel_fn, kernel_epsilon, normalize_data
):
"""
Constructs kernel features for fast generalized attention
Args:
data: input for which features are computes
projection_matrix: matrix used to compute features
batch_dims_t: tuple of batch dimensions
precision: precision parameter
kernel_fn: kernel function used
kernel_epsilon: additive positive term added to every feature for numerical
stability
normalize_data: predicate indicating whether data should be normalized
Returns:
Random features for fast generalized attention.
"""
if normalize_data:
data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1])))
else:
data_normalizer = 1.0
if projection_matrix is None:
return kernel_fn(data_normalizer * data) + kernel_epsilon
else:
data_mod_shape = data.shape[0 : len(batch_dims_t)] + projection_matrix.shape
data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix
data_dash = lax.dot_general(
data_normalizer * data,
data_thick_random_matrix,
(((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), (batch_dims_t, batch_dims_t)),
precision=precision,
)
data_prime = kernel_fn(data_dash) + kernel_epsilon
return data_prime
def make_fast_softmax_attention(
qkv_dim,
renormalize_attention=True,
numerical_stabilizer=0.000001,
nb_features=256,
ortho_features=True,
ortho_scaling=0.0,
redraw_features=True,
unidirectional=False,
nonnegative_features=True,
lax_scan_unroll=1,
):
"""Construct a fast softmax attention method."""
logging.info(
"Fast softmax attention: %s features and orthogonal=%s, renormalize=%s",
nb_features,
ortho_features,
renormalize_attention,
)
if ortho_features:
matrix_creator = functools.partial(GaussianOrthogonalRandomMatrix, nb_features, qkv_dim, scaling=ortho_scaling)
else:
matrix_creator = functools.partial(GaussianUnstructuredRandomMatrix, nb_features, qkv_dim)
if nonnegative_features:
def kernel_feature_creator(
data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=True
):
return nonnegative_softmax_kernel_feature_creator(
data,
projection_matrix,
attention_dims_t,
batch_dims_t,
precision,
is_query,
normalize_data,
numerical_stabilizer,
)
else:
def kernel_feature_creator(
data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=True
):
del is_query
return sincos_softmax_kernel_feature_creator(
data, projection_matrix, attention_dims_t, batch_dims_t, precision, normalize_data
)
attention_fn = FastAttentionviaLowRankDecomposition(
matrix_creator,
kernel_feature_creator,
renormalize_attention=renormalize_attention,
numerical_stabilizer=numerical_stabilizer,
redraw_features=redraw_features,
unidirectional=unidirectional,
lax_scan_unroll=lax_scan_unroll,
).dot_product_attention
return attention_fn
def make_fast_generalized_attention(
qkv_dim,
renormalize_attention=True,
numerical_stabilizer=0.0,
nb_features=256,
features_type="deterministic",
kernel_fn=jax.nn.relu,
kernel_epsilon=0.001,
redraw_features=False,
unidirectional=False,
lax_scan_unroll=1,
):
"""Construct a fast generalized attention menthod."""
logging.info("Fast generalized attention.: %s features and renormalize=%s", nb_features, renormalize_attention)
if features_type == "ortho":
matrix_creator = functools.partial(GaussianOrthogonalRandomMatrix, nb_features, qkv_dim, scaling=False)
elif features_type == "iid":
matrix_creator = functools.partial(GaussianUnstructuredRandomMatrix, nb_features, qkv_dim)
elif features_type == "deterministic":
matrix_creator = None
else:
raise ValueError("Unknown feature value type")
def kernel_feature_creator(
data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=False
):
del attention_dims_t
del is_query
return generalized_kernel_feature_creator(
data, projection_matrix, batch_dims_t, precision, kernel_fn, kernel_epsilon, normalize_data
)
attention_fn = FastAttentionviaLowRankDecomposition(
matrix_creator,
kernel_feature_creator,
renormalize_attention=renormalize_attention,
numerical_stabilizer=numerical_stabilizer,
redraw_features=redraw_features,
unidirectional=unidirectional,
lax_scan_unroll=lax_scan_unroll,
).dot_product_attention
return attention_fn
class RandomMatrix(object):
r"""
Abstract class providing a method for constructing 2D random arrays. Class is responsible for constructing 2D
random arrays.
"""
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def get_2d_array(self):
raise NotImplementedError("Abstract method")
class GaussianUnstructuredRandomMatrix(RandomMatrix):
def __init__(self, nb_rows, nb_columns, key):
self.nb_rows = nb_rows
self.nb_columns = nb_columns
self.key = key
def get_2d_array(self):
return random.normal(self.key, (self.nb_rows, self.nb_columns))
class GaussianOrthogonalRandomMatrix(RandomMatrix):
r"""
Class providing a method to create Gaussian orthogonal matrix. Class is responsible for constructing 2D Gaussian
orthogonal arrays.
"""
def __init__(self, nb_rows, nb_columns, key, scaling=0):
self.nb_rows = nb_rows
self.nb_columns = nb_columns
self.key = key
self.scaling = scaling
def get_2d_array(self):
nb_full_blocks = int(self.nb_rows / self.nb_columns)
block_list = []
rng = self.key
for _ in range(nb_full_blocks):
rng, rng_input = jax.random.split(rng)
unstructured_block = random.normal(rng_input, (self.nb_columns, self.nb_columns))
q, _ = jnp.linalg.qr(unstructured_block)
q = jnp.transpose(q)
block_list.append(q)
remaining_rows = self.nb_rows - nb_full_blocks * self.nb_columns
if remaining_rows > 0:
rng, rng_input = jax.random.split(rng)
unstructured_block = random.normal(rng_input, (self.nb_columns, self.nb_columns))
q, _ = jnp.linalg.qr(unstructured_block)
q = jnp.transpose(q)
block_list.append(q[0:remaining_rows])
final_matrix = jnp.vstack(block_list)
if self.scaling == 0:
multiplier = jnp.linalg.norm(random.normal(self.key, (self.nb_rows, self.nb_columns)), axis=1)
elif self.scaling == 1:
multiplier = jnp.sqrt(float(self.nb_columns)) * jnp.ones((self.nb_rows))
else:
raise ValueError("Scaling must be one of {0, 1}. Was %s" % self._scaling)
return jnp.matmul(jnp.diag(multiplier), final_matrix)
class FastAttention(object):
r"""
Abstract class providing a method for fast attention. Class is responsible for providing a method
<dot_product_attention> for fast approximate attention.
"""
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def dot_product_attention(
self,
query,
key,
value,
dtype=jnp.float32,
bias=None,
axis=None,
broadcast_dropout=True,
dropout_rng=None,
dropout_rate=0.0,
deterministic=False,
precision=None,
):
"""
Computes dot-product attention given query, key, and value. This is the core function for applying fast
approximate dot-product attention. It calculates the attention weights given query and key and combines the
values using the attention weights. This function supports multi-dimensional inputs
Args:
query: queries for calculating attention with shape of [batch_size, dim1,
dim2, ..., dimN, num_heads, mem_channels].
key: keys for calculating attention with shape of [batch_size, dim1, dim2,
..., dimN, num_heads, mem_channels].
value: values to be used in attention with shape of [batch_size, dim1,
dim2,..., dimN, num_heads, value_channels].
dtype: the dtype of the computation (default: float32)
bias: bias for the attention weights. This can be used for incorporating
autoregressive mask, padding mask, proximity bias.
axis: axises over which the attention is applied.
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rng: JAX PRNGKey: to be used for dropout.
dropout_rate: dropout rate.
deterministic: bool, deterministic or not (to apply dropout).
precision: numerical precision of the computation see `jax.lax.Precision`
for details
Returns:
Output of shape [bs, dim1, dim2, ..., dimN,, num_heads, value_channels].
"""
raise NotImplementedError("Abstract method")
def _numerator(z_slice_shape, precision, unroll=1):
def fwd(qs, ks, vs):
def body(p, qkv):
(q, k, v) = qkv
p += jnp.einsum("...m,...d->...md", k, v, precision=precision)
X_slice = jnp.einsum("...m,...md->...d", q, p, precision=precision)
return p, X_slice
init_value = jnp.zeros(z_slice_shape)
p, W = lax.scan(body, init_value, (qs, ks, vs), unroll=unroll)
return W, (p, qs, ks, vs)
def bwd(pqkv, W_ct):
def body(carry, qkv_xct):
p, p_ct = carry
q, k, v, x_ct = qkv_xct
q_ct = jnp.einsum("...d,...md->...m", x_ct, p, precision=precision)
p_ct += jnp.einsum("...d,...m->...md", x_ct, q, precision=precision)
k_ct = jnp.einsum("...md,...d->...m", p_ct, v, precision=precision)
v_ct = jnp.einsum("...md,...m->...d", p_ct, k, precision=precision)
p -= jnp.einsum("...m,...d->...md", k, v, precision=precision)
return (p, p_ct), (q_ct, k_ct, v_ct)
p, qs, ks, vs = pqkv
_, (qs_ct, ks_ct, vs_ct) = lax.scan(
body, (p, jnp.zeros_like(p)), (qs, ks, vs, W_ct), reverse=True, unroll=unroll
)
return qs_ct, ks_ct, vs_ct
@jax.custom_vjp
def _numerator_impl(qs, ks, vs):
W, _ = fwd(qs, ks, vs)
return W
_numerator_impl.defvjp(fwd, bwd)
return _numerator_impl
def _denominator(t_slice_shape, precision, unroll=1):
def fwd(qs, ks):
def body(p, qk):
q, k = qk
p += k
x = jnp.einsum("...m,...m->...", q, p, precision=precision)
return p, x
p = jnp.zeros(t_slice_shape)
p, R = lax.scan(body, p, (qs, ks), unroll=unroll)
return R, (qs, ks, p)
def bwd(qkp, R_ct):
def body(carry, qkx):
p, p_ct = carry
q, k, x_ct = qkx
q_ct = jnp.einsum("...,...m->...m", x_ct, p, precision=precision)
p_ct += jnp.einsum("...,...m->...m", x_ct, q, precision=precision)
k_ct = p_ct
p -= k
return (p, p_ct), (q_ct, k_ct)
qs, ks, p = qkp
_, (qs_ct, ks_ct) = lax.scan(body, (p, jnp.zeros_like(p)), (qs, ks, R_ct), reverse=True, unroll=unroll)
return (qs_ct, ks_ct)
@jax.custom_vjp
def _denominator_impl(qs, ks):
R, _ = fwd(qs, ks)
return R
_denominator_impl.defvjp(fwd, bwd)
return _denominator_impl
class FastAttentionviaLowRankDecomposition(FastAttention):
r"""
Class providing a method for fast attention via low rank decomposition. Class is responsible for providing a method
<dot_product_attention> for fast dot-product attention with the use of low rank decomposition (e.g. with random
feature maps).
"""
def __init__(
self,
matrix_creator,
kernel_feature_creator,
renormalize_attention,
numerical_stabilizer,
redraw_features,
unidirectional,
lax_scan_unroll=1,
): # For optimal GPU performance, set to 16.
rng = random.PRNGKey(0)
self.matrix_creator = matrix_creator
self.projection_matrix = self.draw_weights(rng)
self.kernel_feature_creator = kernel_feature_creator
self.renormalize_attention = renormalize_attention
self.numerical_stabilizer = numerical_stabilizer
self.redraw_features = redraw_features
self.unidirectional = unidirectional
self.lax_scan_unroll = lax_scan_unroll
def draw_weights(self, key):
if self.matrix_creator is None:
return None
matrixrng, _ = random.split(key)
projection_matrix = self.matrix_creator(key=matrixrng).get_2d_array()
return projection_matrix
def dot_product_attention(
self,
query,
key,
value,
dtype=jnp.float32,
bias=None,
axis=None,
broadcast_dropout=True,
dropout_rng=None,
dropout_rate=0.0,
deterministic=False,
precision=None,
):
assert key.shape[:-1] == value.shape[:-1]
assert query.shape[0:1] == key.shape[0:1] and query.shape[-1] == key.shape[-1]
if axis is None:
axis = tuple(range(1, key.ndim - 2))
if not isinstance(axis, Iterable):
axis = (axis,)
assert key.ndim == query.ndim
assert key.ndim == value.ndim
for ax in axis:
if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2):
raise ValueError("Attention axis must be between the batch " "axis and the last-two axes.")
n = key.ndim
# Constructing projection tensor.
if self.redraw_features:
# TODO(kchoro): Get rid of the constant below.
query_seed = lax.convert_element_type(jnp.ceil(jnp.sum(query) * 10000000.0), jnp.int32)
rng = random.PRNGKey(query_seed)
self.projection_matrix = self.draw_weights(rng)
# batch_dims is <bs, <non-attention dims>, num_heads>
batch_dims = tuple(onp.delete(range(n), axis + (n - 1,)))
# q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)
qk_perm = batch_dims + axis + (n - 1,)
k_extra_perm = axis + batch_dims + (n - 1,)
key_extra = key.transpose(k_extra_perm)
key = key.transpose(qk_perm)
query = query.transpose(qk_perm)
# v -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)
v_perm = batch_dims + axis + (n - 1,)
value = value.transpose(v_perm)
batch_dims_t = tuple(range(len(batch_dims)))
attention_dims_t = tuple(range(len(batch_dims), len(batch_dims) + len(axis)))
# Constructing tensors Q^{'} and K^{'}.
query_prime = self.kernel_feature_creator(
query, self.projection_matrix, attention_dims_t, batch_dims_t, precision, True
)
key_prime = self.kernel_feature_creator(
key, self.projection_matrix, attention_dims_t, batch_dims_t, precision, False
)
if self.unidirectional:
index = attention_dims_t[0]
z_slice_shape = key_prime.shape[0 : len(batch_dims_t)] + (key_prime.shape[-1],) + (value.shape[-1],)
numerator_fn = _numerator(z_slice_shape, precision, self.lax_scan_unroll)
W = numerator_fn(
jnp.moveaxis(query_prime, index, 0), jnp.moveaxis(key_prime, index, 0), jnp.moveaxis(value, index, 0)
)
# Constructing W = (Q^{'}(K^{'})^{T})_{masked}V
W = jnp.moveaxis(W, 0, index)
if not self.renormalize_attention:
# Unidirectional, not-normalized attention.
perm_inv = _invert_perm(qk_perm)
result = W.transpose(perm_inv)
return result
else:
# Unidirectional, normalized attention.
thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones(key_extra.shape[0 : len(axis)])
index = attention_dims_t[0]
t_slice_shape = key_prime.shape[0 : len(batch_dims_t)] + (key_prime.shape[-1],)
denominator_fn = _denominator(t_slice_shape, precision, self.lax_scan_unroll)
R = denominator_fn(jnp.moveaxis(query_prime, index, 0), jnp.moveaxis(key_prime, index, 0))
R = jnp.moveaxis(R, 0, index)
else:
contract_query = tuple(range(len(batch_dims) + len(axis), len(batch_dims) + len(axis) + 1))
contract_z = tuple(range(len(batch_dims), len(batch_dims) + 1))
# Constructing Z = (K^{'})^{T}V
# Z (bs, <non-attention dims>, num_heads, channels_m, channels_v)
Z = lax.dot_general(
key_prime,
value,
((attention_dims_t, attention_dims_t), (batch_dims_t, batch_dims_t)),
precision=precision,
)
# Constructing W = Q^{'}Z = Q^{'}(K^{'})^{T}V
# q (bs, <non-attention dims>, num_heads, <attention dims>, channels_m)
# Z (bs, <non-attention dims>, num_heads, channels_m, channels_v)
# W (bs, <non-attention dims>, num_heads, <attention dims>, channels_v)
W = lax.dot_general(
query_prime, Z, ((contract_query, contract_z), (batch_dims_t, batch_dims_t)), precision=precision
)
if not self.renormalize_attention:
# Bidirectional, not-normalized attention.
perm_inv = _invert_perm(qk_perm)
result = W.transpose(perm_inv)
return result
else:
# Bidirectional, normalized attention.
thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones(key_extra.shape[0 : len(axis)])
contract_key = tuple(range(len(batch_dims), len(batch_dims) + len(axis)))
contract_thick_all_ones = tuple(range(thick_all_ones.ndim - len(axis), thick_all_ones.ndim))
# Construct T = (K^{'})^{T} 1_L
# k (bs, <non-attention dims>, num_heads, <attention dims>, channels)
T = lax.dot_general(
key_prime,
thick_all_ones,
((contract_key, contract_thick_all_ones), (batch_dims_t, batch_dims_t)),
precision=precision,
)
# Construct partition function: R = Q^{'} T = Q^{'}(K^{'})^{T} 1_L
# q_p (bs, <non-attention dims>, num_heads, <attention dims>, channs_m)
# T (bs, <non-attention dims>, num_heads, channels_m)
R = lax.dot_general(
query_prime,
T,
(((query_prime.ndim - 1,), (T.ndim - 1,)), (batch_dims_t, range(0, len(T.shape) - 1))),
precision=precision,
)
R = R + 2 * self.numerical_stabilizer * (jnp.abs(R) <= self.numerical_stabilizer)
R = jnp.reciprocal(R)
R = jnp.expand_dims(R, len(R.shape))
# W (bs, <non-attention dims>, num_heads, <attention dims>, channels_v)
# R (bs, <non-attention dims>, num_heads, <attention dims>, extra_channel)
result = W * R
# back to (bs, dim1, dim2, ..., dimN, num_heads, channels)
perm_inv = _invert_perm(qk_perm)
result = result.transpose(perm_inv)
return result
def _invert_perm(perm):
perm_inv = [0] * len(perm)
for i, j in enumerate(perm):
perm_inv[j] = i
return tuple(perm_inv)

Event Timeline