Source code for ml.models.activations

"""Defines a general-purpose API for activation functions.

.. highlight:: python
.. code-block:: python

    from ml.models.activations import get_activation, cast_activation_type

    model = nn.Sequential(nn.Linear(4, 5), get_activation("relu"))

    # This lets you parametrize the activation function as a string.
    model = nn.Sequential(nn.Linear(4, 5), get_activation(cast_activation_type(my_activation)))

Choices for the activation functions are:

- ``"no_act"``
- ``"relu"``
- ``"relu6"``
- ``"relu2"``
- ``"clamp6"``
- ``"leaky_relu"``
- ``"elu"``
- ``"celu"``
- ``"selu"``
- ``"gelu"``
- ``"gelu_fast"``
- ``"sigmoid"``
- ``"log_sigmoid"``
- ``"hard_sigomid"``
- ``"tanh"``
- ``"softsign"``
- ``"softplus"``
- ``"silu"``
- ``"mish"``
- ``"swish"``
- ``"hard_swish"``
- ``"soft_shrink"``
- ``"hard_shrink"``
- ``"tanh_shrink"``
- ``"soft_sign"``
- ``"relu_squared"``
- ``"laplace"``
"""

import math
from typing import Literal, cast, get_args

import torch
from torch import Tensor, nn

ActivationType = Literal[
    "no_act",
    "relu",
    "relu6",
    "relu2",
    "clamp6",
    "leaky_relu",
    "elu",
    "celu",
    "selu",
    "gelu",
    "gelu_fast",
    "sigmoid",
    "log_sigmoid",
    "hard_sigomid",
    "tanh",
    "softsign",
    "softplus",
    "silu",
    "mish",
    "swish",
    "hard_swish",
    "soft_shrink",
    "hard_shrink",
    "tanh_shrink",
    "soft_sign",
    "relu_squared",
    "laplace",
]


[docs]def cast_activation_type(s: str) -> ActivationType: args = get_args(ActivationType) assert s in args, f"Invalid activation type: '{s}' Valid options are {args}" return cast(ActivationType, s)
[docs]class Clamp(nn.Module): __constants__ = ["min_value", "max_value", "inplace"] def __init__( self, *, value: float | None = None, value_range: tuple[float, float] | None = None, inplace: bool = False, ) -> None: super().__init__() assert (value is None) != (value_range is None), "Exactly one of `value` or `value_range` must be specified." if value is not None: value_range = (-value, value) else: assert value_range is not None self.min_value, self.max_value = value_range self.inplace = inplace assert self.min_value < self.max_value, f"{self.min_value=} >= {self.max_value=}"
[docs] def forward(self, x: Tensor) -> Tensor: return x.clamp_(self.min_value, self.max_value) if self.inplace else x.clamp(self.min_value, self.max_value)
[docs]class Clamp6(Clamp): def __init__(self, inplace: bool = False) -> None: super().__init__(value=6.0, inplace=inplace)
[docs]class ReLUSquared(nn.Module):
[docs] def forward(self, x: Tensor) -> Tensor: relu_applied = nn.functional.relu(x) squared = torch.square(relu_applied) return squared
[docs]class FastGELU(nn.Module):
[docs] def forward(self, x: Tensor) -> Tensor: return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
[docs]class QuickGELU(nn.Module):
[docs] def forward(self, x: Tensor) -> Tensor: return x * torch.sigmoid(1.702 * x)
[docs]class LaplaceActivation(nn.Module): __constants__ = ["mu", "sigma"] def __init__(self, mu: float = 0.707107, sigma: float = 0.282095) -> None: super().__init__() self.mu = mu self.sigma = sigma
[docs] def forward(self, x: Tensor) -> Tensor: x = (x - self.mu).div(self.sigma * math.sqrt(2.0)) return 0.5 * (1.0 + torch.erf(x))
[docs]def get_activation(act: ActivationType, *, inplace: bool = True) -> nn.Module: """Returns an activation function from a keyword string. Args: act: The keyword for the activation function (None for identity) inplace: If set, use the inplace version of the activation function Returns: The activation function as a module Raises: NotImplementedError: If the activation function is invalid """ match act: case "no_act": return nn.Identity() case "relu": return nn.ReLU(inplace=inplace) case "relu2": return nn.ReLU(inplace=inplace) case "relu6": return nn.ReLU6(inplace=inplace) case "clamp6": return Clamp6(inplace=inplace) case "leaky_relu": return nn.LeakyReLU(inplace=inplace) case "elu": return nn.ELU(inplace=inplace) case "celu": return nn.CELU(inplace=inplace) case "selu": return nn.SELU(inplace=inplace) case "gelu": return nn.GELU() case "gelu_fast": return FastGELU() case "gelu_quick": return QuickGELU() case "sigmoid": return nn.Sigmoid() case "log_sigmoid": return nn.LogSigmoid() case "hard_sigomid": return nn.Hardsigmoid(inplace=inplace) case "tanh": return nn.Tanh() case "softsign": return nn.Softsign() case "softplus": return nn.Softplus() case "silu" | "swish": return nn.SiLU() case "mish": return nn.Mish(inplace=inplace) case "hard_swish": return nn.Hardswish(inplace=inplace) case "soft_shrink": return nn.Softshrink() case "hard_shrink": return nn.Hardshrink() case "tanh_shrink": return nn.Tanhshrink() case "soft_sign": return nn.Softsign() case "relu_squared": return ReLUSquared() case "laplace": return LaplaceActivation() case _: raise NotImplementedError(f"Activation function '{act}' is not implemented.")