357 lines
13 KiB
Python
357 lines
13 KiB
Python
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
import functools
|
||
import math
|
||
from collections import OrderedDict
|
||
|
||
import torch
|
||
from torch import Tensor, nn
|
||
|
||
from .integrations.hub_kernels import use_kernel_forward_from_hub
|
||
from .utils import logging
|
||
from .utils.import_utils import is_torchdynamo_compiling
|
||
|
||
|
||
logger = logging.get_logger(__name__)
|
||
|
||
|
||
@use_kernel_forward_from_hub("GeluTanh")
|
||
class GELUTanh(nn.Module):
|
||
"""
|
||
A fast C implementation of the tanh approximation of the GeLU activation function. See
|
||
https://huggingface.co/papers/1606.08415.
|
||
|
||
This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
|
||
match due to rounding errors.
|
||
"""
|
||
|
||
def __init__(self, use_gelu_tanh_python: bool = False):
|
||
super().__init__()
|
||
if use_gelu_tanh_python:
|
||
self.act = self._gelu_tanh_python
|
||
else:
|
||
self.act = functools.partial(nn.functional.gelu, approximate="tanh")
|
||
|
||
def _gelu_tanh_python(self, input: Tensor) -> Tensor:
|
||
return input * 0.5 * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
|
||
|
||
def forward(self, input: Tensor) -> Tensor:
|
||
return self.act(input)
|
||
|
||
|
||
@use_kernel_forward_from_hub("NewGELU")
|
||
class NewGELUActivation(nn.Module):
|
||
"""
|
||
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
|
||
the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415
|
||
"""
|
||
|
||
def forward(self, input: Tensor) -> Tensor:
|
||
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
|
||
|
||
|
||
@use_kernel_forward_from_hub("GeLU")
|
||
class GELUActivation(nn.Module):
|
||
"""
|
||
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
|
||
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
|
||
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
|
||
Also see the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415
|
||
"""
|
||
|
||
def __init__(self, use_gelu_python: bool = False):
|
||
super().__init__()
|
||
if use_gelu_python:
|
||
self.act = self._gelu_python
|
||
else:
|
||
self.act = nn.functional.gelu
|
||
|
||
def _gelu_python(self, input: Tensor) -> Tensor:
|
||
return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
|
||
|
||
def forward(self, input: Tensor) -> Tensor:
|
||
return self.act(input)
|
||
|
||
|
||
@use_kernel_forward_from_hub("SiLU")
|
||
class SiLUActivation(nn.Module):
|
||
"""
|
||
See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
|
||
Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
|
||
Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
|
||
Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
|
||
later.
|
||
"""
|
||
|
||
def forward(self, input: Tensor) -> Tensor:
|
||
return nn.functional.silu(input)
|
||
|
||
|
||
@use_kernel_forward_from_hub("FastGELU")
|
||
class FastGELUActivation(nn.Module):
|
||
"""
|
||
Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
|
||
"""
|
||
|
||
def forward(self, input: Tensor) -> Tensor:
|
||
return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
|
||
|
||
|
||
@use_kernel_forward_from_hub("QuickGELU")
|
||
class QuickGELUActivation(nn.Module):
|
||
"""
|
||
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
|
||
"""
|
||
|
||
def forward(self, input: Tensor) -> Tensor:
|
||
return input * torch.sigmoid(1.702 * input)
|
||
|
||
|
||
class ClippedGELUActivation(nn.Module):
|
||
"""
|
||
Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
|
||
it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
|
||
https://huggingface.co/papers/2004.09602.
|
||
|
||
Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
|
||
initially created.
|
||
|
||
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
|
||
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://huggingface.co/papers/1606.08415
|
||
"""
|
||
|
||
def __init__(self, min: float, max: float):
|
||
if min > max:
|
||
raise ValueError(f"min should be < max (got min: {min}, max: {max})")
|
||
|
||
super().__init__()
|
||
self.min = min
|
||
self.max = max
|
||
|
||
def forward(self, x: Tensor) -> Tensor:
|
||
return torch.clip(gelu(x), self.min, self.max)
|
||
|
||
|
||
class AccurateGELUActivation(nn.Module):
|
||
"""
|
||
Applies GELU approximation that is faster than default and more accurate than QuickGELU. See:
|
||
https://github.com/hendrycks/GELUs
|
||
|
||
Implemented along with MEGA (Moving Average Equipped Gated Attention)
|
||
"""
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.precomputed_constant = math.sqrt(2 / math.pi)
|
||
|
||
def forward(self, input: Tensor) -> Tensor:
|
||
return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3))))
|
||
|
||
|
||
class MishActivation(nn.Module):
|
||
"""
|
||
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://huggingface.co/papers/1908.08681). Also
|
||
visit the official repository for the paper: https://github.com/digantamisra98/Mish
|
||
"""
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.act = nn.functional.mish
|
||
|
||
def _mish_python(self, input: Tensor) -> Tensor:
|
||
return input * torch.tanh(nn.functional.softplus(input))
|
||
|
||
def forward(self, input: Tensor) -> Tensor:
|
||
return self.act(input)
|
||
|
||
|
||
class LinearActivation(nn.Module):
|
||
"""
|
||
Applies the linear activation function, i.e. forwarding input directly to output.
|
||
"""
|
||
|
||
def forward(self, input: Tensor) -> Tensor:
|
||
return input
|
||
|
||
|
||
class LaplaceActivation(nn.Module):
|
||
"""
|
||
Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See
|
||
https://huggingface.co/papers/2209.10655
|
||
|
||
Inspired by squared relu, but with bounded range and gradient for better stability
|
||
"""
|
||
|
||
def forward(self, input, mu=0.707107, sigma=0.282095):
|
||
input = (input - mu).div(sigma * math.sqrt(2.0))
|
||
return 0.5 * (1.0 + torch.erf(input))
|
||
|
||
|
||
class ReLUSquaredActivation(nn.Module):
|
||
"""
|
||
Applies the relu^2 activation introduced in https://huggingface.co/papers/2109.08668v2
|
||
"""
|
||
|
||
def forward(self, input):
|
||
relu_applied = nn.functional.relu(input)
|
||
squared = torch.square(relu_applied)
|
||
return squared
|
||
|
||
|
||
class ClassInstantier(OrderedDict):
|
||
def __getitem__(self, key):
|
||
content = super().__getitem__(key)
|
||
cls, kwargs = content if isinstance(content, tuple) else (content, {})
|
||
return cls(**kwargs)
|
||
|
||
|
||
class XIELUActivation(nn.Module):
|
||
"""
|
||
Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
|
||
|
||
If the user has installed the nickjbrowning/XIELU wheel, we import xIELU CUDA
|
||
Otherwise, we emit a single warning and use xIELU Python
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
alpha_p_init=0.8,
|
||
alpha_n_init=0.8,
|
||
beta=0.5,
|
||
eps=-1e-6,
|
||
dtype=torch.bfloat16,
|
||
with_vector_loads=False,
|
||
):
|
||
super().__init__()
|
||
self.alpha_p = nn.Parameter(torch.log(torch.expm1(torch.tensor(alpha_p_init, dtype=dtype))).unsqueeze(0))
|
||
self.alpha_n = nn.Parameter(
|
||
torch.log(torch.expm1(torch.tensor(alpha_n_init - beta, dtype=dtype))).unsqueeze(0)
|
||
)
|
||
self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
|
||
self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
|
||
self.with_vector_loads = with_vector_loads
|
||
# Temporary until xIELU CUDA fully implemented
|
||
self._beta_scalar = float(self.beta.detach().cpu().float().item())
|
||
self._eps_scalar = float(self.eps.detach().cpu().float().item())
|
||
|
||
self._xielu_cuda_obj = None
|
||
try:
|
||
import xielu.ops # noqa: F401
|
||
|
||
self._xielu_cuda_obj = torch.classes.xielu.XIELU()
|
||
msg = "Using experimental xIELU CUDA."
|
||
try:
|
||
from torch._dynamo import allow_in_graph
|
||
|
||
self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
|
||
msg += " Enabled torch._dynamo for xIELU CUDA."
|
||
except Exception as err:
|
||
msg += f" Could not enable torch._dynamo for xIELU ({err}) - this may result in slower performance."
|
||
self._xielu_cuda_fn = self._xielu_cuda
|
||
logger.warning_once(msg)
|
||
except Exception as err:
|
||
logger.warning_once(
|
||
"CUDA-fused xIELU not available (%s) – falling back to a Python version.\n"
|
||
"For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
|
||
str(err),
|
||
)
|
||
|
||
def _xielu_python(self, x: Tensor) -> Tensor:
|
||
alpha_p = nn.functional.softplus(self.alpha_p)
|
||
alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
|
||
return torch.where(
|
||
x > 0,
|
||
alpha_p * x * x + self.beta * x,
|
||
(torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
|
||
)
|
||
|
||
def _xielu_cuda(self, x: Tensor) -> Tensor:
|
||
"""Firewall function to prevent torch.compile from seeing .item() calls"""
|
||
original_shape = x.shape
|
||
# CUDA kernel expects 3D tensors, reshape if needed
|
||
while x.dim() < 3:
|
||
x = x.unsqueeze(0)
|
||
if x.dim() > 3:
|
||
x = x.view(-1, 1, x.size(-1))
|
||
if original_shape != x.shape:
|
||
logger.warning_once(
|
||
"Warning: xIELU input tensor expects 3 dimensions but got (shape: %s). Reshaping to (shape: %s).",
|
||
original_shape,
|
||
x.shape,
|
||
)
|
||
result = self._xielu_cuda_obj.forward(
|
||
x,
|
||
self.alpha_p.to(x.dtype),
|
||
self.alpha_n.to(x.dtype),
|
||
# Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item()
|
||
self._beta_scalar,
|
||
self._eps_scalar,
|
||
self.with_vector_loads,
|
||
)
|
||
return result.view(original_shape)
|
||
|
||
def forward(self, input: Tensor) -> Tensor:
|
||
if self._xielu_cuda_obj is not None and input.is_cuda:
|
||
if not is_torchdynamo_compiling():
|
||
return self._xielu_cuda_fn(input)
|
||
else:
|
||
logger.warning_once("torch._dynamo is compiling, using Python version of xIELU.")
|
||
return self._xielu_python(input)
|
||
|
||
|
||
ACT2CLS = {
|
||
"gelu": GELUActivation,
|
||
"gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
|
||
"gelu_fast": FastGELUActivation,
|
||
"gelu_new": NewGELUActivation,
|
||
"gelu_python": (GELUActivation, {"use_gelu_python": True}),
|
||
"gelu_pytorch_tanh": GELUTanh,
|
||
"gelu_python_tanh": (GELUTanh, {"use_gelu_tanh_python": True}),
|
||
"gelu_accurate": AccurateGELUActivation,
|
||
"laplace": LaplaceActivation,
|
||
"leaky_relu": nn.LeakyReLU,
|
||
"linear": LinearActivation,
|
||
"mish": MishActivation,
|
||
"quick_gelu": QuickGELUActivation,
|
||
"relu": nn.ReLU,
|
||
"relu2": ReLUSquaredActivation,
|
||
"relu6": nn.ReLU6,
|
||
"sigmoid": nn.Sigmoid,
|
||
"silu": SiLUActivation,
|
||
"swish": nn.SiLU,
|
||
"tanh": nn.Tanh,
|
||
"prelu": nn.PReLU,
|
||
"xielu": XIELUActivation,
|
||
}
|
||
ACT2FN = ClassInstantier(ACT2CLS)
|
||
|
||
|
||
def get_activation(activation_string):
|
||
if activation_string in ACT2FN:
|
||
return ACT2FN[activation_string]
|
||
else:
|
||
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
|
||
|
||
|
||
# For backwards compatibility with: from activations import gelu_python
|
||
gelu_python = get_activation("gelu_python")
|
||
gelu_new = get_activation("gelu_new")
|
||
gelu = get_activation("gelu")
|
||
gelu_fast = get_activation("gelu_fast")
|
||
quick_gelu = get_activation("quick_gelu")
|
||
silu = get_activation("silu")
|
||
mish = get_activation("mish")
|
||
linear_act = get_activation("linear")
|