1494 lines
66 KiB
Python
1494 lines
66 KiB
Python
from abc import ABC, abstractmethod
|
|
from collections.abc import Iterable
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
|
|
from .configuration_utils import PretrainedConfig
|
|
from .utils import (
|
|
is_hqq_available,
|
|
is_quanto_greater,
|
|
is_torch_greater_or_equal,
|
|
is_torchdynamo_compiling,
|
|
logging,
|
|
)
|
|
|
|
|
|
if is_hqq_available():
|
|
from hqq.core.quantize import Quantizer as HQQQuantizer
|
|
|
|
_is_torch_greater_or_equal_than_2_7 = is_torch_greater_or_equal("2.7", accept_dev=True)
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class CacheLayerMixin(ABC):
|
|
"""Base, abstract class for a single layer's cache."""
|
|
|
|
is_compileable = False
|
|
|
|
def __init__(self):
|
|
self.keys: Optional[torch.Tensor] = None
|
|
self.values: Optional[torch.Tensor] = None
|
|
self.is_initialized = False
|
|
|
|
def __repr__(self):
|
|
return f"{self.__class__.__name__}"
|
|
|
|
@abstractmethod
|
|
def lazy_initialization(self, key_states: torch.Tensor): ...
|
|
|
|
@abstractmethod
|
|
def update(
|
|
self, key_states: torch.Tensor, value_states: torch.Tensor, cache_kwargs: Optional[dict[str, Any]] = None
|
|
) -> tuple[torch.Tensor, torch.Tensor]: ...
|
|
|
|
@abstractmethod
|
|
def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: ...
|
|
|
|
@abstractmethod
|
|
def get_seq_length(self) -> int: ...
|
|
|
|
@abstractmethod
|
|
def get_max_cache_shape(self) -> int: ...
|
|
|
|
def offload(self):
|
|
"""Offload this layer's data to CPU device."""
|
|
if self.is_initialized:
|
|
self.keys = self.keys.to("cpu", non_blocking=True)
|
|
self.values = self.values.to("cpu", non_blocking=True)
|
|
|
|
def prefetch(self):
|
|
"""In case of layer offloading, this allows to move the data back to the layer's device ahead of time."""
|
|
if self.is_initialized and self.keys.device != self.device:
|
|
self.keys = self.keys.to(self.device, non_blocking=True)
|
|
self.values = self.values.to(self.device, non_blocking=True)
|
|
|
|
def reset(self) -> None:
|
|
"""Resets the cache values while preserving the objects"""
|
|
if self.is_initialized:
|
|
self.keys.zero_()
|
|
self.values.zero_()
|
|
# This attribute is set on several Layers
|
|
if hasattr(self, "cumulative_length"):
|
|
self.cumulative_length = 0
|
|
|
|
def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
|
|
"""Reorders this layer's cache for beam search."""
|
|
if self.get_seq_length() > 0:
|
|
self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device))
|
|
self.values = self.values.index_select(0, beam_idx.to(self.values.device))
|
|
|
|
|
|
class DynamicLayer(CacheLayerMixin):
|
|
"""
|
|
A cache layer that grows dynamically as more tokens are generated. This is the default for generative models.
|
|
It stores the key and value states as tensors of shape `[batch_size, num_heads, seq_len, head_dim]`.
|
|
"""
|
|
|
|
is_sliding = False
|
|
|
|
def lazy_initialization(self, key_states: torch.Tensor):
|
|
self.dtype, self.device = key_states.dtype, key_states.device
|
|
self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
|
|
self.values = torch.tensor([], dtype=self.dtype, device=self.device)
|
|
self.is_initialized = True
|
|
|
|
def update(
|
|
self,
|
|
key_states: torch.Tensor,
|
|
value_states: torch.Tensor,
|
|
cache_kwargs: Optional[dict[str, Any]] = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Update the key and value caches in-place, and return the necessary keys and value states.
|
|
|
|
Args:
|
|
key_states (`torch.Tensor`): The new key states to cache.
|
|
value_states (`torch.Tensor`): The new value states to cache.
|
|
cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
|
|
|
|
Returns:
|
|
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
|
|
"""
|
|
# Lazy initialization
|
|
if not self.is_initialized:
|
|
self.lazy_initialization(key_states)
|
|
|
|
self.keys = torch.cat([self.keys, key_states], dim=-2)
|
|
self.values = torch.cat([self.values, value_states], dim=-2)
|
|
return self.keys, self.values
|
|
|
|
def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
|
|
"""Return the length and offset of the cache, used to generate the mask"""
|
|
kv_offset = 0
|
|
query_length = cache_position.shape[0]
|
|
kv_length = self.get_seq_length() + query_length
|
|
return kv_length, kv_offset
|
|
|
|
def get_seq_length(self) -> int:
|
|
"""Returns the sequence length of the cached states."""
|
|
if not self.is_initialized or self.keys.numel() == 0:
|
|
return 0
|
|
return self.keys.shape[-2]
|
|
|
|
def get_max_cache_shape(self) -> int:
|
|
"""Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length."""
|
|
return -1
|
|
|
|
def crop(self, max_length: int) -> None:
|
|
"""
|
|
Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be negative
|
|
to remove `max_length` tokens.
|
|
"""
|
|
if max_length < 0:
|
|
max_length = self.get_seq_length() - abs(max_length)
|
|
|
|
if self.get_seq_length() <= max_length:
|
|
return
|
|
|
|
self.keys = self.keys[..., :max_length, :]
|
|
self.values = self.values[..., :max_length, :]
|
|
|
|
def batch_repeat_interleave(self, repeats: int) -> None:
|
|
"""Repeat the cache `repeats` times in the batch dimension."""
|
|
if self.get_seq_length() > 0:
|
|
self.keys = self.keys.repeat_interleave(repeats, dim=0)
|
|
self.values = self.values.repeat_interleave(repeats, dim=0)
|
|
|
|
def batch_select_indices(self, indices: torch.Tensor) -> None:
|
|
"""Only keep the `indices` in the batch dimension of the cache."""
|
|
if self.get_seq_length() > 0:
|
|
self.keys = self.keys[indices, ...]
|
|
self.values = self.values[indices, ...]
|
|
|
|
|
|
class DynamicSlidingWindowLayer(DynamicLayer):
|
|
"""
|
|
A cache layer that grows dynamically as more tokens are generated, up until the sliding window size.
|
|
It stores the key and value states as tensors of shape `[batch_size, num_heads, min(seq_len, sliding_window), head_dim]`.
|
|
"""
|
|
|
|
is_sliding = True
|
|
|
|
def __init__(self, sliding_window: int):
|
|
super().__init__()
|
|
self.sliding_window = sliding_window
|
|
self.cumulative_length = 0
|
|
|
|
def update(
|
|
self,
|
|
key_states: torch.Tensor,
|
|
value_states: torch.Tensor,
|
|
cache_kwargs: Optional[dict[str, Any]] = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Update the key and value caches in-place, and return the necessary keys and value states.
|
|
|
|
Args:
|
|
key_states (`torch.Tensor`): The new key states to cache.
|
|
value_states (`torch.Tensor`): The new value states to cache.
|
|
cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
|
|
|
|
Returns:
|
|
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
|
|
"""
|
|
# Lazy initialization
|
|
if not self.is_initialized:
|
|
self.lazy_initialization(key_states)
|
|
|
|
self.cumulative_length += key_states.shape[-2]
|
|
|
|
# Compute the full states
|
|
full_key_states = torch.cat([self.keys, key_states], dim=-2)
|
|
full_value_states = torch.cat([self.values, value_states], dim=-2)
|
|
# Only cache the last `self.sliding_window - 1` tokens (or all of them if lower than that)
|
|
self.keys = full_key_states[:, :, -self.sliding_window + 1 :, :]
|
|
self.values = full_value_states[:, :, -self.sliding_window + 1 :, :]
|
|
|
|
# Return the full states
|
|
return full_key_states, full_value_states
|
|
|
|
def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
|
|
"""Return the length and offset of the cache, used to generate the attention mask"""
|
|
query_length = cache_position.shape[0]
|
|
is_full = self.cumulative_length >= self.sliding_window
|
|
|
|
kv_offset = max(self.cumulative_length - self.sliding_window + 1, 0)
|
|
if is_full:
|
|
kv_length = self.sliding_window - 1 + query_length
|
|
else:
|
|
kv_length = self.cumulative_length + query_length
|
|
|
|
return kv_length, kv_offset
|
|
|
|
def get_seq_length(self) -> int:
|
|
"""Returns the sequence length of the cached states."""
|
|
return self.cumulative_length
|
|
|
|
def get_max_cache_shape(self) -> int:
|
|
"""Return the maximum cache shape of the cache"""
|
|
return self.sliding_window
|
|
|
|
def crop(self, max_length: int) -> None:
|
|
"""
|
|
Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
|
|
negative to remove `max_length` tokens.
|
|
"""
|
|
if self.get_seq_length() >= self.sliding_window:
|
|
raise ValueError(
|
|
"Cannot `crop` a `DynamicSlidingWindowLayer` after it has seen more tokens than its"
|
|
"sliding window (otherwise some states are lost)"
|
|
)
|
|
super().crop(max_length)
|
|
self.cumulative_length = self.keys.shape[-2]
|
|
|
|
|
|
class StaticLayer(CacheLayerMixin):
|
|
"""
|
|
A static cache layer that stores the key and value states as static tensors of shape `[batch_size, num_heads, max_cache_len), head_dim]`.
|
|
It lazily allocates its full backing tensors, and then mutates them in-place. Built for `torch.compile` support.
|
|
|
|
Args:
|
|
max_cache_len (`int`):
|
|
Maximum number of tokens that can be stored, used for tensor preallocation.
|
|
"""
|
|
|
|
is_compileable = True
|
|
is_sliding = False
|
|
|
|
def __init__(self, max_cache_len: int):
|
|
super().__init__()
|
|
self.max_cache_len = max_cache_len
|
|
|
|
def lazy_initialization(self, key_states: torch.Tensor):
|
|
"""
|
|
Lazy initialization of the keys and values tensors. This allows to get all properties (dtype, device,
|
|
num_heads in case of TP etc...) at runtime directly, which is extremely practical as it avoids moving
|
|
devices, dtypes etc later on for each `update` (which could break the static dynamo addresses as well).
|
|
|
|
If this is unwanted, one can call `early_initialization(...)` on the Cache directly, which will call this
|
|
function ahead-of-time (this is required for `torch.export` for example). Note that for `compile`, as we
|
|
internally don't compile the prefill, this is guaranteed to have been called already when compiling.
|
|
If compiling the prefill as well, e.g. calling `model.compile(...)` before `generate` with a static cache,
|
|
it is still supported in general, but without guarantees depending on the compilation options (e.g. cuda graphs,
|
|
i.e. `mode="reduce-overhead"` is known to fail). But it will in general work correctly, and prefill should
|
|
not be compiled anyway for performances!
|
|
"""
|
|
self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape
|
|
self.dtype, self.device = key_states.dtype, key_states.device
|
|
|
|
self.keys = torch.zeros(
|
|
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
|
|
dtype=self.dtype,
|
|
device=self.device,
|
|
)
|
|
self.values = torch.zeros(
|
|
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
|
|
dtype=self.dtype,
|
|
device=self.device,
|
|
)
|
|
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer, preventing compiled graph
|
|
# breaks when updating the cache. However, it is not supported when tracing the graph, so we skip it in this case.
|
|
# As prefill should never be compiled, this is not an issue and it will still be run (except when users compile
|
|
# prefill explicitly, but this should be avoided!)
|
|
if not is_torchdynamo_compiling():
|
|
torch._dynamo.mark_static_address(self.keys)
|
|
torch._dynamo.mark_static_address(self.values)
|
|
|
|
self.is_initialized = True
|
|
|
|
def update(
|
|
self,
|
|
key_states: torch.Tensor,
|
|
value_states: torch.Tensor,
|
|
cache_kwargs: Optional[dict[str, Any]] = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Update the key and value caches in-place, and return the necessary keys and value states.
|
|
|
|
Args:
|
|
key_states (`torch.Tensor`): The new key states to cache.
|
|
value_states (`torch.Tensor`): The new value states to cache.
|
|
cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
|
|
|
|
Returns:
|
|
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
|
|
"""
|
|
# Lazy initialization
|
|
if not self.is_initialized:
|
|
self.lazy_initialization(key_states)
|
|
|
|
# Some old models give None for `cache_position` or even omit passing `cache_kwargs` when used as cross-attention,
|
|
# in which case we should copy the whole Layer (key_states.shape[-2] == self.max_cache_len)
|
|
cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
|
|
cache_position = (
|
|
cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
|
|
)
|
|
|
|
# Update the cache
|
|
try:
|
|
self.keys.index_copy_(2, cache_position, key_states)
|
|
self.values.index_copy_(2, cache_position, value_states)
|
|
except NotImplementedError:
|
|
# Fallback for devices like MPS where index_copy_ might not be supported.
|
|
self.keys[:, :, cache_position] = key_states
|
|
self.values[:, :, cache_position] = value_states
|
|
return self.keys, self.values
|
|
|
|
def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
|
|
"""Return the length and offset of the cache, used to generate the attention mask"""
|
|
kv_offset = 0
|
|
kv_length = self.max_cache_len
|
|
return kv_length, kv_offset
|
|
|
|
def get_seq_length(self) -> int:
|
|
"""Returns the sequence length of the cached states."""
|
|
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
|
|
# limit the check to the first batch member and head dimension.
|
|
return (self.keys[0, 0].any(dim=-1)).sum() if self.is_initialized else 0
|
|
|
|
def get_max_cache_shape(self) -> int:
|
|
"""Return the maximum cache shape of the cache"""
|
|
return self.max_cache_len
|
|
|
|
|
|
class StaticSlidingWindowLayer(StaticLayer):
|
|
"""
|
|
A static cache layer that stores the key and value states as static tensors of shape
|
|
`[batch_size, num_heads, min(max_cache_len, sliding_window), head_dim]`. It lazily allocates its full backing
|
|
tensors, and then mutates them in-place. Built for `torch.compile` support.
|
|
|
|
Args:
|
|
max_cache_len (`int`):
|
|
Maximum number of tokens that can be stored, used for tensor preallocation.
|
|
sliding_window (`int`):
|
|
The size of the sliding window.
|
|
"""
|
|
|
|
is_sliding = True
|
|
|
|
def __init__(self, max_cache_len: int, sliding_window: int):
|
|
effective_max_cache_len = min(sliding_window, max_cache_len)
|
|
super().__init__(max_cache_len=effective_max_cache_len)
|
|
self.cumulative_length = 0
|
|
|
|
def update(
|
|
self,
|
|
key_states: torch.Tensor,
|
|
value_states: torch.Tensor,
|
|
cache_kwargs: Optional[dict[str, Any]] = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Update the key and value caches in-place, and return the necessary keys and value states.
|
|
|
|
Args:
|
|
key_states (`torch.Tensor`): The new key states to cache.
|
|
value_states (`torch.Tensor`): The new value states to cache.
|
|
cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
|
|
|
|
Returns:
|
|
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
|
|
"""
|
|
# Lazy initialization
|
|
if not self.is_initialized:
|
|
self.lazy_initialization(key_states)
|
|
|
|
# Some old models give None for `cache_position` or even omit passing `cache_kwargs` when used as cross-attention,
|
|
# in which case we should copy the whole Layer (key_states.shape[-2] == self.max_cache_len)
|
|
cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
|
|
cache_position = (
|
|
cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
|
|
)
|
|
|
|
cumulative_length = self.cumulative_length
|
|
is_full = cumulative_length >= self.max_cache_len
|
|
# Update it now that we saved the value above
|
|
self.cumulative_length += key_states.shape[-2]
|
|
|
|
if is_full:
|
|
# In general, we should use a much simpler `cat` here as well, independently of the states size. However,
|
|
# dynamo is currently bugged when doing it - see https://github.com/pytorch/pytorch/issues/159855 for more details
|
|
if key_states.shape[-2] == 1:
|
|
# Roll all values to the left by 1 position
|
|
new_keys = self.keys.roll(-1, dims=-2)
|
|
new_values = self.values.roll(-1, dims=-2)
|
|
# Overwrite the last position with new states
|
|
# (note: very important to use a tensor to index here, see https://github.com/pytorch/pytorch/issues/159855)
|
|
index = torch.tensor([-1], dtype=int, device=self.device)
|
|
new_keys[:, :, index] = key_states
|
|
new_values[:, :, index] = value_states
|
|
|
|
# Copy back into `self` (do not just assign again) in order to keep the static dynamo address
|
|
self.keys.copy_(new_keys)
|
|
self.values.copy_(new_values)
|
|
# Very important to return the `self` tensors here, as they have the static dynamo address
|
|
return self.keys, self.values
|
|
# Already full but using more than 1 new token (e.g. prefill caching, chat continuation, etc...)
|
|
else:
|
|
full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2)
|
|
full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2)
|
|
# Not yet full, but becoming full on this update
|
|
elif cumulative_length + key_states.shape[2] > self.max_cache_len:
|
|
# Fast prefill path, no need to cat() in this case, as the cache is currently empty
|
|
if cumulative_length == 0:
|
|
full_key_states = key_states
|
|
full_value_states = value_states
|
|
else:
|
|
full_key_states = torch.cat((self.keys[:, :, :cumulative_length, :], key_states), dim=-2)
|
|
full_value_states = torch.cat((self.values[:, :, :cumulative_length, :], value_states), dim=-2)
|
|
else:
|
|
try:
|
|
self.keys.index_copy_(2, cache_position, key_states)
|
|
self.values.index_copy_(2, cache_position, value_states)
|
|
except NotImplementedError:
|
|
self.keys[:, :, cache_position] = key_states
|
|
self.values[:, :, cache_position] = value_states
|
|
|
|
# Very important to return the `self` tensors here, as they have the static dynamo address
|
|
return self.keys, self.values
|
|
|
|
# We only cache the last `sliding_window` tokens
|
|
self.keys.copy_(full_key_states[:, :, -self.max_cache_len :, :])
|
|
self.values.copy_(full_value_states[:, :, -self.max_cache_len :, :])
|
|
# we should return the whole states instead of `self.keys/values` here, as otherwise we lose some context
|
|
return full_key_states, full_value_states
|
|
|
|
def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
|
|
"""Return the length and offset of the cache, used to generate the attention mask"""
|
|
query_length = cache_position.shape[0]
|
|
sliding_window = self.max_cache_len
|
|
is_full = self.cumulative_length >= self.max_cache_len
|
|
|
|
kv_offset = max(self.cumulative_length - sliding_window + 1, 0)
|
|
# The cache is already full
|
|
if is_full:
|
|
kv_length = sliding_window + query_length - 1
|
|
# Not yet full, but becoming full on this update
|
|
elif self.cumulative_length + query_length > sliding_window:
|
|
kv_length = self.cumulative_length + query_length
|
|
# Here the Cache is still smaller than the local size, but we return the local size as it's static
|
|
else:
|
|
kv_length = sliding_window
|
|
|
|
return kv_length, kv_offset
|
|
|
|
def get_seq_length(self) -> int:
|
|
"""Returns the sequence length of the cached states."""
|
|
return self.cumulative_length
|
|
|
|
|
|
class QuantizedLayer(DynamicLayer):
|
|
"""
|
|
A quantized layer similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750).
|
|
It allows the model to generate longer sequence length without allocating too much memory for the key and value caches by
|
|
applying quantization.
|
|
|
|
The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length`
|
|
is set as a maximum capacity for the original precision cache. When the length goes beyond maximum capacity, the original
|
|
precision cache is discarded and moved into the quantized cache. The quantization is done per-channel with a set `q_group_size`
|
|
for both Keys and Values, in contrast to what was described in the paper.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
nbits: int = 4,
|
|
axis_key: int = 0,
|
|
axis_value: int = 0,
|
|
q_group_size: int = 64,
|
|
residual_length: int = 128,
|
|
):
|
|
super().__init__()
|
|
self.nbits = nbits
|
|
self.axis_key = axis_key
|
|
self.axis_value = axis_value
|
|
self.q_group_size = q_group_size
|
|
self.residual_length = residual_length
|
|
self.cumulative_length = 0
|
|
|
|
def update(
|
|
self,
|
|
key_states: torch.Tensor,
|
|
value_states: torch.Tensor,
|
|
cache_kwargs: Optional[dict[str, Any]] = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Update the key and value caches in-place, and return the necessary keys and value states.
|
|
|
|
Args:
|
|
key_states (`torch.Tensor`): The new key states to cache.
|
|
value_states (`torch.Tensor`): The new value states to cache.
|
|
cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
|
|
|
|
Returns:
|
|
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
|
|
"""
|
|
self.cumulative_length += key_states.shape[-2]
|
|
|
|
# Lazy initialization
|
|
if not self.is_initialized:
|
|
self.lazy_initialization(key_states)
|
|
self._quantized_keys = self._quantize(key_states.contiguous(), axis=self.axis_key)
|
|
self._quantized_values = self._quantize(value_states.contiguous(), axis=self.axis_value)
|
|
return key_states, value_states
|
|
|
|
dequant_keys = self._dequantize(self._quantized_keys)
|
|
dequant_values = self._dequantize(self._quantized_values)
|
|
keys_to_return = torch.cat([dequant_keys, self.keys, key_states], dim=-2)
|
|
values_to_return = torch.cat([dequant_values, self.values, value_states], dim=-2)
|
|
if self.keys.dim() == 4 and self.keys.shape[-2] + 1 >= self.residual_length:
|
|
self._quantized_keys = self._quantize(keys_to_return.contiguous(), axis=self.axis_key)
|
|
self._quantized_values = self._quantize(values_to_return.contiguous(), axis=self.axis_value)
|
|
self.keys = torch.tensor([], dtype=key_states.dtype, device=key_states.device)
|
|
self.values = torch.tensor([], dtype=key_states.dtype, device=key_states.device)
|
|
else:
|
|
self.keys = torch.cat([self.keys, key_states], dim=-2)
|
|
self.values = torch.cat([self.values, value_states], dim=-2)
|
|
|
|
return keys_to_return, values_to_return
|
|
|
|
@abstractmethod
|
|
def _quantize(self, tensor, axis): ...
|
|
|
|
@abstractmethod
|
|
def _dequantize(self, q_tensor): ...
|
|
|
|
def get_seq_length(self) -> int:
|
|
"""Returns the sequence length of the cached states."""
|
|
return self.cumulative_length
|
|
|
|
|
|
class QuantoQuantizedLayer(QuantizedLayer):
|
|
def __init__(
|
|
self,
|
|
nbits: int = 4,
|
|
axis_key: int = 0,
|
|
axis_value: int = 0,
|
|
q_group_size: int = 64,
|
|
residual_length: int = 128,
|
|
):
|
|
super().__init__(
|
|
nbits=nbits,
|
|
axis_key=axis_key,
|
|
axis_value=axis_value,
|
|
q_group_size=q_group_size,
|
|
residual_length=residual_length,
|
|
)
|
|
|
|
# We need to import quanto here to avoid circular imports due to optimum/quanto/models/transformers_models.py
|
|
if is_quanto_greater("0.2.5", accept_dev=True):
|
|
from optimum.quanto import MaxOptimizer, qint2, qint4
|
|
else:
|
|
raise ImportError(
|
|
"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. "
|
|
)
|
|
|
|
if self.nbits not in [2, 4]:
|
|
raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
|
|
|
|
if self.axis_key not in [0, -1]:
|
|
raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}")
|
|
|
|
if self.axis_value not in [0, -1]:
|
|
raise ValueError(
|
|
f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}"
|
|
)
|
|
|
|
self.qtype = qint4 if self.nbits == 4 else qint2
|
|
self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization
|
|
|
|
def _quantize(self, tensor, axis):
|
|
from optimum.quanto import quantize_weight
|
|
|
|
scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
|
|
qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
|
|
return qtensor
|
|
|
|
def _dequantize(self, qtensor):
|
|
return qtensor.dequantize()
|
|
|
|
|
|
class HQQQuantizedLayer(QuantizedLayer):
|
|
def __init__(
|
|
self,
|
|
nbits: int = 4,
|
|
axis_key: int = 0,
|
|
axis_value: int = 0,
|
|
q_group_size: int = 64,
|
|
residual_length: int = 128,
|
|
):
|
|
super().__init__(
|
|
nbits=nbits,
|
|
axis_key=axis_key,
|
|
axis_value=axis_value,
|
|
q_group_size=q_group_size,
|
|
residual_length=residual_length,
|
|
)
|
|
|
|
if not is_hqq_available():
|
|
raise ImportError("You need to install `hqq` to use `HQQQuantizedLayer`")
|
|
|
|
if self.nbits not in [1, 2, 3, 4, 8]:
|
|
raise ValueError(
|
|
f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}"
|
|
)
|
|
|
|
if self.axis_key not in [0, 1]:
|
|
raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}")
|
|
|
|
if self.axis_value not in [0, 1]:
|
|
raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}")
|
|
|
|
self.quantizer = HQQQuantizer
|
|
|
|
def _quantize(self, tensor, axis):
|
|
qtensor, meta = self.quantizer.quantize(
|
|
tensor,
|
|
axis=axis,
|
|
device=self.keys.device,
|
|
compute_dtype=self.keys.dtype,
|
|
nbits=self.nbits,
|
|
group_size=self.q_group_size,
|
|
)
|
|
meta["compute_dtype"] = self.keys.dtype
|
|
self.quantizer.cuda(qtensor, meta=meta, device=self.keys.device) # Move to device and cast to dtype
|
|
meta["scale"] = meta["scale"].to(qtensor.device)
|
|
meta["zero"] = meta["zero"].to(qtensor.device)
|
|
return qtensor, meta
|
|
|
|
def _dequantize(self, qtensor):
|
|
quant_tensor, meta = qtensor
|
|
tensor = self.quantizer.dequantize(quant_tensor, meta)
|
|
return tensor
|
|
|
|
|
|
class Cache:
|
|
"""
|
|
A `Cache` is mostly a list of `CacheLayerMixin` objects, one per model layer. It serves as a container for
|
|
the Cache of each layer.
|
|
|
|
Args:
|
|
layers (`Optional`, *optional*):
|
|
A list of pre-created `CacheLayerMixin`. If omitted (`None`), then `layer_class_to_replicate` will
|
|
be used.
|
|
layer_class_to_replicate (`type[CacheLayerMixin]`, *optional*):
|
|
Only used if `layers` is omitted (`None`), in which case it will be used as the base class for each layer,
|
|
and the layers will be added lazily as soon as `update` is called with a `layer_idx` greater than the current
|
|
list of layers.
|
|
offloading (`bool`, *optional*, defaults to `False`):
|
|
Whether to perform offloading of the layers to `cpu`, to save GPU memory.
|
|
offload_only_non_sliding (`bool`, *optional*, defaults to `True`):
|
|
If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
|
|
usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
layers: Optional[list[CacheLayerMixin]] = None,
|
|
layer_class_to_replicate: Optional[type[CacheLayerMixin]] = None,
|
|
offloading: bool = False,
|
|
offload_only_non_sliding: bool = True,
|
|
):
|
|
if layers is not None and layer_class_to_replicate is not None:
|
|
raise ValueError(
|
|
"You can construct a Cache either from a list `layers` of all the predefined `CacheLayer`, or from a "
|
|
"`layer_class_to_replicate`, in which case the Cache will append a new layer corresponding to "
|
|
"`layer_class_to_replicate` for each new call to `update` with an idx not already in the Cache."
|
|
)
|
|
if layers is None and layer_class_to_replicate is None:
|
|
raise ValueError(
|
|
"You should provide exactly one of `layers` or `layer_class_to_replicate` to initialize a Cache."
|
|
)
|
|
self.layers = layers if layers is not None else []
|
|
self.layer_class_to_replicate = layer_class_to_replicate
|
|
self.offloading = offloading
|
|
if self.offloading:
|
|
self.only_non_sliding = offload_only_non_sliding
|
|
self.prefetch_stream = torch.Stream() if _is_torch_greater_or_equal_than_2_7 else torch.cuda.Stream()
|
|
|
|
def __repr__(self):
|
|
return f"{self.__class__.__name__}(layers={self.layers})"
|
|
|
|
def prefetch(self, layer_idx: int, only_non_sliding: bool = True):
|
|
"""
|
|
Prefetch a given layer on its device. If `only_non_sliding` is True, it will try to prefetch only the layers
|
|
which are non-sliding. If the `layer_idx` is outside the range, this will circle back to the first layers.
|
|
Note that we use a non-default stream for this, to avoid blocking.
|
|
"""
|
|
if only_non_sliding:
|
|
# Try to find next non-sliding, starting at `layer_idx`
|
|
try:
|
|
layer_idx = layer_idx + self.is_sliding[layer_idx:].index(False)
|
|
# In this case, we need to circle back to the beginning
|
|
except ValueError:
|
|
layer_idx = self.is_sliding.index(False)
|
|
else:
|
|
layer_idx = layer_idx if layer_idx < len(self.layers) else 0
|
|
|
|
# Prefetch
|
|
with self.prefetch_stream if _is_torch_greater_or_equal_than_2_7 else torch.cuda.stream(self.prefetch_stream):
|
|
self.layers[layer_idx].prefetch()
|
|
|
|
def offload(self, layer_idx: int, only_non_sliding: bool = True):
|
|
"""
|
|
Offload a given `layer_idx`. If `only_non_sliding` is True, it will offload `layer_idx` only if it is a
|
|
non-sliding layer. Note that we do it on the default stream, so that we ensure all earlier
|
|
computation in the layer's `update` methods are finished.
|
|
"""
|
|
if not (only_non_sliding and self.is_sliding[layer_idx]):
|
|
self.layers[layer_idx].offload()
|
|
|
|
def update(
|
|
self,
|
|
key_states: torch.Tensor,
|
|
value_states: torch.Tensor,
|
|
layer_idx: int,
|
|
cache_kwargs: Optional[dict[str, Any]] = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
|
|
|
Parameters:
|
|
key_states (`torch.Tensor`):
|
|
The new key states to cache.
|
|
value_states (`torch.Tensor`):
|
|
The new value states to cache.
|
|
layer_idx (`int`):
|
|
The index of the layer to cache the states for.
|
|
cache_kwargs (`dict[str, Any]`, *optional*):
|
|
Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
|
|
cache to be created.
|
|
|
|
Return:
|
|
A tuple containing the updated key and value states.
|
|
"""
|
|
# In this case, the `layers` were not provided, and we must append as much as `layer_idx`
|
|
if self.layer_class_to_replicate is not None:
|
|
while len(self.layers) <= layer_idx:
|
|
self.layers.append(self.layer_class_to_replicate())
|
|
|
|
if self.offloading:
|
|
# Wait for the stream to finish if needed, and start prefetching the next layer
|
|
torch.cuda.default_stream(key_states.device).wait_stream(self.prefetch_stream)
|
|
self.prefetch(layer_idx + 1, self.only_non_sliding)
|
|
|
|
keys, values = self.layers[layer_idx].update(key_states, value_states, cache_kwargs)
|
|
|
|
if self.offloading:
|
|
self.offload(layer_idx, self.only_non_sliding)
|
|
|
|
return keys, values
|
|
|
|
def early_initialization(
|
|
self, batch_size: int, num_heads: int, head_dim: int, dtype: torch.dtype, device: torch.device
|
|
):
|
|
"""
|
|
Initialize all the layers in advance (it's otherwise lazily initialized on the first `update` call).
|
|
This is useful for our `export` recipes, as `export` needs everything in advance.
|
|
"""
|
|
# Note that the initialization needs all dimensions (except -2), as well as device and dtype, so we use
|
|
# this fake tensor approach. It has size 0 on the -2 dimension, so it does not allocate any data (it only
|
|
# creates an empty tensor with correct shape, dtype and device), which is very efficient and practical
|
|
fake_keys_tensor = torch.zeros((batch_size, num_heads, 0, head_dim), dtype=dtype, device=device)
|
|
# Init all layers
|
|
for layer in self.layers:
|
|
layer.lazy_initialization(fake_keys_tensor)
|
|
|
|
def get_seq_length(self, layer_idx: int = 0) -> int:
|
|
"""Returns the sequence length of the cache for the given layer."""
|
|
if layer_idx >= len(self.layers):
|
|
return 0
|
|
return self.layers[layer_idx].get_seq_length()
|
|
|
|
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
|
|
"""
|
|
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
|
|
the given layer at `layer_idx`.
|
|
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer.
|
|
"""
|
|
# For DynamicCache, where the layers are created at runtime -> if it was not yet created, the size is
|
|
# simply the shape of `cache_position`
|
|
if layer_idx >= len(self.layers):
|
|
return cache_position.shape[0], 0
|
|
return self.layers[layer_idx].get_mask_sizes(cache_position)
|
|
|
|
def get_max_cache_shape(self, layer_idx: int = 0) -> int:
|
|
"""Returns maximum sequence length of the cache object. Dynamic caches do not have a maximum length."""
|
|
# For DynamicCache, where the layers are created at runtime -> if it was not yet created, return -1
|
|
# as DynamicLayer does
|
|
if layer_idx >= len(self.layers):
|
|
return -1
|
|
return self.layers[layer_idx].get_max_cache_shape()
|
|
|
|
def reset(self):
|
|
"""Recursively reset all layers tensors"""
|
|
for layer_idx in range(len(self.layers)):
|
|
self.layers[layer_idx].reset()
|
|
|
|
def reorder_cache(self, beam_idx: torch.LongTensor):
|
|
"""Reorder the cache for beam search"""
|
|
for layer_idx in range(len(self.layers)):
|
|
self.layers[layer_idx].reorder_cache(beam_idx)
|
|
|
|
def crop(self, max_length: int):
|
|
"""Crop the cache to the given length"""
|
|
for layer_idx in range(len(self.layers)):
|
|
self.layers[layer_idx].crop(max_length)
|
|
|
|
def batch_repeat_interleave(self, repeats: int):
|
|
"""Repeat and interleave the cache"""
|
|
for layer_idx in range(len(self.layers)):
|
|
self.layers[layer_idx].batch_repeat_interleave(repeats)
|
|
|
|
def batch_select_indices(self, indices: torch.Tensor):
|
|
"""Select indices from the cache"""
|
|
for layer_idx in range(len(self.layers)):
|
|
self.layers[layer_idx].batch_select_indices(indices)
|
|
|
|
@property
|
|
def max_batch_size(self) -> int:
|
|
"""Return the maximum batch size of the cache"""
|
|
values = [layer.max_batch_size for layer in self.layers]
|
|
if len(set(values)) > 1:
|
|
raise ValueError(f"Max batch size is not consistent across layers: {values}")
|
|
return values[0]
|
|
|
|
@property
|
|
def max_cache_len(self) -> int:
|
|
"""Return the maximum cache length of the cache"""
|
|
values = [layer.max_cache_len for layer in self.layers]
|
|
return max(values)
|
|
|
|
@property
|
|
def is_compileable(self) -> bool:
|
|
"""Return whether the cache is compileable"""
|
|
# For DynamicCache dispatching the layers lazily (otherwise, all([]) is True)
|
|
if len(self.layers) == 0:
|
|
return False
|
|
return all(layer.is_compileable for layer in self.layers)
|
|
|
|
@property
|
|
def is_initialized(self) -> bool:
|
|
"""Return whether the cache data is initialized"""
|
|
return len(self.layers) > 0 and all(layer.is_initialized for layer in self.layers)
|
|
|
|
@property
|
|
def is_sliding(self) -> list[bool]:
|
|
"""Return whether the layers of the cache are sliding window"""
|
|
return [getattr(layer, "is_sliding", False) for layer in self.layers]
|
|
|
|
def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Support for backwards-compatible `past_key_values` indexing, e.g. `past_key_values[0][0].shape[2]` to get the
|
|
sequence length.
|
|
"""
|
|
if layer_idx < len(self.layers):
|
|
return self.layers[layer_idx].keys, self.layers[layer_idx].values
|
|
else:
|
|
raise KeyError(
|
|
f"Cache only has {len(self.layers)} layers, attempted to access layer with index {layer_idx}"
|
|
)
|
|
|
|
def __iter__(self):
|
|
"""
|
|
Support for backwards-compatible `past_key_values` iteration, e.g. `for x in past_key_values:` to iterate over
|
|
keys and values
|
|
"""
|
|
for layer_idx in range(len(self)):
|
|
yield (self.layers[layer_idx].keys, self.layers[layer_idx].values)
|
|
|
|
def __len__(self):
|
|
"""
|
|
This value corresponds to the number of layers in the model.
|
|
"""
|
|
# Note: for DynamicCache, layers are initialized lazily, so this will not be accurate before the first
|
|
# forward through all the layers
|
|
return len(self.layers)
|
|
|
|
|
|
class DynamicCache(Cache):
|
|
"""
|
|
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
|
|
It stores the key and value states as a list of `CacheLayer`, one for each layer. The expected shape for each tensor
|
|
in the `CacheLayer`s is `[batch_size, num_heads, seq_len, head_dim]`.
|
|
If a config is passed, it will additionally check for sliding or hybrid cache structure, greatly reducing the
|
|
memory requirement of the cached tensors to `[batch_size, num_heads, min(seq_len, sliding_window), head_dim]`.
|
|
|
|
See `Cache` for details on common methods that are implemented by all cache classes.
|
|
|
|
Args:
|
|
ddp_cache_data (`Iterable[tuple[torch.Tensor, torch.Tensor]]`, *optional*):
|
|
It was originally added for compatibility with `torch.distributed` (DDP). In a nutshell, it is
|
|
`map(gather_map, zip(*caches))`, i.e. each item in the iterable contains the key and value states
|
|
for a layer gathered across replicas by torch.distributed (shape=[global batch size, num_heads, seq_len, head_dim]).
|
|
Note: it needs to be the 1st arg as well to work correctly
|
|
config (`PretrainedConfig`, *optional*):
|
|
The config of the model for which this Cache will be used. If passed, it will be used to check for sliding
|
|
or hybrid layer structure, greatly reducing the memory requirement of the cached tensors to
|
|
`[batch_size, num_heads, min(seq_len, sliding_window), head_dim]`.
|
|
offloading (`bool`, *optional*, defaults to `False`):
|
|
Whether to perform offloading of the layers to `cpu`, to save GPU memory.
|
|
offload_only_non_sliding (`bool`, *optional*, defaults to `False`):
|
|
If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
|
|
usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster).
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
|
|
|
|
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
|
|
|
>>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
|
|
|
|
>>> # Prepare a cache class and pass it to model's forward
|
|
>>> past_key_values = DynamicCache(config=model.config)
|
|
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
|
>>> outputs.past_key_values # access cache filled with key/values from generation
|
|
```
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None,
|
|
config: Optional[PretrainedConfig] = None,
|
|
offloading: bool = False,
|
|
offload_only_non_sliding: bool = False,
|
|
):
|
|
layers = []
|
|
# If a config is passed, use it to infer the layer types and initialize accordingly
|
|
if config is not None:
|
|
decoder_config = config.get_text_config(decoder=True)
|
|
sliding_window = getattr(decoder_config, "sliding_window", None) or getattr(
|
|
decoder_config, "attention_chunk_size", None
|
|
)
|
|
layer_types = getattr(decoder_config, "layer_types", None)
|
|
if layer_types is None:
|
|
layer_types = [
|
|
"sliding_attention" if sliding_window is not None else "full_attention"
|
|
for _ in range(decoder_config.num_hidden_layers)
|
|
]
|
|
# Some models have shared layers thus no cache is needed for them (e.g. Gemma3n)
|
|
if hasattr(decoder_config, "num_kv_shared_layers"):
|
|
layer_types = layer_types[: -decoder_config.num_kv_shared_layers]
|
|
|
|
for layer_type in layer_types:
|
|
# From a cache point of view, both sliding and chunked are the same in how they should behave and how many
|
|
# states they should return - only the mask changes to make them different at the end!
|
|
if layer_type in ("sliding_attention", "chunked_attention"):
|
|
layers.append(DynamicSlidingWindowLayer(sliding_window=sliding_window))
|
|
else:
|
|
layers.append(DynamicLayer())
|
|
|
|
# In this case, use the passed data to already fill in the Cache
|
|
if ddp_cache_data is not None:
|
|
# Init all the layers with the data
|
|
for layer_idx, (key_states, value_states) in enumerate(ddp_cache_data):
|
|
# If the config was not passed above, initialize a DynamicLayer for each entry of the ddp_data
|
|
if config is None:
|
|
layers.append(DynamicLayer())
|
|
# Update the layer with the data
|
|
_, _ = layers[layer_idx].update(key_states, value_states)
|
|
|
|
# If neither of config nor ddp_data was passed, then simply lazy init a full cache of DynamicLayer
|
|
if len(layers) == 0:
|
|
super().__init__(
|
|
layer_class_to_replicate=DynamicLayer,
|
|
offloading=offloading,
|
|
offload_only_non_sliding=offload_only_non_sliding,
|
|
)
|
|
else:
|
|
super().__init__(layers=layers, offloading=offloading, offload_only_non_sliding=offload_only_non_sliding)
|
|
|
|
def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]:
|
|
"""
|
|
Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for
|
|
backward compatibility.
|
|
"""
|
|
legacy_cache = ()
|
|
for layer in self.layers:
|
|
legacy_cache += ((layer.keys, layer.values),)
|
|
return legacy_cache
|
|
|
|
@classmethod
|
|
def from_legacy_cache(cls, past_key_values: tuple[tuple[torch.Tensor, torch.Tensor]]) -> "DynamicCache":
|
|
"""
|
|
Converts a cache in the legacy cache format into an equivalent `Cache`. Used for
|
|
backward compatibility.
|
|
"""
|
|
cache = cls()
|
|
if past_key_values is None:
|
|
logger.warning_once("past_key_values should not be None in from_legacy_cache()")
|
|
if past_key_values is not None:
|
|
for layer_idx in range(len(past_key_values)):
|
|
key_states, value_states = past_key_values[layer_idx]
|
|
cache.update(key_states, value_states, layer_idx)
|
|
return cache
|
|
|
|
|
|
class StaticCache(Cache):
|
|
"""
|
|
Static Cache class to be used with `torch.compile(model)` and `torch.export()`. It will check the `config`
|
|
for potential hybrid cache structure, and initialize each layer accordingly.
|
|
|
|
See `Cache` for details on common methods that are implemented by all cache classes.
|
|
|
|
Args:
|
|
config (`PretrainedConfig`):
|
|
The config of the model for which this Cache will be used. It will be used to check for sliding
|
|
or hybrid layer structure, and initialize each layer accordingly.
|
|
max_cache_len (`int`):
|
|
The maximum number of tokens that this Cache should hold.
|
|
offloading (`bool`, *optional*, defaults to `False`):
|
|
Whether to perform offloading of the layers to `cpu`, to save GPU memory.
|
|
offload_only_non_sliding (`bool`, *optional*, defaults to `True`):
|
|
If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
|
|
usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster).
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
|
|
|
|
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
|
|
|
|
>>> inputs = tokenizer(text="My name is Llama", return_tensors="pt")
|
|
|
|
>>> # Prepare a cache class and pass it to model's forward
|
|
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
|
|
>>> max_generated_length = inputs.input_ids.shape[1] + 10
|
|
>>> past_key_values = StaticCache(config=model.config, max_cache_len=max_generated_length)
|
|
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
|
>>> outputs.past_key_values # access cache filled with key/values from generation
|
|
StaticCache()
|
|
```
|
|
"""
|
|
|
|
# Pass-in kwargs as well to avoid crashing for BC (it used more arguments before)
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
max_cache_len: int,
|
|
offloading: bool = False,
|
|
offload_only_non_sliding: bool = True,
|
|
**kwargs,
|
|
):
|
|
config = config.get_text_config(decoder=True)
|
|
layer_types = getattr(config, "layer_types", None)
|
|
# If `layer_types` is not explicitly provided, infer if the model is fully sliding
|
|
if layer_types is None:
|
|
if getattr(config, "sliding_window", None) is not None:
|
|
layer_types = ["sliding_attention" for _ in range(config.num_hidden_layers)]
|
|
elif getattr(config, "attention_chunk_size", None) is not None:
|
|
layer_types = ["chunked_attention" for _ in range(config.num_hidden_layers)]
|
|
else:
|
|
layer_types = ["full_attention" for _ in range(config.num_hidden_layers)]
|
|
# Some models have shared layers thus no cache is needed for them (e.g. Gemma3n)
|
|
if hasattr(config, "num_kv_shared_layers"):
|
|
layer_types = layer_types[: -config.num_kv_shared_layers]
|
|
|
|
layers = []
|
|
for layer_type in layer_types:
|
|
if layer_type == "sliding_attention":
|
|
layer = StaticSlidingWindowLayer(max_cache_len=max_cache_len, sliding_window=config.sliding_window)
|
|
elif layer_type == "chunked_attention":
|
|
# From a cache point of view, both sliding and chunked are the same in how they should behave and how many
|
|
# states they should return - only the mask changes to make them different at the end!
|
|
layer = StaticSlidingWindowLayer(
|
|
max_cache_len=max_cache_len, sliding_window=config.attention_chunk_size
|
|
)
|
|
else:
|
|
layer = StaticLayer(max_cache_len=max_cache_len)
|
|
layers.append(layer)
|
|
|
|
super().__init__(layers=layers, offloading=offloading, offload_only_non_sliding=offload_only_non_sliding)
|
|
|
|
|
|
class QuantizedCache(Cache):
|
|
"""
|
|
A quantizer cache similar to what is described in the
|
|
[KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750).
|
|
It allows the model to generate longer sequence length without allocating too much memory for keys and values
|
|
by applying quantization.
|
|
The cache has two types of storage, one for original precision and one for the
|
|
quantized cache. A `residual length` is set as a maximum capacity for the original precision cache. When the
|
|
length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache.
|
|
The quantization is done per-channel with a set `q_group_size` for both keys and values, in contrast to what was
|
|
described in the paper.
|
|
|
|
See `Cache` for details on common methods that are implemented by all cache classes.
|
|
|
|
Args:
|
|
backend (`str`):
|
|
The quantization backend to use. One of `("quanto", "hqq").
|
|
config (`PretrainedConfig`):
|
|
The config of the model for which this Cache will be used.
|
|
nbits (`int`, *optional*, defaults to 4):
|
|
The number of bits for quantization.
|
|
axis_key (`int`, *optional*, defaults to 0):
|
|
The axis on which to quantize the keys.
|
|
axis_value (`int`, *optional*, defaults to 0):
|
|
The axis on which to quantize the values.
|
|
q_group_size (`int`, *optional*, defaults to 64):
|
|
Quantization is done per-channel according to a set `q_group_size` for both keys and values.
|
|
residual_length (`int`, *optional*, defaults to 128):
|
|
Maximum capacity for the original precision cache
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
backend: str,
|
|
config: PretrainedConfig,
|
|
nbits: int = 4,
|
|
axis_key: int = 0,
|
|
axis_value: int = 0,
|
|
q_group_size: int = 64,
|
|
residual_length: int = 128,
|
|
):
|
|
if backend == "quanto":
|
|
layer_class = QuantoQuantizedLayer
|
|
elif backend == "hqq":
|
|
layer_class = HQQQuantizedLayer
|
|
else:
|
|
raise ValueError(f"Unknown quantization backend `{backend}`")
|
|
|
|
config = config.get_text_config(decoder=True)
|
|
layers = [
|
|
layer_class(nbits, axis_key, axis_value, q_group_size, residual_length)
|
|
for _ in range(config.num_hidden_layers)
|
|
]
|
|
super().__init__(layers=layers)
|
|
|
|
|
|
class EncoderDecoderCache(Cache):
|
|
"""
|
|
Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and
|
|
cross-attention caches.
|
|
|
|
See `Cache` for details on common methods that are implemented by all cache classes.
|
|
|
|
Args:
|
|
caches (`Iterable`):
|
|
Usually an iterable of length 2, containing 2 `Cache` objects, the first one for self-attention, the
|
|
second one for cross-attention. Can optionally also be an iterable of length 1, containing a
|
|
`tuple[tuple[torch.Tensor]]` (usually used for compatibility with torch dp and ddp).
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache
|
|
|
|
>>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small")
|
|
>>> processor = AutoProcessor.from_pretrained("openai/whisper-small")
|
|
|
|
>>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt")
|
|
|
|
>>> # Prepare cache classes for encoder and decoder and pass it to model's forward
|
|
>>> self_attention_cache = DynamicCache(config=self.config)
|
|
>>> cross_attention_cache = DynamicCache(config=self.config)
|
|
>>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache)
|
|
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
|
>>> outputs.past_key_values # access cache filled with key/values from generation
|
|
EncoderDecoderCache()
|
|
```
|
|
"""
|
|
|
|
def __init__(self, *caches) -> None:
|
|
# For dp and ddp support, if only one argument is passed, it should be an iterable of tuples of tensors
|
|
if len(caches) == 1:
|
|
self.self_attention_cache = DynamicCache()
|
|
self.cross_attention_cache = DynamicCache()
|
|
# Populate cache from the iterable
|
|
for layer_idx, key_value_states in enumerate(caches[0]):
|
|
key_states, value_states = key_value_states[:2]
|
|
self.self_attention_cache.update(key_states, value_states, layer_idx)
|
|
if len(key_value_states) > 2:
|
|
key_states, value_states = key_value_states[2:]
|
|
self.cross_attention_cache.update(key_states, value_states, layer_idx)
|
|
# Otherwise, we should get two arguments, a self-attention cache and a cross-attention cache
|
|
elif len(caches) == 2:
|
|
if not isinstance(caches[0], Cache) or not isinstance(caches[1], Cache):
|
|
raise TypeError(f"One of the two arguments is not a Cache: {type(caches[0]) = }, {type(caches[1]) = }")
|
|
self.self_attention_cache = caches[0]
|
|
self.cross_attention_cache = caches[1]
|
|
# Error case
|
|
else:
|
|
raise ValueError(f"Expected 1 or 2 arguments, got {len(caches)}")
|
|
|
|
self.is_updated = {}
|
|
for layer_idx in range(len(self.cross_attention_cache)):
|
|
self.is_updated[layer_idx] = bool(self.cross_attention_cache.get_seq_length(layer_idx) > 0)
|
|
|
|
def __repr__(self) -> str:
|
|
return (
|
|
f"{self.__class__.__name__}(self_attention_cache={self.self_attention_cache}, cross_attention_cache="
|
|
f"{self.cross_attention_cache})"
|
|
)
|
|
|
|
def __iter__(self):
|
|
"""
|
|
Support for backwards-compatible `past_key_values` iteration, e.g. `for x in past_key_values:` to iterate over
|
|
keys and values
|
|
"""
|
|
for layer_idx in range(len(self)):
|
|
yield (
|
|
self.self_attention_cache.layers[layer_idx].keys,
|
|
self.self_attention_cache.layers[layer_idx].values,
|
|
self.cross_attention_cache.layers[layer_idx].keys,
|
|
self.cross_attention_cache.layers[layer_idx].values,
|
|
)
|
|
|
|
def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Support for backwards-compatible `past_key_values` indexing, e.g. `past_key_values[0][0].shape[2]` to get the
|
|
sequence length.
|
|
"""
|
|
if layer_idx < len(self):
|
|
return (
|
|
self.self_attention_cache.layers[layer_idx].keys,
|
|
self.self_attention_cache.layers[layer_idx].values,
|
|
self.cross_attention_cache.layers[layer_idx].keys,
|
|
self.cross_attention_cache.layers[layer_idx].values,
|
|
)
|
|
else:
|
|
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
|
|
|
|
def __len__(self):
|
|
"""
|
|
Support for backwards-compatible `past_key_values` length, e.g. `len(past_key_values)`. This value corresponds
|
|
to the number of layers in the model.
|
|
"""
|
|
return len(self.self_attention_cache)
|
|
|
|
def to_legacy_cache(self) -> tuple[tuple[torch.Tensor]]:
|
|
"""Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format."""
|
|
legacy_cache = ()
|
|
if len(self.cross_attention_cache) > 0:
|
|
for self_attn, cross_attn in zip(
|
|
self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache()
|
|
):
|
|
legacy_cache += (self_attn + cross_attn,)
|
|
else:
|
|
legacy_cache = self.self_attention_cache.to_legacy_cache()
|
|
return legacy_cache
|
|
|
|
@classmethod
|
|
def from_legacy_cache(
|
|
cls, past_key_values: Optional[Iterable[tuple[torch.FloatTensor, ...]]]
|
|
) -> "EncoderDecoderCache":
|
|
"""Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
|
|
cache = cls(DynamicCache(), DynamicCache())
|
|
if past_key_values is None:
|
|
logger.warning_once("past_key_values should not be None in from_legacy_cache()")
|
|
else:
|
|
for layer_idx, key_value_states in enumerate(past_key_values):
|
|
key_states, value_states = key_value_states[:2]
|
|
cache.self_attention_cache.update(key_states, value_states, layer_idx)
|
|
if len(key_value_states) > 2:
|
|
key_states, value_states = key_value_states[2:]
|
|
cache.cross_attention_cache.update(key_states, value_states, layer_idx)
|
|
cache.is_updated[layer_idx] = True
|
|
return cache
|
|
|
|
def get_seq_length(self, layer_idx: int = 0) -> int:
|
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
|
return self.self_attention_cache.get_seq_length(layer_idx)
|
|
|
|
def reset(self):
|
|
self.self_attention_cache.reset()
|
|
self.cross_attention_cache.reset()
|
|
for layer_idx in self.is_updated:
|
|
self.is_updated[layer_idx] = False
|
|
|
|
def reorder_cache(self, beam_idx: torch.LongTensor):
|
|
"""Reorders the cache for beam search, given the selected beam indices."""
|
|
self.self_attention_cache.reorder_cache(beam_idx)
|
|
self.cross_attention_cache.reorder_cache(beam_idx)
|
|
|
|
def check_dynamic_cache(self, method: str):
|
|
if not (
|
|
isinstance(self.self_attention_cache, DynamicCache)
|
|
and isinstance(self.cross_attention_cache, DynamicCache)
|
|
):
|
|
raise ValueError(
|
|
f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self "
|
|
f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache."
|
|
)
|
|
|
|
# TODO(gante, sanchit-gandhi): move following functionality into `.generate`
|
|
def crop(self, maximum_length: int):
|
|
"""
|
|
Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
|
|
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search (on the Hub).
|
|
"""
|
|
self.check_dynamic_cache(self.crop.__name__)
|
|
self.self_attention_cache.crop(maximum_length)
|
|
|
|
def batch_split(self, full_batch_size: int, split_size: int) -> "list[EncoderDecoderCache]":
|
|
"""
|
|
Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
|
|
`_split_model_inputs()` in `generation.utils`
|
|
"""
|
|
self.check_dynamic_cache(self.batch_split.__name__)
|
|
self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
|
|
cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)
|
|
|
|
out = []
|
|
for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache):
|
|
out.append(EncoderDecoderCache(self_attn, cross_attn))
|
|
return out
|
|
|
|
def batch_repeat_interleave(self, repeats: int):
|
|
"""Repeat the cache `repeats` times in the batch dimension. Used in contrastive search (on the Hub)."""
|
|
self.check_dynamic_cache(self.batch_repeat_interleave.__name__)
|
|
self.self_attention_cache.batch_repeat_interleave(repeats)
|
|
self.cross_attention_cache.batch_repeat_interleave(repeats)
|
|
|
|
def batch_select_indices(self, indices: torch.Tensor):
|
|
"""Only keep the `indices` in the batch dimension of the cache. Used in contrastive search (on the Hub)."""
|
|
self.check_dynamic_cache(self.batch_select_indices.__name__)
|
|
self.self_attention_cache.batch_select_indices(indices)
|
|
self.cross_attention_cache.batch_select_indices(indices)
|
|
|
|
def get_max_cache_shape(self) -> int:
|
|
"""Returns the maximum sequence length (i.e. max capacity) of the cache object"""
|
|
return self.self_attention_cache.get_max_cache_shape()
|
|
|
|
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
|
|
return self.self_attention_cache.get_mask_sizes(cache_position, layer_idx)
|
|
|
|
@property
|
|
def is_sliding(self):
|
|
return self.self_attention_cache.is_sliding
|
|
|
|
@property
|
|
def is_compileable(self) -> bool:
|
|
return self.self_attention_cache.is_compileable
|
|
|
|
|
|
### Deprecated classes
|
|
|
|
|
|
class SlidingWindowLayer(StaticSlidingWindowLayer):
|
|
def __init__(self, max_cache_len: int, sliding_window: int):
|
|
logger.warning_once(
|
|
"`SlidingWindowLayer` is deprecated and will be removed in version v4.59 "
|
|
"Use `StaticSlidingWindowLayer` instead, which is a better name for it."
|
|
)
|
|
super().__init__(max_cache_len, sliding_window)
|
|
|
|
|
|
class ChunkedSlidingLayer(StaticSlidingWindowLayer):
|
|
def __init__(self, max_cache_len: int, sliding_window: int):
|
|
logger.warning_once(
|
|
"`ChunkedSlidingLayer` is deprecated and will be removed in version v4.59 "
|
|
"Use `StaticSlidingWindowLayer` instead, which has the exact same functionalities."
|
|
)
|
|
super().__init__(max_cache_len, sliding_window)
|
|
|
|
|
|
class OffloadedCache(DynamicCache):
|
|
def __init__(self) -> None:
|
|
logger.warning_once(
|
|
"`OffloadedCache` is deprecated and will be removed in version v4.59 "
|
|
"Use `DynamicCache(offloading=True)` instead"
|
|
)
|
|
super().__init__(offloading=True)
|
|
|
|
|
|
class OffloadedStaticCache(StaticCache):
|
|
def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
|
|
logger.warning_once(
|
|
"`OffloadedStaticCache` is deprecated and will be removed in version v4.59 "
|
|
"Use `StaticCache(..., offloading=True)` instead"
|
|
)
|
|
super().__init__(config=config, max_cache_len=max_cache_len, offloading=True)
|
|
|
|
|
|
class SlidingWindowCache(StaticCache):
|
|
def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
|
|
logger.warning_once(
|
|
"`SlidingWindowCache` is deprecated and will be removed in version v4.59 "
|
|
"Use `StaticCache(...)` instead which will correctly infer the type of each layer."
|
|
)
|
|
super().__init__(config=config, max_cache_len=max_cache_len)
|
|
|
|
|
|
class HybridCache(StaticCache):
|
|
def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
|
|
logger.warning_once(
|
|
"`HybridCache` is deprecated and will be removed in version v4.59 "
|
|
"Use `StaticCache(...)` instead which will correctly infer the type of each layer."
|
|
)
|
|
super().__init__(config=config, max_cache_len=max_cache_len)
|
|
|
|
|
|
class HybridChunkedCache(StaticCache):
|
|
def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
|
|
logger.warning_once(
|
|
"`HybridChunkedCache` is deprecated and will be removed in version v4.59 "
|
|
"Use `StaticCache(...)` instead which will correctly infer the type of each layer."
|
|
)
|
|
super().__init__(config=config, max_cache_len=max_cache_len)
|
|
|
|
|
|
class OffloadedHybridCache(StaticCache):
|
|
def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
|
|
logger.warning_once(
|
|
"`OffloadedHybridCache` is deprecated and will be removed in version v4.59 "
|
|
"Use `StaticCache(..., offload=True)` instead which will correctly infer the type of each layer."
|
|
)
|
|
super().__init__(config=config, max_cache_len=max_cache_len, offloading=True)
|
|
|
|
|
|
class QuantoQuantizedCache(QuantizedCache):
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
nbits: int = 4,
|
|
axis_key: int = 0,
|
|
axis_value: int = 0,
|
|
q_group_size: int = 64,
|
|
residual_length: int = 128,
|
|
):
|
|
logger.warning_once(
|
|
"`QuantoQuantizedCache` is deprecated and will be removed in version v4.59 "
|
|
"Use `QuantizedCache(backend='quanto', ...)` instead."
|
|
)
|
|
super().__init__("quanto", config, nbits, axis_key, axis_value, q_group_size, residual_length)
|
|
|
|
|
|
class HQQQuantizedCache(QuantizedCache):
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
nbits: int = 4,
|
|
axis_key: int = 0,
|
|
axis_value: int = 0,
|
|
q_group_size: int = 64,
|
|
residual_length: int = 128,
|
|
):
|
|
logger.warning_once(
|
|
"`HQQQuantizedCache` is deprecated and will be removed in version v4.59 "
|
|
"Use `QuantizedCache(backend='hqq', ...)` instead."
|
|
)
|
|
super().__init__("hqq", config, nbits, axis_key, axis_value, q_group_size, residual_length)
|
|
|
|
|
|
class SinkCache(Cache):
|
|
"""
|
|
It is now a `custom_generate` repository on the Hub: https://huggingface.co/transformers-community/sink_cache.
|
|
See [these docs](https://huggingface.co/docs/transformers/generation_strategies#custom-decoding-methods) for
|
|
general `custom_generate`usage.
|
|
"""
|
|
|
|
# TODO (joao, manuel): Remove this class in v4.59.0
|
|
def __init__(self, **kwargs) -> None:
|
|
raise NotImplementedError(
|
|
"`SinkCache` has been moved as a `custom_generate` repository on the Hub: "
|
|
"https://huggingface.co/transformers-community/sink_cache. See the repository for usage examples."
|
|
)
|