DriverTrac/venv/lib/python3.12/site-packages/transformers/integrations/sdpa_attention.py

109 lines
5.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Optional
import torch
from ..utils import is_torch_npu_available, is_torch_xpu_available, logging
from ..utils.import_utils import is_torch_greater_or_equal
logger = logging.get_logger(__name__)
_is_torch_greater_or_equal_than_2_5 = is_torch_greater_or_equal("2.5", accept_dev=True)
_is_torch_greater_or_equal_than_2_8 = is_torch_greater_or_equal("2.8", accept_dev=True)
_is_torch_xpu_available = is_torch_xpu_available()
_is_torch_npu_available = is_torch_npu_available()
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def use_gqa_in_sdpa(attention_mask: Optional[torch.Tensor], key: torch.Tensor) -> bool:
# GQA can only be used under the following conditions
# 1.cuda
# - torch version >= 2.5
# - attention_mask is None (otherwise it will fall back to the math kernel)
# - key is not a torch.fx.Proxy (otherwise it will fail with a tracing error)
# 2.xpu
# - torch version >= 2.8
# - key is not a torch.fx.Proxy (otherwise it will fail with a tracing error)
# 3.npu
# - npu is not supported gqa currently
if _is_torch_xpu_available:
return _is_torch_greater_or_equal_than_2_8 and not isinstance(key, torch.fx.Proxy)
if _is_torch_npu_available:
return False
return _is_torch_greater_or_equal_than_2_5 and attention_mask is None and not isinstance(key, torch.fx.Proxy)
def sdpa_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
dropout: float = 0.0,
scaling: Optional[float] = None,
is_causal: Optional[bool] = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None:
logger.warning_once(
"`sdpa` attention does not support `output_attentions=True` or `head_mask`."
" Please set your attention to `eager` if you want any of these features."
)
sdpa_kwargs = {}
if hasattr(module, "num_key_value_groups"):
if not use_gqa_in_sdpa(attention_mask, key):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)
else:
sdpa_kwargs = {"enable_gqa": True}
if attention_mask is not None and attention_mask.ndim == 4:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
if is_causal is None:
# The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag
# This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns
is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True)
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
# We convert it to a bool for the SDPA kernel that only accepts bools.
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
is_causal = is_causal.item()
# When `is_causal = False` and the `attention_mask` is not of boolean type, the Ascend NPU's SDPA interface cannot utilize the FlashAttentionScore operator
# and falls back to small-operator concatenation. To invoke the FlashAttentionScore, the attention_mask must be converted to boolean type.
# This adaptation ensures the `attention_mask` meets the requirement for using FlashAttentionScore.
if _is_torch_npu_available:
if attention_mask is not None and attention_mask.dtype != torch.bool:
# Convert to boolean type, making sdpa to force call FlashAttentionScore to improve performance.
attention_mask = torch.logical_not(attention_mask.bool()).to(query.device)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=dropout,
scale=scaling,
is_causal=is_causal,
**sdpa_kwargs,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None