# Copyright (c) ONNX Project Contributors # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations import numpy as np import onnx from onnx.reference.op_run import OpRun def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray: x_max = np.max(x, axis=axis, keepdims=True) tmp = np.exp(x - x_max) s = np.sum(tmp, axis=axis, keepdims=True) return tmp / s def _softcap(X, softcap): if softcap > 0: Y = X / softcap Y = np.tanh(Y) return Y * softcap return X def _apply_causal(mask, past_sequence_length): """Applies a causal mask on the input `mask`: ``mask[i, j] = -inf if past_sequence_length + i > j else 0``. Because a softmax is applied on the mask, -inf becomes 0 and 0 becomes 1. The modification is done inplace. """ q_sequence_length, total_sequence_length = mask.shape[-2:] triu = np.triu( np.ones( (q_sequence_length, total_sequence_length - past_sequence_length), dtype=mask.dtype, ), k=1, ) triu[triu == 1] = -np.inf mask[..., :, past_sequence_length:] += triu return mask def _compute_attention( Q: np.ndarray, K: np.ndarray, V: np.ndarray, attn_mask: np.ndarray | None = None, past_key: np.ndarray | None = None, past_value: np.ndarray | None = None, nonpad_kv_seqlen: np.ndarray | None = None, scale=None, is_causal=False, q_num_heads=None, kv_num_heads=None, softmax_precision=None, softcap=None, qk_matmul_output_mode=None, ) -> np.ndarray: assert len(Q.shape) == len(K.shape) == len(V.shape) # Set input tensors (Q, K, V) to the correct shape if input shape is 3D # NewShapeQ (batch_size, q_num_heads, q_sequence_length, head_size) # NewShapeK (batch_size, kv_num_heads, kv_sequence_length, head_size) # NewShapeV (value) has shape (batch_size, kv_num_heads, kv_sequence_length, v_head_size) input_shape_len = len(Q.shape) batch_size = Q.shape[0] if len(Q.shape) == 3: hidden_size_q = Q.shape[2] hidden_size_k = K.shape[2] hidden_size_v = V.shape[2] assert q_num_heads is not None and kv_num_heads is not None head_size_q = int(hidden_size_q / q_num_heads) # First reshape to [batch_size, q_sequence_length, q_num_heads, head_size] intermediate_shape_q = [batch_size, Q.shape[1], q_num_heads, head_size_q] Q = np.reshape(Q, intermediate_shape_q) # Then transpose to [batch_size, q_num_heads, q_sequence_length, head_size] Q = np.transpose(Q, (0, 2, 1, 3)) head_size_k = int(hidden_size_k / kv_num_heads) # First reshape to [batch_size, kv_sequence_length, kv_num_heads, head_size] intermediate_shape_k = [batch_size, K.shape[1], kv_num_heads, head_size_k] K = np.reshape(K, intermediate_shape_k) # Then transpose to [batch_size, kv_num_heads, kv_sequence_length, head_size] K = np.transpose(K, (0, 2, 1, 3)) head_size_v = int(hidden_size_v / kv_num_heads) # First reshape to [batch_size, kv_sequence_length, kv_num_heads, head_size] intermediate_shape_v = [batch_size, V.shape[1], kv_num_heads, head_size_v] V = np.reshape(V, intermediate_shape_v) # Then transpose to [batch_size, kv_num_heads, kv_sequence_length, head_size] V = np.transpose(V, (0, 2, 1, 3)) assert len(Q.shape) == 4 and len(K.shape) == 4 and len(V.shape) == 4 # Calculate Scaling Factor if not provided if scale is None: q_head_size = Q.shape[3] scale = 1 / np.sqrt(q_head_size) scale = np.sqrt(scale) # Update key and value cache if past_key is not None: present_key = np.concatenate((past_key, K), axis=2) else: present_key = K if past_value is not None: present_value = np.concatenate((past_value, V), axis=2) else: present_value = V K = present_key V = present_value # Create attn_bias q_sequence_length = Q.shape[2] kv_sequence_length = K.shape[2] attn_bias = np.zeros((q_sequence_length, kv_sequence_length), dtype=Q.dtype) # The attn_mask can be less than kv_sequence_length, we need to pad it with -inf or 0 if attn_mask is not None: pad_width = kv_sequence_length - attn_mask.shape[-1] if pad_width > 0: pad_shape = [(0, 0)] * (attn_mask.ndim - 1) + [(0, pad_width)] pad_value = False if attn_mask.dtype == np.bool_ else -np.inf attn_mask = np.pad( attn_mask, pad_shape, mode="constant", constant_values=pad_value ) # First case: If is_causal is provided # If set to true, the attention masking is a lower triangular matrix when the mask # is a square matrix. The attention masking has the form of the upper left causal # bias due to the alignment when the mask is a non-square matrix. if is_causal: if attn_mask is None: temp_mask = np.zeros((q_sequence_length, kv_sequence_length), dtype=Q.dtype) attn_bias = _apply_causal( temp_mask, past_sequence_length=past_key.shape[2] if past_key is not None else 0, ) else: if attn_mask.dtype == np.bool_: attn_mask = (1 - attn_mask).astype(Q.dtype) * (-np.inf) attn_bias = _apply_causal( attn_mask.copy(), past_sequence_length=past_key.shape[2] if past_key is not None else 0, ) elif attn_mask is not None: if attn_mask.dtype == np.bool_: attn_mask = (1 - attn_mask).astype(Q.dtype) attn_mask[attn_mask == 1] = -np.inf attn_bias = attn_bias + attn_mask if nonpad_kv_seqlen is not None: attn_bias = attn_bias.reshape( (1,) * (4 - attn_bias.ndim) + attn_bias.shape ) # broadcast to 4D padding_mask = np.arange(kv_sequence_length) < nonpad_kv_seqlen[:, np.newaxis] padding_mask = padding_mask.reshape(batch_size, 1, 1, kv_sequence_length) padding_mask = np.where(padding_mask, 0, -np.inf) attn_bias += padding_mask # Group Query Attention is applied if the following are satisfied # 1) q_num_heads != kv_num_heads # 2) q_num_heads % kv_num_heads == 0 # 3) kv_num_heads == k_num_heads == v_num_heads if q_num_heads is None: q_num_heads = Q.shape[1] if kv_num_heads is None: k_num_heads = K.shape[1] v_num_heads = K.shape[1] else: k_num_heads = kv_num_heads v_num_heads = kv_num_heads if ( (q_num_heads != k_num_heads) and (q_num_heads % k_num_heads == 0) and (k_num_heads == v_num_heads) ): seq_reps = q_num_heads // k_num_heads # Interleave-repeat each KV head: [h0, h0, h1, h1, ...] K = np.repeat(K, repeats=seq_reps, axis=1) V = np.repeat(V, repeats=seq_reps, axis=1) # The following pattern is applied # Q K V # | | | # Q*scale K*scale | # | | | # | Transpose | # | | | # ---MatMul--- | # | | # at_mask---Add | # | | # softcap (if provided) | # | | # Softmax | # | | # -----MatMul------ # | # Y k_transpose = np.transpose(K, (0, 1, 3, 2)) qk_matmul_output = np.matmul(Q * scale, k_transpose * scale) qk_with_bias = qk_matmul_output + attn_bias if qk_matmul_output_mode == 1: qk_matmul_output = qk_with_bias.copy() # Apply softcap if softcap is not None: qk_with_bias = _softcap(qk_with_bias, softcap) if qk_matmul_output_mode == 2: qk_matmul_output = qk_with_bias if softmax_precision is not None: qk_with_bias = qk_with_bias.astype( onnx.helper.tensor_dtype_to_np_dtype(softmax_precision) ) qk_softmax = _softmax(qk_with_bias) if qk_matmul_output_mode == 3: qk_matmul_output = qk_softmax qk_matmul_output = qk_matmul_output.astype(Q.dtype) output = np.matmul(qk_softmax, V).astype(Q.dtype) if input_shape_len == 3: output = np.transpose(output, (0, 2, 1, 3)) output = np.reshape(output, (output.shape[0], output.shape[1], -1)) return output, present_key, present_value, qk_matmul_output class Attention(OpRun): def _run( self, Q: np.ndarray, K: np.ndarray, V: np.ndarray, attn_mask: np.ndarray | None = None, past_key: np.ndarray | None = None, past_value: np.ndarray | None = None, nonpad_kv_seqlen: np.ndarray | None = None, scale=None, is_causal=False, q_num_heads=None, kv_num_heads=None, softmax_precision=None, softcap=None, qk_matmul_output_mode=None, ) -> np.ndarray: res = _compute_attention( Q, K, V, attn_mask=attn_mask, past_key=past_key, past_value=past_value, nonpad_kv_seqlen=nonpad_kv_seqlen, scale=scale, is_causal=is_causal, q_num_heads=q_num_heads, kv_num_heads=kv_num_heads, softmax_precision=softmax_precision, softcap=softcap, qk_matmul_output_mode=qk_matmul_output_mode, ) return res