# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import torch
import torch.nn as nn
from modelzoo.common.pytorch.model_utils.create_initializer import (
create_initializer,
)
from .AttentionLayer import MultiheadAttention
[docs]class MultiQueryAttention(MultiheadAttention):
"""Implements the Multi-Query Attention Layer from
`Fast Transformer Decoding: One Write-Head is All You Need
<https://arxiv.org/abs/1911.02150>`
Args:
embed_dim (int): Number of input units in each projection output
num_heads (int): Number of attention heads.
inner_dim (int): Number of output units in attention query/key/value projection. Defaults to ``embed_dim``.
dropout (float): Dropout rate for key-query weights. Defaults to 0.0.
batch_first (bool): If True, then the input and output tensors are
provided as (batch, seq, feature), otherwise the format will be
(seq, batch, feature). Default: True (batch, seq, feature).
add_bias_kv (bool): If specified, adds bias to the key and value sequences at dim=0. Default: False.
add_zero_attn (bool): If specified, adds a new batch of zeros to the key and value
sequences at dim=1. Default: False
kdim (int): Number of output units in key projection
vdim (int): Number of output units in projection
use_projection_bias (bool): Whether to use bias in the key, query, and
value projections.
use_ffn_bias (bool): Whether to use bias in the output projection.
attention_initializer (str): Projection kernel initializer. Defaults to
``xavier_uniform``.
attention_q_initializer: Query projection kernel initializer. If not
specified, the query will be initialized via ``attention_initializer``
output_layer_initializer (str or initializer): If not None, use this
initializer for the output transform layer. Defaults to None.
bias_initializer (str): Bias initializer. Defaults to ``zeros``.
attention_type (str): The attention variant to execute. Currently
accepts ``dot_product`` and ``scaled_dot_product``. Defaults to
``scaled_dot_product``.
softmax_dtype_fp32 (bool): Use an FP32 softmax implementation.
device (optional): Device to create the model parameters on, can be a cuda device or CS device.
"""
[docs] def __init__(
self,
embed_dim,
num_heads,
inner_dim=None,
dropout=0.0,
batch_first=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None,
use_projection_bias=None,
use_ffn_bias=False,
attention_initializer="xavier_uniform",
attention_q_initializer=None,
output_layer_initializer=None,
bias_initializer="zeros",
attention_type="scaled_dot_product",
scale_qk_dot_by_d=False,
softmax_dtype_fp32=True,
scale_qk_dot_by_layer_idx=False,
device=None,
# MQA specific
num_kv_groups=1,
):
super(MultiQueryAttention, self).__init__(
embed_dim=embed_dim,
num_heads=num_heads,
inner_dim=inner_dim,
dropout=dropout,
batch_first=batch_first,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
vdim=vdim,
kdim=kdim,
use_projection_bias=use_projection_bias,
use_ffn_bias=use_ffn_bias,
attention_initializer=attention_initializer,
attention_q_initializer=attention_q_initializer,
output_layer_initializer=output_layer_initializer,
bias_initializer=bias_initializer,
attention_type=attention_type,
scale_qk_dot_by_d=scale_qk_dot_by_d,
softmax_dtype_fp32=softmax_dtype_fp32,
scale_qk_dot_by_layer_idx=scale_qk_dot_by_layer_idx,
device=device,
)
self.head_dim = self.inner_dim // self.num_heads
self.num_kv_groups = num_kv_groups
self.per_group_num_heads = self.num_heads // self.num_kv_groups
assert (
self.num_heads % self.num_kv_groups == 0
), f"num_heads has to be a multiple of num_kv_groups but got {self.num_heads} and {self.num_kv_groups}"
# assuming only 1 head for key and value projections
self.proj_k_dense_layer = nn.Linear(
self.kdim,
self.num_kv_groups * self.head_dim,
bias=use_projection_bias,
device=device,
)
self.proj_v_dense_layer = nn.Linear(
self.vdim,
self.num_kv_groups * self.head_dim,
bias=use_projection_bias,
device=device,
)
# reset newly initialized parameters
self.__reset_parameters()
def reset_parameters(self):
super().reset_parameters()
self.__reset_parameters()
def __reset_parameters(self):
# bias initialization
bias_initializer = create_initializer(self.bias_initializer)
if self.use_projection_bias:
bias_initializer(self.proj_k_dense_layer.bias.data)
bias_initializer(self.proj_v_dense_layer.bias.data)
# k, v projections
weight_initializer = create_initializer(self.initializer)
weight_initializer(self.proj_k_dense_layer.weight.data)
weight_initializer(self.proj_v_dense_layer.weight.data)
def construct_key_vector(self, k, attn_mask=None, key_padding_mask=None):
# linear projection
k = self.proj_k_dense_layer(
k
) # [batch_size, seq_length, self.num_kv_groups * self.head_dim]
if self.num_kv_groups == 1:
return torch.unsqueeze(
k, 2
) # [batch_size, seq_length, 1, kv_channels]
batch_size, seq_length, _ = k.shape
# [batch_size, seq_length, self.num_kv_groups, self.head_dim]
k = k.reshape(batch_size, seq_length, self.num_kv_groups, self.head_dim)
return k
def construct_value_vector(self, v, attn_mask=None, key_padding_mask=None):
# linear projection
v = self.proj_v_dense_layer(
v
) # [batch_size, seq_length, self.num_kv_groups * self.head_dim]
if self.num_kv_groups == 1:
return torch.unsqueeze(
v, 1
) # [batch_size, 1, seq_length, kv_channels]
batch_size, seq_length, _ = v.shape
v = v.reshape(batch_size, seq_length, self.num_kv_groups, self.head_dim)
v = v.transpose(2, 1)
# [batch_size, self.num_kv_groups, seq_length, self.head_dim]
return v
def expand_kv_over_group_dim(self, x):
# expand k/v over dimension
batch_size, _, seq_length, _ = x.shape
x = x.unsqueeze(
2
) # [batch_size, self.num_kv_groups, 1, seq_length, self.head_dim]
# expand over per_group_num_heads
x = x.expand(
batch_size,
self.num_kv_groups,
self.per_group_num_heads,
seq_length,
self.head_dim,
)
x = x.reshape(batch_size, self.num_heads, seq_length, self.head_dim)
return x
def calculate_attention_logits(self, q, k, layer_idx):
if self.num_kv_groups > 1:
k = self.expand_kv_over_group_dim(k)
return super().calculate_attention_logits(q, k, layer_idx)
def calculate_attention_output(self, attention_scores, v):
if self.num_kv_groups > 1:
v = self.expand_kv_over_group_dim(v)
return super().calculate_attention_output(attention_scores, v)
def check_extra_params(params):
if "num_kv_groups" not in params:
params["num_kv_groups"] = 1
logging.warning(
"num_kv_groups is not set in the yaml, it is set to 1 by default. "
"Please provide a value if this is not intended."
)