DriverTrac/venv/lib/python3.12/site-packages/onnx/reference/ops/op_attention.py
2025-11-28 09:08:33 +05:30

266 lines
9.4 KiB
Python

# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import numpy as np
import onnx
from onnx.reference.op_run import OpRun
def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
x_max = np.max(x, axis=axis, keepdims=True)
tmp = np.exp(x - x_max)
s = np.sum(tmp, axis=axis, keepdims=True)
return tmp / s
def _softcap(X, softcap):
if softcap > 0:
Y = X / softcap
Y = np.tanh(Y)
return Y * softcap
return X
def _apply_causal(mask, past_sequence_length):
"""Applies a causal mask on the input `mask`:
``mask[i, j] = -inf if past_sequence_length + i > j else 0``.
Because a softmax is applied on the mask, -inf becomes 0 and 0 becomes 1.
The modification is done inplace.
"""
q_sequence_length, total_sequence_length = mask.shape[-2:]
triu = np.triu(
np.ones(
(q_sequence_length, total_sequence_length - past_sequence_length),
dtype=mask.dtype,
),
k=1,
)
triu[triu == 1] = -np.inf
mask[..., :, past_sequence_length:] += triu
return mask
def _compute_attention(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
attn_mask: np.ndarray | None = None,
past_key: np.ndarray | None = None,
past_value: np.ndarray | None = None,
nonpad_kv_seqlen: np.ndarray | None = None,
scale=None,
is_causal=False,
q_num_heads=None,
kv_num_heads=None,
softmax_precision=None,
softcap=None,
qk_matmul_output_mode=None,
) -> np.ndarray:
assert len(Q.shape) == len(K.shape) == len(V.shape)
# Set input tensors (Q, K, V) to the correct shape if input shape is 3D
# NewShapeQ (batch_size, q_num_heads, q_sequence_length, head_size)
# NewShapeK (batch_size, kv_num_heads, kv_sequence_length, head_size)
# NewShapeV (value) has shape (batch_size, kv_num_heads, kv_sequence_length, v_head_size)
input_shape_len = len(Q.shape)
batch_size = Q.shape[0]
if len(Q.shape) == 3:
hidden_size_q = Q.shape[2]
hidden_size_k = K.shape[2]
hidden_size_v = V.shape[2]
assert q_num_heads is not None and kv_num_heads is not None
head_size_q = int(hidden_size_q / q_num_heads)
# First reshape to [batch_size, q_sequence_length, q_num_heads, head_size]
intermediate_shape_q = [batch_size, Q.shape[1], q_num_heads, head_size_q]
Q = np.reshape(Q, intermediate_shape_q)
# Then transpose to [batch_size, q_num_heads, q_sequence_length, head_size]
Q = np.transpose(Q, (0, 2, 1, 3))
head_size_k = int(hidden_size_k / kv_num_heads)
# First reshape to [batch_size, kv_sequence_length, kv_num_heads, head_size]
intermediate_shape_k = [batch_size, K.shape[1], kv_num_heads, head_size_k]
K = np.reshape(K, intermediate_shape_k)
# Then transpose to [batch_size, kv_num_heads, kv_sequence_length, head_size]
K = np.transpose(K, (0, 2, 1, 3))
head_size_v = int(hidden_size_v / kv_num_heads)
# First reshape to [batch_size, kv_sequence_length, kv_num_heads, head_size]
intermediate_shape_v = [batch_size, V.shape[1], kv_num_heads, head_size_v]
V = np.reshape(V, intermediate_shape_v)
# Then transpose to [batch_size, kv_num_heads, kv_sequence_length, head_size]
V = np.transpose(V, (0, 2, 1, 3))
assert len(Q.shape) == 4 and len(K.shape) == 4 and len(V.shape) == 4
# Calculate Scaling Factor if not provided
if scale is None:
q_head_size = Q.shape[3]
scale = 1 / np.sqrt(q_head_size)
scale = np.sqrt(scale)
# Update key and value cache
if past_key is not None:
present_key = np.concatenate((past_key, K), axis=2)
else:
present_key = K
if past_value is not None:
present_value = np.concatenate((past_value, V), axis=2)
else:
present_value = V
K = present_key
V = present_value
# Create attn_bias
q_sequence_length = Q.shape[2]
kv_sequence_length = K.shape[2]
attn_bias = np.zeros((q_sequence_length, kv_sequence_length), dtype=Q.dtype)
# The attn_mask can be less than kv_sequence_length, we need to pad it with -inf or 0
if attn_mask is not None:
pad_width = kv_sequence_length - attn_mask.shape[-1]
if pad_width > 0:
pad_shape = [(0, 0)] * (attn_mask.ndim - 1) + [(0, pad_width)]
pad_value = False if attn_mask.dtype == np.bool_ else -np.inf
attn_mask = np.pad(
attn_mask, pad_shape, mode="constant", constant_values=pad_value
)
# First case: If is_causal is provided
# If set to true, the attention masking is a lower triangular matrix when the mask
# is a square matrix. The attention masking has the form of the upper left causal
# bias due to the alignment when the mask is a non-square matrix.
if is_causal:
if attn_mask is None:
temp_mask = np.zeros((q_sequence_length, kv_sequence_length), dtype=Q.dtype)
attn_bias = _apply_causal(
temp_mask,
past_sequence_length=past_key.shape[2] if past_key is not None else 0,
)
else:
if attn_mask.dtype == np.bool_:
attn_mask = (1 - attn_mask).astype(Q.dtype) * (-np.inf)
attn_bias = _apply_causal(
attn_mask.copy(),
past_sequence_length=past_key.shape[2] if past_key is not None else 0,
)
elif attn_mask is not None:
if attn_mask.dtype == np.bool_:
attn_mask = (1 - attn_mask).astype(Q.dtype)
attn_mask[attn_mask == 1] = -np.inf
attn_bias = attn_bias + attn_mask
if nonpad_kv_seqlen is not None:
attn_bias = attn_bias.reshape(
(1,) * (4 - attn_bias.ndim) + attn_bias.shape
) # broadcast to 4D
padding_mask = np.arange(kv_sequence_length) < nonpad_kv_seqlen[:, np.newaxis]
padding_mask = padding_mask.reshape(batch_size, 1, 1, kv_sequence_length)
padding_mask = np.where(padding_mask, 0, -np.inf)
attn_bias += padding_mask
# Group Query Attention is applied if the following are satisfied
# 1) q_num_heads != kv_num_heads
# 2) q_num_heads % kv_num_heads == 0
# 3) kv_num_heads == k_num_heads == v_num_heads
if q_num_heads is None:
q_num_heads = Q.shape[1]
if kv_num_heads is None:
k_num_heads = K.shape[1]
v_num_heads = K.shape[1]
else:
k_num_heads = kv_num_heads
v_num_heads = kv_num_heads
if (
(q_num_heads != k_num_heads)
and (q_num_heads % k_num_heads == 0)
and (k_num_heads == v_num_heads)
):
seq_reps = q_num_heads // k_num_heads
# Interleave-repeat each KV head: [h0, h0, h1, h1, ...]
K = np.repeat(K, repeats=seq_reps, axis=1)
V = np.repeat(V, repeats=seq_reps, axis=1)
# The following pattern is applied
# Q K V
# | | |
# Q*scale K*scale |
# | | |
# | Transpose |
# | | |
# ---MatMul--- |
# | |
# at_mask---Add |
# | |
# softcap (if provided) |
# | |
# Softmax |
# | |
# -----MatMul------
# |
# Y
k_transpose = np.transpose(K, (0, 1, 3, 2))
qk_matmul_output = np.matmul(Q * scale, k_transpose * scale)
qk_with_bias = qk_matmul_output + attn_bias
if qk_matmul_output_mode == 1:
qk_matmul_output = qk_with_bias.copy()
# Apply softcap
if softcap is not None:
qk_with_bias = _softcap(qk_with_bias, softcap)
if qk_matmul_output_mode == 2:
qk_matmul_output = qk_with_bias
if softmax_precision is not None:
qk_with_bias = qk_with_bias.astype(
onnx.helper.tensor_dtype_to_np_dtype(softmax_precision)
)
qk_softmax = _softmax(qk_with_bias)
if qk_matmul_output_mode == 3:
qk_matmul_output = qk_softmax
qk_matmul_output = qk_matmul_output.astype(Q.dtype)
output = np.matmul(qk_softmax, V).astype(Q.dtype)
if input_shape_len == 3:
output = np.transpose(output, (0, 2, 1, 3))
output = np.reshape(output, (output.shape[0], output.shape[1], -1))
return output, present_key, present_value, qk_matmul_output
class Attention(OpRun):
def _run(
self,
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
attn_mask: np.ndarray | None = None,
past_key: np.ndarray | None = None,
past_value: np.ndarray | None = None,
nonpad_kv_seqlen: np.ndarray | None = None,
scale=None,
is_causal=False,
q_num_heads=None,
kv_num_heads=None,
softmax_precision=None,
softcap=None,
qk_matmul_output_mode=None,
) -> np.ndarray:
res = _compute_attention(
Q,
K,
V,
attn_mask=attn_mask,
past_key=past_key,
past_value=past_value,
nonpad_kv_seqlen=nonpad_kv_seqlen,
scale=scale,
is_causal=is_causal,
q_num_heads=q_num_heads,
kv_num_heads=kv_num_heads,
softmax_precision=softmax_precision,
softcap=softcap,
qk_matmul_output_mode=qk_matmul_output_mode,
)
return res