266 lines
9.4 KiB
Python
266 lines
9.4 KiB
Python
# Copyright (c) ONNX Project Contributors
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
from __future__ import annotations
|
|
|
|
import numpy as np
|
|
|
|
import onnx
|
|
from onnx.reference.op_run import OpRun
|
|
|
|
|
|
def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
|
|
x_max = np.max(x, axis=axis, keepdims=True)
|
|
tmp = np.exp(x - x_max)
|
|
s = np.sum(tmp, axis=axis, keepdims=True)
|
|
return tmp / s
|
|
|
|
|
|
def _softcap(X, softcap):
|
|
if softcap > 0:
|
|
Y = X / softcap
|
|
Y = np.tanh(Y)
|
|
return Y * softcap
|
|
return X
|
|
|
|
|
|
def _apply_causal(mask, past_sequence_length):
|
|
"""Applies a causal mask on the input `mask`:
|
|
``mask[i, j] = -inf if past_sequence_length + i > j else 0``.
|
|
Because a softmax is applied on the mask, -inf becomes 0 and 0 becomes 1.
|
|
The modification is done inplace.
|
|
"""
|
|
q_sequence_length, total_sequence_length = mask.shape[-2:]
|
|
triu = np.triu(
|
|
np.ones(
|
|
(q_sequence_length, total_sequence_length - past_sequence_length),
|
|
dtype=mask.dtype,
|
|
),
|
|
k=1,
|
|
)
|
|
triu[triu == 1] = -np.inf
|
|
mask[..., :, past_sequence_length:] += triu
|
|
return mask
|
|
|
|
|
|
def _compute_attention(
|
|
Q: np.ndarray,
|
|
K: np.ndarray,
|
|
V: np.ndarray,
|
|
attn_mask: np.ndarray | None = None,
|
|
past_key: np.ndarray | None = None,
|
|
past_value: np.ndarray | None = None,
|
|
nonpad_kv_seqlen: np.ndarray | None = None,
|
|
scale=None,
|
|
is_causal=False,
|
|
q_num_heads=None,
|
|
kv_num_heads=None,
|
|
softmax_precision=None,
|
|
softcap=None,
|
|
qk_matmul_output_mode=None,
|
|
) -> np.ndarray:
|
|
assert len(Q.shape) == len(K.shape) == len(V.shape)
|
|
# Set input tensors (Q, K, V) to the correct shape if input shape is 3D
|
|
# NewShapeQ (batch_size, q_num_heads, q_sequence_length, head_size)
|
|
# NewShapeK (batch_size, kv_num_heads, kv_sequence_length, head_size)
|
|
# NewShapeV (value) has shape (batch_size, kv_num_heads, kv_sequence_length, v_head_size)
|
|
input_shape_len = len(Q.shape)
|
|
batch_size = Q.shape[0]
|
|
if len(Q.shape) == 3:
|
|
hidden_size_q = Q.shape[2]
|
|
hidden_size_k = K.shape[2]
|
|
hidden_size_v = V.shape[2]
|
|
assert q_num_heads is not None and kv_num_heads is not None
|
|
|
|
head_size_q = int(hidden_size_q / q_num_heads)
|
|
# First reshape to [batch_size, q_sequence_length, q_num_heads, head_size]
|
|
intermediate_shape_q = [batch_size, Q.shape[1], q_num_heads, head_size_q]
|
|
Q = np.reshape(Q, intermediate_shape_q)
|
|
# Then transpose to [batch_size, q_num_heads, q_sequence_length, head_size]
|
|
Q = np.transpose(Q, (0, 2, 1, 3))
|
|
|
|
head_size_k = int(hidden_size_k / kv_num_heads)
|
|
# First reshape to [batch_size, kv_sequence_length, kv_num_heads, head_size]
|
|
intermediate_shape_k = [batch_size, K.shape[1], kv_num_heads, head_size_k]
|
|
K = np.reshape(K, intermediate_shape_k)
|
|
# Then transpose to [batch_size, kv_num_heads, kv_sequence_length, head_size]
|
|
K = np.transpose(K, (0, 2, 1, 3))
|
|
|
|
head_size_v = int(hidden_size_v / kv_num_heads)
|
|
# First reshape to [batch_size, kv_sequence_length, kv_num_heads, head_size]
|
|
intermediate_shape_v = [batch_size, V.shape[1], kv_num_heads, head_size_v]
|
|
V = np.reshape(V, intermediate_shape_v)
|
|
# Then transpose to [batch_size, kv_num_heads, kv_sequence_length, head_size]
|
|
V = np.transpose(V, (0, 2, 1, 3))
|
|
assert len(Q.shape) == 4 and len(K.shape) == 4 and len(V.shape) == 4
|
|
|
|
# Calculate Scaling Factor if not provided
|
|
if scale is None:
|
|
q_head_size = Q.shape[3]
|
|
scale = 1 / np.sqrt(q_head_size)
|
|
scale = np.sqrt(scale)
|
|
|
|
# Update key and value cache
|
|
if past_key is not None:
|
|
present_key = np.concatenate((past_key, K), axis=2)
|
|
else:
|
|
present_key = K
|
|
if past_value is not None:
|
|
present_value = np.concatenate((past_value, V), axis=2)
|
|
else:
|
|
present_value = V
|
|
K = present_key
|
|
V = present_value
|
|
|
|
# Create attn_bias
|
|
q_sequence_length = Q.shape[2]
|
|
kv_sequence_length = K.shape[2]
|
|
attn_bias = np.zeros((q_sequence_length, kv_sequence_length), dtype=Q.dtype)
|
|
|
|
# The attn_mask can be less than kv_sequence_length, we need to pad it with -inf or 0
|
|
if attn_mask is not None:
|
|
pad_width = kv_sequence_length - attn_mask.shape[-1]
|
|
if pad_width > 0:
|
|
pad_shape = [(0, 0)] * (attn_mask.ndim - 1) + [(0, pad_width)]
|
|
pad_value = False if attn_mask.dtype == np.bool_ else -np.inf
|
|
attn_mask = np.pad(
|
|
attn_mask, pad_shape, mode="constant", constant_values=pad_value
|
|
)
|
|
|
|
# First case: If is_causal is provided
|
|
# If set to true, the attention masking is a lower triangular matrix when the mask
|
|
# is a square matrix. The attention masking has the form of the upper left causal
|
|
# bias due to the alignment when the mask is a non-square matrix.
|
|
if is_causal:
|
|
if attn_mask is None:
|
|
temp_mask = np.zeros((q_sequence_length, kv_sequence_length), dtype=Q.dtype)
|
|
attn_bias = _apply_causal(
|
|
temp_mask,
|
|
past_sequence_length=past_key.shape[2] if past_key is not None else 0,
|
|
)
|
|
else:
|
|
if attn_mask.dtype == np.bool_:
|
|
attn_mask = (1 - attn_mask).astype(Q.dtype) * (-np.inf)
|
|
attn_bias = _apply_causal(
|
|
attn_mask.copy(),
|
|
past_sequence_length=past_key.shape[2] if past_key is not None else 0,
|
|
)
|
|
elif attn_mask is not None:
|
|
if attn_mask.dtype == np.bool_:
|
|
attn_mask = (1 - attn_mask).astype(Q.dtype)
|
|
attn_mask[attn_mask == 1] = -np.inf
|
|
attn_bias = attn_bias + attn_mask
|
|
|
|
if nonpad_kv_seqlen is not None:
|
|
attn_bias = attn_bias.reshape(
|
|
(1,) * (4 - attn_bias.ndim) + attn_bias.shape
|
|
) # broadcast to 4D
|
|
padding_mask = np.arange(kv_sequence_length) < nonpad_kv_seqlen[:, np.newaxis]
|
|
padding_mask = padding_mask.reshape(batch_size, 1, 1, kv_sequence_length)
|
|
padding_mask = np.where(padding_mask, 0, -np.inf)
|
|
attn_bias += padding_mask
|
|
|
|
# Group Query Attention is applied if the following are satisfied
|
|
# 1) q_num_heads != kv_num_heads
|
|
# 2) q_num_heads % kv_num_heads == 0
|
|
# 3) kv_num_heads == k_num_heads == v_num_heads
|
|
if q_num_heads is None:
|
|
q_num_heads = Q.shape[1]
|
|
if kv_num_heads is None:
|
|
k_num_heads = K.shape[1]
|
|
v_num_heads = K.shape[1]
|
|
else:
|
|
k_num_heads = kv_num_heads
|
|
v_num_heads = kv_num_heads
|
|
if (
|
|
(q_num_heads != k_num_heads)
|
|
and (q_num_heads % k_num_heads == 0)
|
|
and (k_num_heads == v_num_heads)
|
|
):
|
|
seq_reps = q_num_heads // k_num_heads
|
|
# Interleave-repeat each KV head: [h0, h0, h1, h1, ...]
|
|
K = np.repeat(K, repeats=seq_reps, axis=1)
|
|
V = np.repeat(V, repeats=seq_reps, axis=1)
|
|
|
|
# The following pattern is applied
|
|
# Q K V
|
|
# | | |
|
|
# Q*scale K*scale |
|
|
# | | |
|
|
# | Transpose |
|
|
# | | |
|
|
# ---MatMul--- |
|
|
# | |
|
|
# at_mask---Add |
|
|
# | |
|
|
# softcap (if provided) |
|
|
# | |
|
|
# Softmax |
|
|
# | |
|
|
# -----MatMul------
|
|
# |
|
|
# Y
|
|
k_transpose = np.transpose(K, (0, 1, 3, 2))
|
|
qk_matmul_output = np.matmul(Q * scale, k_transpose * scale)
|
|
qk_with_bias = qk_matmul_output + attn_bias
|
|
if qk_matmul_output_mode == 1:
|
|
qk_matmul_output = qk_with_bias.copy()
|
|
|
|
# Apply softcap
|
|
if softcap is not None:
|
|
qk_with_bias = _softcap(qk_with_bias, softcap)
|
|
if qk_matmul_output_mode == 2:
|
|
qk_matmul_output = qk_with_bias
|
|
|
|
if softmax_precision is not None:
|
|
qk_with_bias = qk_with_bias.astype(
|
|
onnx.helper.tensor_dtype_to_np_dtype(softmax_precision)
|
|
)
|
|
qk_softmax = _softmax(qk_with_bias)
|
|
if qk_matmul_output_mode == 3:
|
|
qk_matmul_output = qk_softmax
|
|
qk_matmul_output = qk_matmul_output.astype(Q.dtype)
|
|
|
|
output = np.matmul(qk_softmax, V).astype(Q.dtype)
|
|
if input_shape_len == 3:
|
|
output = np.transpose(output, (0, 2, 1, 3))
|
|
output = np.reshape(output, (output.shape[0], output.shape[1], -1))
|
|
return output, present_key, present_value, qk_matmul_output
|
|
|
|
|
|
class Attention(OpRun):
|
|
def _run(
|
|
self,
|
|
Q: np.ndarray,
|
|
K: np.ndarray,
|
|
V: np.ndarray,
|
|
attn_mask: np.ndarray | None = None,
|
|
past_key: np.ndarray | None = None,
|
|
past_value: np.ndarray | None = None,
|
|
nonpad_kv_seqlen: np.ndarray | None = None,
|
|
scale=None,
|
|
is_causal=False,
|
|
q_num_heads=None,
|
|
kv_num_heads=None,
|
|
softmax_precision=None,
|
|
softcap=None,
|
|
qk_matmul_output_mode=None,
|
|
) -> np.ndarray:
|
|
res = _compute_attention(
|
|
Q,
|
|
K,
|
|
V,
|
|
attn_mask=attn_mask,
|
|
past_key=past_key,
|
|
past_value=past_value,
|
|
nonpad_kv_seqlen=nonpad_kv_seqlen,
|
|
scale=scale,
|
|
is_causal=is_causal,
|
|
q_num_heads=q_num_heads,
|
|
kv_num_heads=kv_num_heads,
|
|
softmax_precision=softmax_precision,
|
|
softcap=softcap,
|
|
qk_matmul_output_mode=qk_matmul_output_mode,
|
|
)
|
|
return res
|