Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F61651224
modeling_flax_performer_utils.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Wed, May 8, 02:34
Size
25 KB
Mime Type
text/x-python
Expires
Fri, May 10, 02:34 (1 d, 23 h)
Engine
blob
Format
Raw Data
Handle
17540664
Attached To
R11484 ADDI
modeling_flax_performer_utils.py
View Options
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
IMPORTANT:
This code was copied from
https://github.com/google-research/google-research/blob/master/performer/fast_self_attention/fast_self_attention.py on
6/11/2020. This is very new code, so it might be prone to change soon -> make sure to check the original code and
update accordingly
Core Fast Attention Module for Flax. Implementation of the approximate fast softmax and generalized attention mechanism
leveraging structured random feature maps [RFM] techniques and low rank decomposition of the attention matrix.
"""
# pylint: disable=invalid-name, missing-function-docstring, line-too-long
import
abc
import
functools
from
collections.abc
import
Iterable
# pylint: disable=g-importing-member
import
numpy
as
onp
from
absl
import
logging
import
jax
import
jax.numpy
as
jnp
from
jax
import
lax
,
random
def
nonnegative_softmax_kernel_feature_creator
(
data
,
projection_matrix
,
attention_dims_t
,
batch_dims_t
,
precision
,
is_query
,
normalize_data
=
True
,
eps
=
0.0001
):
"""
Constructs nonnegative kernel features for fast softmax attention
Args:
data: input for which features are computes
projection_matrix: random matrix used to compute features
attention_dims_t: tuple of attention dimensions
batch_dims_t: tuple of batch dimensions
precision: precision parameter
is_query: predicate indicating whether input data corresponds to queries or
keys
normalize_data: predicate indicating whether data should be normalized,
eps: numerical stabilizer
Returns:
Random features for fast softmax attention.
"""
del
attention_dims_t
if
normalize_data
:
# We have e^{qk^T/sqrt{d}} = e^{q_norm k_norm^T}, where
# w_norm = w * data_normalizer for w in {q,k}.
data_normalizer
=
1.0
/
(
jnp
.
sqrt
(
jnp
.
sqrt
(
data
.
shape
[
-
1
])))
else
:
data_normalizer
=
1.0
ratio
=
1.0
/
jnp
.
sqrt
(
projection_matrix
.
shape
[
0
])
data_mod_shape
=
data
.
shape
[
0
:
len
(
batch_dims_t
)]
+
projection_matrix
.
shape
data_thick_random_matrix
=
jnp
.
zeros
(
data_mod_shape
)
+
projection_matrix
data_dash
=
lax
.
dot_general
(
data_normalizer
*
data
,
data_thick_random_matrix
,
(((
data
.
ndim
-
1
,),
(
data_thick_random_matrix
.
ndim
-
1
,)),
(
batch_dims_t
,
batch_dims_t
)),
precision
=
precision
,
)
diag_data
=
jnp
.
square
(
data
)
diag_data
=
jnp
.
sum
(
diag_data
,
axis
=
data
.
ndim
-
1
)
diag_data
=
(
diag_data
/
2.0
)
*
data_normalizer
*
data_normalizer
diag_data
=
jnp
.
expand_dims
(
diag_data
,
axis
=
data
.
ndim
-
1
)
if
is_query
:
last_dims_t
=
(
len
(
data_dash
.
shape
)
-
1
,)
data_dash
=
ratio
*
(
jnp
.
exp
(
data_dash
-
diag_data
-
jnp
.
max
(
data_dash
,
axis
=
last_dims_t
,
keepdims
=
True
))
+
eps
)
else
:
data_dash
=
ratio
*
(
jnp
.
exp
(
data_dash
-
diag_data
-
jnp
.
max
(
data_dash
))
+
eps
)
return
data_dash
def
sincos_softmax_kernel_feature_creator
(
data
,
projection_matrix
,
attention_dims_t
,
batch_dims_t
,
precision
,
normalize_data
=
True
):
"""
Constructs kernel sin-cos features for fast softmax attention
Args:
data: input for which features are computes
projection_matrix: random matrix used to compute features
attention_dims_t: tuple of attention dimensions
batch_dims_t: tuple of batch dimensions
precision: precision parameter
normalize_data: predicate indicating whether data should be normalized
Returns:
Random features for fast softmax attention.
"""
if
normalize_data
:
# We have: exp(qk^T/sqrt{d}) = exp(|q|^2/2sqrt{d}) * exp(|k|^2/2sqrt{d}) *
# exp(-(|q*c-k*c|^2)/2), where c = 1.0 / sqrt{sqrt{d}}.
data_normalizer
=
1.0
/
(
jnp
.
sqrt
(
jnp
.
sqrt
(
data
.
shape
[
-
1
])))
else
:
data_normalizer
=
1.0
ratio
=
1.0
/
jnp
.
sqrt
(
projection_matrix
.
shape
[
0
])
data_mod_shape
=
data
.
shape
[
0
:
len
(
batch_dims_t
)]
+
projection_matrix
.
shape
data_thick_random_matrix
=
jnp
.
zeros
(
data_mod_shape
)
+
projection_matrix
data_dash
=
lax
.
dot_general
(
data_normalizer
*
data
,
data_thick_random_matrix
,
(((
data
.
ndim
-
1
,),
(
data_thick_random_matrix
.
ndim
-
1
,)),
(
batch_dims_t
,
batch_dims_t
)),
precision
=
precision
,
)
data_dash_cos
=
ratio
*
jnp
.
cos
(
data_dash
)
data_dash_sin
=
ratio
*
jnp
.
sin
(
data_dash
)
data_dash
=
jnp
.
concatenate
((
data_dash_cos
,
data_dash_sin
),
axis
=-
1
)
# Constructing D_data and data^{'}
diag_data
=
jnp
.
square
(
data
)
diag_data
=
jnp
.
sum
(
diag_data
,
axis
=
data
.
ndim
-
1
)
diag_data
=
(
diag_data
/
2.0
)
*
data_normalizer
*
data_normalizer
diag_data
=
jnp
.
expand_dims
(
diag_data
,
axis
=
data
.
ndim
-
1
)
# Additional renormalization for numerical stability
data_renormalizer
=
jnp
.
max
(
diag_data
,
attention_dims_t
,
keepdims
=
True
)
diag_data
-=
data_renormalizer
diag_data
=
jnp
.
exp
(
diag_data
)
data_prime
=
data_dash
*
diag_data
return
data_prime
def
generalized_kernel_feature_creator
(
data
,
projection_matrix
,
batch_dims_t
,
precision
,
kernel_fn
,
kernel_epsilon
,
normalize_data
):
"""
Constructs kernel features for fast generalized attention
Args:
data: input for which features are computes
projection_matrix: matrix used to compute features
batch_dims_t: tuple of batch dimensions
precision: precision parameter
kernel_fn: kernel function used
kernel_epsilon: additive positive term added to every feature for numerical
stability
normalize_data: predicate indicating whether data should be normalized
Returns:
Random features for fast generalized attention.
"""
if
normalize_data
:
data_normalizer
=
1.0
/
(
jnp
.
sqrt
(
jnp
.
sqrt
(
data
.
shape
[
-
1
])))
else
:
data_normalizer
=
1.0
if
projection_matrix
is
None
:
return
kernel_fn
(
data_normalizer
*
data
)
+
kernel_epsilon
else
:
data_mod_shape
=
data
.
shape
[
0
:
len
(
batch_dims_t
)]
+
projection_matrix
.
shape
data_thick_random_matrix
=
jnp
.
zeros
(
data_mod_shape
)
+
projection_matrix
data_dash
=
lax
.
dot_general
(
data_normalizer
*
data
,
data_thick_random_matrix
,
(((
data
.
ndim
-
1
,),
(
data_thick_random_matrix
.
ndim
-
1
,)),
(
batch_dims_t
,
batch_dims_t
)),
precision
=
precision
,
)
data_prime
=
kernel_fn
(
data_dash
)
+
kernel_epsilon
return
data_prime
def
make_fast_softmax_attention
(
qkv_dim
,
renormalize_attention
=
True
,
numerical_stabilizer
=
0.000001
,
nb_features
=
256
,
ortho_features
=
True
,
ortho_scaling
=
0.0
,
redraw_features
=
True
,
unidirectional
=
False
,
nonnegative_features
=
True
,
lax_scan_unroll
=
1
,
):
"""Construct a fast softmax attention method."""
logging
.
info
(
"Fast softmax attention:
%s
features and orthogonal=
%s
, renormalize=
%s
"
,
nb_features
,
ortho_features
,
renormalize_attention
,
)
if
ortho_features
:
matrix_creator
=
functools
.
partial
(
GaussianOrthogonalRandomMatrix
,
nb_features
,
qkv_dim
,
scaling
=
ortho_scaling
)
else
:
matrix_creator
=
functools
.
partial
(
GaussianUnstructuredRandomMatrix
,
nb_features
,
qkv_dim
)
if
nonnegative_features
:
def
kernel_feature_creator
(
data
,
projection_matrix
,
attention_dims_t
,
batch_dims_t
,
precision
,
is_query
,
normalize_data
=
True
):
return
nonnegative_softmax_kernel_feature_creator
(
data
,
projection_matrix
,
attention_dims_t
,
batch_dims_t
,
precision
,
is_query
,
normalize_data
,
numerical_stabilizer
,
)
else
:
def
kernel_feature_creator
(
data
,
projection_matrix
,
attention_dims_t
,
batch_dims_t
,
precision
,
is_query
,
normalize_data
=
True
):
del
is_query
return
sincos_softmax_kernel_feature_creator
(
data
,
projection_matrix
,
attention_dims_t
,
batch_dims_t
,
precision
,
normalize_data
)
attention_fn
=
FastAttentionviaLowRankDecomposition
(
matrix_creator
,
kernel_feature_creator
,
renormalize_attention
=
renormalize_attention
,
numerical_stabilizer
=
numerical_stabilizer
,
redraw_features
=
redraw_features
,
unidirectional
=
unidirectional
,
lax_scan_unroll
=
lax_scan_unroll
,
)
.
dot_product_attention
return
attention_fn
def
make_fast_generalized_attention
(
qkv_dim
,
renormalize_attention
=
True
,
numerical_stabilizer
=
0.0
,
nb_features
=
256
,
features_type
=
"deterministic"
,
kernel_fn
=
jax
.
nn
.
relu
,
kernel_epsilon
=
0.001
,
redraw_features
=
False
,
unidirectional
=
False
,
lax_scan_unroll
=
1
,
):
"""Construct a fast generalized attention menthod."""
logging
.
info
(
"Fast generalized attention.:
%s
features and renormalize=
%s
"
,
nb_features
,
renormalize_attention
)
if
features_type
==
"ortho"
:
matrix_creator
=
functools
.
partial
(
GaussianOrthogonalRandomMatrix
,
nb_features
,
qkv_dim
,
scaling
=
False
)
elif
features_type
==
"iid"
:
matrix_creator
=
functools
.
partial
(
GaussianUnstructuredRandomMatrix
,
nb_features
,
qkv_dim
)
elif
features_type
==
"deterministic"
:
matrix_creator
=
None
else
:
raise
ValueError
(
"Unknown feature value type"
)
def
kernel_feature_creator
(
data
,
projection_matrix
,
attention_dims_t
,
batch_dims_t
,
precision
,
is_query
,
normalize_data
=
False
):
del
attention_dims_t
del
is_query
return
generalized_kernel_feature_creator
(
data
,
projection_matrix
,
batch_dims_t
,
precision
,
kernel_fn
,
kernel_epsilon
,
normalize_data
)
attention_fn
=
FastAttentionviaLowRankDecomposition
(
matrix_creator
,
kernel_feature_creator
,
renormalize_attention
=
renormalize_attention
,
numerical_stabilizer
=
numerical_stabilizer
,
redraw_features
=
redraw_features
,
unidirectional
=
unidirectional
,
lax_scan_unroll
=
lax_scan_unroll
,
)
.
dot_product_attention
return
attention_fn
class
RandomMatrix
(
object
):
r"""
Abstract class providing a method for constructing 2D random arrays. Class is responsible for constructing 2D
random arrays.
"""
__metaclass__
=
abc
.
ABCMeta
@abc.abstractmethod
def
get_2d_array
(
self
):
raise
NotImplementedError
(
"Abstract method"
)
class
GaussianUnstructuredRandomMatrix
(
RandomMatrix
):
def
__init__
(
self
,
nb_rows
,
nb_columns
,
key
):
self
.
nb_rows
=
nb_rows
self
.
nb_columns
=
nb_columns
self
.
key
=
key
def
get_2d_array
(
self
):
return
random
.
normal
(
self
.
key
,
(
self
.
nb_rows
,
self
.
nb_columns
))
class
GaussianOrthogonalRandomMatrix
(
RandomMatrix
):
r"""
Class providing a method to create Gaussian orthogonal matrix. Class is responsible for constructing 2D Gaussian
orthogonal arrays.
"""
def
__init__
(
self
,
nb_rows
,
nb_columns
,
key
,
scaling
=
0
):
self
.
nb_rows
=
nb_rows
self
.
nb_columns
=
nb_columns
self
.
key
=
key
self
.
scaling
=
scaling
def
get_2d_array
(
self
):
nb_full_blocks
=
int
(
self
.
nb_rows
/
self
.
nb_columns
)
block_list
=
[]
rng
=
self
.
key
for
_
in
range
(
nb_full_blocks
):
rng
,
rng_input
=
jax
.
random
.
split
(
rng
)
unstructured_block
=
random
.
normal
(
rng_input
,
(
self
.
nb_columns
,
self
.
nb_columns
))
q
,
_
=
jnp
.
linalg
.
qr
(
unstructured_block
)
q
=
jnp
.
transpose
(
q
)
block_list
.
append
(
q
)
remaining_rows
=
self
.
nb_rows
-
nb_full_blocks
*
self
.
nb_columns
if
remaining_rows
>
0
:
rng
,
rng_input
=
jax
.
random
.
split
(
rng
)
unstructured_block
=
random
.
normal
(
rng_input
,
(
self
.
nb_columns
,
self
.
nb_columns
))
q
,
_
=
jnp
.
linalg
.
qr
(
unstructured_block
)
q
=
jnp
.
transpose
(
q
)
block_list
.
append
(
q
[
0
:
remaining_rows
])
final_matrix
=
jnp
.
vstack
(
block_list
)
if
self
.
scaling
==
0
:
multiplier
=
jnp
.
linalg
.
norm
(
random
.
normal
(
self
.
key
,
(
self
.
nb_rows
,
self
.
nb_columns
)),
axis
=
1
)
elif
self
.
scaling
==
1
:
multiplier
=
jnp
.
sqrt
(
float
(
self
.
nb_columns
))
*
jnp
.
ones
((
self
.
nb_rows
))
else
:
raise
ValueError
(
"Scaling must be one of {0, 1}. Was
%s
"
%
self
.
_scaling
)
return
jnp
.
matmul
(
jnp
.
diag
(
multiplier
),
final_matrix
)
class
FastAttention
(
object
):
r"""
Abstract class providing a method for fast attention. Class is responsible for providing a method
<dot_product_attention> for fast approximate attention.
"""
__metaclass__
=
abc
.
ABCMeta
@abc.abstractmethod
def
dot_product_attention
(
self
,
query
,
key
,
value
,
dtype
=
jnp
.
float32
,
bias
=
None
,
axis
=
None
,
broadcast_dropout
=
True
,
dropout_rng
=
None
,
dropout_rate
=
0.0
,
deterministic
=
False
,
precision
=
None
,
):
"""
Computes dot-product attention given query, key, and value. This is the core function for applying fast
approximate dot-product attention. It calculates the attention weights given query and key and combines the
values using the attention weights. This function supports multi-dimensional inputs
Args:
query: queries for calculating attention with shape of [batch_size, dim1,
dim2, ..., dimN, num_heads, mem_channels].
key: keys for calculating attention with shape of [batch_size, dim1, dim2,
..., dimN, num_heads, mem_channels].
value: values to be used in attention with shape of [batch_size, dim1,
dim2,..., dimN, num_heads, value_channels].
dtype: the dtype of the computation (default: float32)
bias: bias for the attention weights. This can be used for incorporating
autoregressive mask, padding mask, proximity bias.
axis: axises over which the attention is applied.
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rng: JAX PRNGKey: to be used for dropout.
dropout_rate: dropout rate.
deterministic: bool, deterministic or not (to apply dropout).
precision: numerical precision of the computation see `jax.lax.Precision`
for details
Returns:
Output of shape [bs, dim1, dim2, ..., dimN,, num_heads, value_channels].
"""
raise
NotImplementedError
(
"Abstract method"
)
def
_numerator
(
z_slice_shape
,
precision
,
unroll
=
1
):
def
fwd
(
qs
,
ks
,
vs
):
def
body
(
p
,
qkv
):
(
q
,
k
,
v
)
=
qkv
p
+=
jnp
.
einsum
(
"...m,...d->...md"
,
k
,
v
,
precision
=
precision
)
X_slice
=
jnp
.
einsum
(
"...m,...md->...d"
,
q
,
p
,
precision
=
precision
)
return
p
,
X_slice
init_value
=
jnp
.
zeros
(
z_slice_shape
)
p
,
W
=
lax
.
scan
(
body
,
init_value
,
(
qs
,
ks
,
vs
),
unroll
=
unroll
)
return
W
,
(
p
,
qs
,
ks
,
vs
)
def
bwd
(
pqkv
,
W_ct
):
def
body
(
carry
,
qkv_xct
):
p
,
p_ct
=
carry
q
,
k
,
v
,
x_ct
=
qkv_xct
q_ct
=
jnp
.
einsum
(
"...d,...md->...m"
,
x_ct
,
p
,
precision
=
precision
)
p_ct
+=
jnp
.
einsum
(
"...d,...m->...md"
,
x_ct
,
q
,
precision
=
precision
)
k_ct
=
jnp
.
einsum
(
"...md,...d->...m"
,
p_ct
,
v
,
precision
=
precision
)
v_ct
=
jnp
.
einsum
(
"...md,...m->...d"
,
p_ct
,
k
,
precision
=
precision
)
p
-=
jnp
.
einsum
(
"...m,...d->...md"
,
k
,
v
,
precision
=
precision
)
return
(
p
,
p_ct
),
(
q_ct
,
k_ct
,
v_ct
)
p
,
qs
,
ks
,
vs
=
pqkv
_
,
(
qs_ct
,
ks_ct
,
vs_ct
)
=
lax
.
scan
(
body
,
(
p
,
jnp
.
zeros_like
(
p
)),
(
qs
,
ks
,
vs
,
W_ct
),
reverse
=
True
,
unroll
=
unroll
)
return
qs_ct
,
ks_ct
,
vs_ct
@jax.custom_vjp
def
_numerator_impl
(
qs
,
ks
,
vs
):
W
,
_
=
fwd
(
qs
,
ks
,
vs
)
return
W
_numerator_impl
.
defvjp
(
fwd
,
bwd
)
return
_numerator_impl
def
_denominator
(
t_slice_shape
,
precision
,
unroll
=
1
):
def
fwd
(
qs
,
ks
):
def
body
(
p
,
qk
):
q
,
k
=
qk
p
+=
k
x
=
jnp
.
einsum
(
"...m,...m->..."
,
q
,
p
,
precision
=
precision
)
return
p
,
x
p
=
jnp
.
zeros
(
t_slice_shape
)
p
,
R
=
lax
.
scan
(
body
,
p
,
(
qs
,
ks
),
unroll
=
unroll
)
return
R
,
(
qs
,
ks
,
p
)
def
bwd
(
qkp
,
R_ct
):
def
body
(
carry
,
qkx
):
p
,
p_ct
=
carry
q
,
k
,
x_ct
=
qkx
q_ct
=
jnp
.
einsum
(
"...,...m->...m"
,
x_ct
,
p
,
precision
=
precision
)
p_ct
+=
jnp
.
einsum
(
"...,...m->...m"
,
x_ct
,
q
,
precision
=
precision
)
k_ct
=
p_ct
p
-=
k
return
(
p
,
p_ct
),
(
q_ct
,
k_ct
)
qs
,
ks
,
p
=
qkp
_
,
(
qs_ct
,
ks_ct
)
=
lax
.
scan
(
body
,
(
p
,
jnp
.
zeros_like
(
p
)),
(
qs
,
ks
,
R_ct
),
reverse
=
True
,
unroll
=
unroll
)
return
(
qs_ct
,
ks_ct
)
@jax.custom_vjp
def
_denominator_impl
(
qs
,
ks
):
R
,
_
=
fwd
(
qs
,
ks
)
return
R
_denominator_impl
.
defvjp
(
fwd
,
bwd
)
return
_denominator_impl
class
FastAttentionviaLowRankDecomposition
(
FastAttention
):
r"""
Class providing a method for fast attention via low rank decomposition. Class is responsible for providing a method
<dot_product_attention> for fast dot-product attention with the use of low rank decomposition (e.g. with random
feature maps).
"""
def
__init__
(
self
,
matrix_creator
,
kernel_feature_creator
,
renormalize_attention
,
numerical_stabilizer
,
redraw_features
,
unidirectional
,
lax_scan_unroll
=
1
,
):
# For optimal GPU performance, set to 16.
rng
=
random
.
PRNGKey
(
0
)
self
.
matrix_creator
=
matrix_creator
self
.
projection_matrix
=
self
.
draw_weights
(
rng
)
self
.
kernel_feature_creator
=
kernel_feature_creator
self
.
renormalize_attention
=
renormalize_attention
self
.
numerical_stabilizer
=
numerical_stabilizer
self
.
redraw_features
=
redraw_features
self
.
unidirectional
=
unidirectional
self
.
lax_scan_unroll
=
lax_scan_unroll
def
draw_weights
(
self
,
key
):
if
self
.
matrix_creator
is
None
:
return
None
matrixrng
,
_
=
random
.
split
(
key
)
projection_matrix
=
self
.
matrix_creator
(
key
=
matrixrng
)
.
get_2d_array
()
return
projection_matrix
def
dot_product_attention
(
self
,
query
,
key
,
value
,
dtype
=
jnp
.
float32
,
bias
=
None
,
axis
=
None
,
broadcast_dropout
=
True
,
dropout_rng
=
None
,
dropout_rate
=
0.0
,
deterministic
=
False
,
precision
=
None
,
):
assert
key
.
shape
[:
-
1
]
==
value
.
shape
[:
-
1
]
assert
query
.
shape
[
0
:
1
]
==
key
.
shape
[
0
:
1
]
and
query
.
shape
[
-
1
]
==
key
.
shape
[
-
1
]
if
axis
is
None
:
axis
=
tuple
(
range
(
1
,
key
.
ndim
-
2
))
if
not
isinstance
(
axis
,
Iterable
):
axis
=
(
axis
,)
assert
key
.
ndim
==
query
.
ndim
assert
key
.
ndim
==
value
.
ndim
for
ax
in
axis
:
if
not
(
query
.
ndim
>=
3
and
1
<=
ax
<
query
.
ndim
-
2
):
raise
ValueError
(
"Attention axis must be between the batch "
"axis and the last-two axes."
)
n
=
key
.
ndim
# Constructing projection tensor.
if
self
.
redraw_features
:
# TODO(kchoro): Get rid of the constant below.
query_seed
=
lax
.
convert_element_type
(
jnp
.
ceil
(
jnp
.
sum
(
query
)
*
10000000.0
),
jnp
.
int32
)
rng
=
random
.
PRNGKey
(
query_seed
)
self
.
projection_matrix
=
self
.
draw_weights
(
rng
)
# batch_dims is <bs, <non-attention dims>, num_heads>
batch_dims
=
tuple
(
onp
.
delete
(
range
(
n
),
axis
+
(
n
-
1
,)))
# q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)
qk_perm
=
batch_dims
+
axis
+
(
n
-
1
,)
k_extra_perm
=
axis
+
batch_dims
+
(
n
-
1
,)
key_extra
=
key
.
transpose
(
k_extra_perm
)
key
=
key
.
transpose
(
qk_perm
)
query
=
query
.
transpose
(
qk_perm
)
# v -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)
v_perm
=
batch_dims
+
axis
+
(
n
-
1
,)
value
=
value
.
transpose
(
v_perm
)
batch_dims_t
=
tuple
(
range
(
len
(
batch_dims
)))
attention_dims_t
=
tuple
(
range
(
len
(
batch_dims
),
len
(
batch_dims
)
+
len
(
axis
)))
# Constructing tensors Q^{'} and K^{'}.
query_prime
=
self
.
kernel_feature_creator
(
query
,
self
.
projection_matrix
,
attention_dims_t
,
batch_dims_t
,
precision
,
True
)
key_prime
=
self
.
kernel_feature_creator
(
key
,
self
.
projection_matrix
,
attention_dims_t
,
batch_dims_t
,
precision
,
False
)
if
self
.
unidirectional
:
index
=
attention_dims_t
[
0
]
z_slice_shape
=
key_prime
.
shape
[
0
:
len
(
batch_dims_t
)]
+
(
key_prime
.
shape
[
-
1
],)
+
(
value
.
shape
[
-
1
],)
numerator_fn
=
_numerator
(
z_slice_shape
,
precision
,
self
.
lax_scan_unroll
)
W
=
numerator_fn
(
jnp
.
moveaxis
(
query_prime
,
index
,
0
),
jnp
.
moveaxis
(
key_prime
,
index
,
0
),
jnp
.
moveaxis
(
value
,
index
,
0
)
)
# Constructing W = (Q^{'}(K^{'})^{T})_{masked}V
W
=
jnp
.
moveaxis
(
W
,
0
,
index
)
if
not
self
.
renormalize_attention
:
# Unidirectional, not-normalized attention.
perm_inv
=
_invert_perm
(
qk_perm
)
result
=
W
.
transpose
(
perm_inv
)
return
result
else
:
# Unidirectional, normalized attention.
thick_all_ones
=
jnp
.
zeros
(
key
.
shape
[
0
:
-
1
])
+
jnp
.
ones
(
key_extra
.
shape
[
0
:
len
(
axis
)])
index
=
attention_dims_t
[
0
]
t_slice_shape
=
key_prime
.
shape
[
0
:
len
(
batch_dims_t
)]
+
(
key_prime
.
shape
[
-
1
],)
denominator_fn
=
_denominator
(
t_slice_shape
,
precision
,
self
.
lax_scan_unroll
)
R
=
denominator_fn
(
jnp
.
moveaxis
(
query_prime
,
index
,
0
),
jnp
.
moveaxis
(
key_prime
,
index
,
0
))
R
=
jnp
.
moveaxis
(
R
,
0
,
index
)
else
:
contract_query
=
tuple
(
range
(
len
(
batch_dims
)
+
len
(
axis
),
len
(
batch_dims
)
+
len
(
axis
)
+
1
))
contract_z
=
tuple
(
range
(
len
(
batch_dims
),
len
(
batch_dims
)
+
1
))
# Constructing Z = (K^{'})^{T}V
# Z (bs, <non-attention dims>, num_heads, channels_m, channels_v)
Z
=
lax
.
dot_general
(
key_prime
,
value
,
((
attention_dims_t
,
attention_dims_t
),
(
batch_dims_t
,
batch_dims_t
)),
precision
=
precision
,
)
# Constructing W = Q^{'}Z = Q^{'}(K^{'})^{T}V
# q (bs, <non-attention dims>, num_heads, <attention dims>, channels_m)
# Z (bs, <non-attention dims>, num_heads, channels_m, channels_v)
# W (bs, <non-attention dims>, num_heads, <attention dims>, channels_v)
W
=
lax
.
dot_general
(
query_prime
,
Z
,
((
contract_query
,
contract_z
),
(
batch_dims_t
,
batch_dims_t
)),
precision
=
precision
)
if
not
self
.
renormalize_attention
:
# Bidirectional, not-normalized attention.
perm_inv
=
_invert_perm
(
qk_perm
)
result
=
W
.
transpose
(
perm_inv
)
return
result
else
:
# Bidirectional, normalized attention.
thick_all_ones
=
jnp
.
zeros
(
key
.
shape
[
0
:
-
1
])
+
jnp
.
ones
(
key_extra
.
shape
[
0
:
len
(
axis
)])
contract_key
=
tuple
(
range
(
len
(
batch_dims
),
len
(
batch_dims
)
+
len
(
axis
)))
contract_thick_all_ones
=
tuple
(
range
(
thick_all_ones
.
ndim
-
len
(
axis
),
thick_all_ones
.
ndim
))
# Construct T = (K^{'})^{T} 1_L
# k (bs, <non-attention dims>, num_heads, <attention dims>, channels)
T
=
lax
.
dot_general
(
key_prime
,
thick_all_ones
,
((
contract_key
,
contract_thick_all_ones
),
(
batch_dims_t
,
batch_dims_t
)),
precision
=
precision
,
)
# Construct partition function: R = Q^{'} T = Q^{'}(K^{'})^{T} 1_L
# q_p (bs, <non-attention dims>, num_heads, <attention dims>, channs_m)
# T (bs, <non-attention dims>, num_heads, channels_m)
R
=
lax
.
dot_general
(
query_prime
,
T
,
(((
query_prime
.
ndim
-
1
,),
(
T
.
ndim
-
1
,)),
(
batch_dims_t
,
range
(
0
,
len
(
T
.
shape
)
-
1
))),
precision
=
precision
,
)
R
=
R
+
2
*
self
.
numerical_stabilizer
*
(
jnp
.
abs
(
R
)
<=
self
.
numerical_stabilizer
)
R
=
jnp
.
reciprocal
(
R
)
R
=
jnp
.
expand_dims
(
R
,
len
(
R
.
shape
))
# W (bs, <non-attention dims>, num_heads, <attention dims>, channels_v)
# R (bs, <non-attention dims>, num_heads, <attention dims>, extra_channel)
result
=
W
*
R
# back to (bs, dim1, dim2, ..., dimN, num_heads, channels)
perm_inv
=
_invert_perm
(
qk_perm
)
result
=
result
.
transpose
(
perm_inv
)
return
result
def
_invert_perm
(
perm
):
perm_inv
=
[
0
]
*
len
(
perm
)
for
i
,
j
in
enumerate
(
perm
):
perm_inv
[
j
]
=
i
return
tuple
(
perm_inv
)
Event Timeline
Log In to Comment