Source code for ml.models.architectures.rwkv

# mypy: disable-error-code="import, override"
r"""Defines modules for RWKV blocks.

RWKV blocks are similar to Transformer blocks, but use a different attention
mechanism that doesn't require a linearly growing KV cache.

Training requires CUDA kernel requires installing ``triton``:

.. code-block:: bash

    pip install triton
"""

import math
import os
import warnings
from typing import Callable, Literal, cast, get_args

import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.autograd.function import Function, FunctionCtx, once_differentiable

from ml.utils.triton import supports_triton

WkvFnKey = Literal["eps", "log"]

RwkvAttentionState = tuple[Tensor, Tensor]
RwkvFeedForwardState = Tensor
RwkvState = tuple[RwkvAttentionState, RwkvFeedForwardState]

EPS = 1e-4


@torch.jit.script
def wkv_with_eps_forward(w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor) -> tuple[Tensor, Tensor]:
    bsz, tsz, chans = k.shape

    assert w.shape == u.shape == (chans,)
    assert v.shape == (bsz, tsz, chans)
    assert state.shape == (bsz, 3, 1, chans)

    alpha, beta, eps = state[:, :, -1].chunk(3, dim=1)  # (B, 1, D), (B, 1, D), (B, 1, D)

    _, tsz, _ = k.shape

    wkvs = []
    alphas = [alpha]
    betas = [beta]
    epss = [eps]

    for t in range(tsz):
        kt, vt = k[:, t : t + 1], v[:, t : t + 1]
        ukt = u + kt
        tau = torch.maximum(ukt, eps)
        e1 = torch.exp(eps - tau)
        e2 = torch.exp(ukt - tau)
        wkv = (e1 * alpha + e2 * vt) / (e1 * beta + e2)
        wkvs.append(wkv)

        w_eps = eps - w
        eps = torch.maximum(w_eps, kt)
        e1 = torch.exp(w_eps - eps)
        e2 = torch.exp(kt - eps)
        alpha = e1 * alpha + e2 * vt
        beta = e1 * beta + e2

        alphas.append(alpha)
        betas.append(beta)
        epss.append(eps)

    alpha = torch.stack(alphas, dim=2)
    beta = torch.stack(betas, dim=2)
    eps = torch.stack(epss, dim=2)

    return torch.cat(wkvs, 1), torch.cat((alpha, beta, eps), dim=1)


@torch.jit.script
def wkv_with_eps_backward(
    w: Tensor,
    u: Tensor,
    k: Tensor,
    v: Tensor,
    state: Tensor,
    grad_wkv: Tensor,
    grad_state: Tensor,
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
    bsz, tsz, chans = k.shape

    assert w.shape == u.shape == (chans,)
    assert v.shape == (bsz, tsz, chans)
    assert state.shape == (bsz, 3, tsz + 1, chans)
    assert grad_wkv.shape == (bsz, tsz, chans)
    assert grad_state.shape == (bsz, 3, 1, chans)

    alpha, beta, eps = state.chunk(3, dim=1)  # (B, 1, T + 1, D), (B, 1, T + 1, D), (B, 1, T + 1, D)
    grad_alpha, grad_beta, grad_eps = grad_state[:, :, 0].chunk(3, dim=1)  # (B, 1, D), (B, 1, D), (B, 1, D)
    grad_eps = grad_eps.clone()

    grad_w = torch.zeros_like(w)
    grad_u = torch.zeros_like(u)
    grad_k = torch.zeros_like(k)
    grad_v = torch.zeros_like(v)

    for t in range(tsz - 1, -1, -1):
        kt, vt = k[:, t : t + 1], v[:, t : t + 1]
        alpha_prev, beta_prev, eps_prev = alpha[:, :, t], beta[:, :, t], eps[:, :, t]
        alpha_curr, beta_curr, eps_curr = alpha[:, :, t + 1], beta[:, :, t + 1], eps[:, :, t + 1]
        ukt = u + kt
        tau = torch.maximum(ukt, eps_prev)
        e1 = torch.exp(eps_prev - tau)
        e2 = torch.exp(ukt - tau)

        euke = torch.exp(ukt + eps_prev - 2 * tau)

        denom = e1 * beta_prev + e2
        denom_sq = denom * denom

        grad_wkvt = grad_wkv[:, t : t + 1]

        # Backpropagates wkv gradients.
        grad_uk = grad_wkvt * e2 * (e1 * beta_prev * vt - e1 * alpha_prev) / denom_sq
        grad_u += grad_uk.flatten(0, -2).sum(0)
        grad_k[:, t : t + 1] += grad_uk
        grad_v[:, t : t + 1] += grad_wkvt * e2 / denom

        grad_alpha_wkv = grad_wkvt * e1 / denom
        grad_beta_wkv = -grad_wkvt * e1 * (e2 * vt + e1 * alpha_prev) / denom_sq
        grad_eps_wkv = grad_wkvt * euke * (alpha_prev - vt * beta_prev) / (e1 * beta_prev + e2) ** 2

        e1 = torch.exp(eps_prev - eps_curr - w)
        e2 = torch.exp(kt - eps_curr)

        # Backpropagates alpha gradients.
        grad_alpha_we = grad_alpha * e1 * alpha_prev
        grad_w -= grad_alpha_we.flatten(0, -2).sum(0)
        grad_k[:, t : t + 1] += grad_alpha * e2 * vt
        grad_v[:, t : t + 1] += grad_alpha * e2
        grad_eps += grad_alpha * -alpha_curr

        # Backpropagates beta gradients.
        grad_beta_we = grad_beta * e1 * beta_prev
        grad_w -= grad_beta_we.flatten(0, -2).sum(0)
        grad_k[:, t : t + 1] += grad_beta * e2
        grad_eps += grad_beta * -beta_curr

        # Backpropagates epsilon gradients.
        eps_grad_mask = eps_prev - w > kt
        grad_eps_we = torch.where(eps_grad_mask, grad_eps, torch.zeros_like(grad_eps))
        grad_w -= grad_eps_we.flatten(0, -2).sum(0)
        grad_k[:, t : t + 1] += torch.where(eps_grad_mask, torch.zeros_like(grad_eps), grad_eps)

        # Computes gradients for alpha, beta and epsilon.
        grad_alpha = grad_alpha * e1 + grad_alpha_wkv
        grad_beta = grad_beta * e1 + grad_beta_wkv
        grad_eps = grad_alpha_we + grad_beta_we + grad_eps_we + grad_eps_wkv

    return grad_w, grad_u, grad_k, grad_v, torch.stack((grad_alpha, grad_beta, grad_eps), dim=1)


[docs]class WkvWithEps(Function):
[docs] @staticmethod def forward( ctx: FunctionCtx, w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor, ) -> tuple[Tensor, Tensor]: wkv, state_out = wkv_with_eps_forward(w, u, k, v, state) ctx.save_for_backward(w, u, k, v, state_out) return wkv, state_out[:, :, -1:]
[docs] @staticmethod @once_differentiable def backward( ctx: FunctionCtx, grad_wkv: Tensor, grad_state: Tensor, ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: w, u, k, v, state = cast(tuple[Tensor, ...], ctx.saved_tensors) return wkv_with_eps_backward(w, u, k, v, state, grad_wkv, grad_state)
[docs]def initial_state_with_eps(emb_dim: int) -> Tensor: return torch.zeros(1, 3, 1, emb_dim)
[docs]def wkv_with_eps(w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor) -> tuple[Tensor, Tensor]: """Runs the core WKV computation. Args: w: The decay tensor, with shape (D) u: The output multiplier tensor, with shape (D) k: The K tensor, with shape (B, T, D) v: The V tensor, with shape (B, T, D) state: The state tensor, with shape (B, 3, T, D), consisting of the alpha, beta and eps tensors, each with shape (B, 1, T, D) Returns: The WKV tensor, with shape (B, T, D), and the next state, with shape (B, 3, 1, D), consisting of the next alpha, beta and eps tensors, each with shape (B, 1, 1, D) """ return WkvWithEps.apply(w, u, k, v, state)
@torch.jit.script def logaddexp(a: Tensor, b: Tensor) -> Tensor: max_ab = torch.maximum(a, b) return max_ab + torch.log(torch.exp(a - max_ab) + torch.exp(b - max_ab)) @torch.jit.script def logsubexp(a: Tensor, b: Tensor, log_eps: float) -> Tensor: max_ab = torch.clamp_min(torch.maximum(a, b), log_eps) return max_ab + torch.log(torch.exp(a - max_ab) - torch.exp(b - max_ab)) @torch.jit.script def wkv_log_space_forward( w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor, eps: float = EPS, normalize: bool = False, ) -> tuple[Tensor, Tensor]: bsz, tsz, chans = k.shape assert w.shape == u.shape == (chans,) assert v.shape == (bsz, tsz, chans) assert state.shape == (bsz, 3, 1, chans) ln_alpha_p, ln_alpha_m, ln_beta = state[:, :, -1].chunk(3, dim=1) log_eps = math.log(eps) wkvs = [] ln_alpha_ps = [ln_alpha_p] ln_alpha_ms = [ln_alpha_m] ln_betas = [ln_beta] for t in range(tsz): kt, vt = k[:, t : t + 1], v[:, t : t + 1] vt_p, vt_m = torch.clamp_min(vt, 0) + eps, torch.clamp_min(-vt, 0) + eps ln_v_p, ln_v_m = torch.log(vt_p), torch.log(vt_m) if normalize: ln_alpha_pm = torch.minimum(ln_alpha_p, ln_alpha_m) - eps ln_alpha_p = logsubexp(ln_alpha_p, ln_alpha_pm, log_eps) ln_alpha_m = logsubexp(ln_alpha_m, ln_alpha_pm, log_eps) ln_wkv_p = logaddexp(u + kt + ln_v_p, ln_alpha_p) - logaddexp(u + kt, ln_beta) ln_wkv_m = logaddexp(u + kt + ln_v_m, ln_alpha_m) - logaddexp(u + kt, ln_beta) wkv = torch.exp(ln_wkv_p) - torch.exp(ln_wkv_m) wkvs.append(wkv) ln_alpha_p = logaddexp(ln_alpha_p - w, kt + ln_v_p) ln_alpha_m = logaddexp(ln_alpha_m - w, kt + ln_v_m) ln_beta = logaddexp(ln_beta - w, kt) ln_alpha_ps.append(ln_alpha_p) ln_alpha_ms.append(ln_alpha_m) ln_betas.append(ln_beta) ln_alpha_p = torch.stack(ln_alpha_ps, dim=2) ln_alpha_m = torch.stack(ln_alpha_ms, dim=2) ln_beta = torch.stack(ln_betas, dim=2) return torch.cat(wkvs, 1), torch.cat((ln_alpha_p, ln_alpha_m, ln_beta), dim=1) @torch.jit.script def wkv_log_space_backward( w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor, grad_wkv: Tensor, grad_state: Tensor, eps: float = EPS, ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: bsz, tsz, chans = k.shape assert w.shape == u.shape == (chans,) assert v.shape == (bsz, tsz, chans) assert state.shape == (bsz, 3, tsz, chans) assert grad_wkv.shape == (bsz, tsz, chans) assert grad_state.shape == (bsz, 3, 1, chans) grad_ln_alpha_p, grad_ln_alpha_m, grad_ln_beta = grad_state[:, :, 0].chunk(3, dim=1) grad_w = torch.zeros_like(w) grad_u = torch.zeros_like(u) grad_k = torch.zeros_like(k) grad_v = torch.zeros_like(v) for t in range(tsz - 1, -1, -1): kt, vt = k[:, t : t + 1], v[:, t : t + 1] vt_p, vt_m = torch.clamp_min(vt, 0) + eps, torch.clamp_min(-vt, 0) + eps ln_v_p, ln_v_m = torch.log(vt_p), torch.log(vt_m) ln_alpha_p_prev, ln_alpha_m_prev, ln_beta_prev = state[:, :, t].chunk(3, dim=1) uk = u + kt ukv_p, ukv_m = uk + ln_v_p, uk + ln_v_m ukb = logaddexp(uk, ln_beta_prev) wkv_p = torch.exp(logaddexp(ukv_p, ln_alpha_p_prev) - ukb) wkv_m = torch.exp(logaddexp(ukv_m, ln_alpha_m_prev) - ukb) grad_wkvt = grad_wkv[:, t : t + 1] grad_ln_wkv_p, grad_ln_wkv_m = grad_wkvt * wkv_p, grad_wkvt * -wkv_m # Backpropagates wkv gradients. e_num_p = torch.exp(ln_alpha_p_prev - ukv_p) e_num_m = torch.exp(ln_alpha_m_prev - ukv_m) e_den = torch.exp(ln_beta_prev - uk) grad_wkv_den_p = grad_ln_wkv_p / (1 + e_den) grad_wkv_den_m = grad_ln_wkv_m / (1 + e_den) grad_kv_p = grad_ln_wkv_p / (1 + e_num_p) grad_kv_m = grad_ln_wkv_m / (1 + e_num_m) grad_uk = grad_kv_p + grad_kv_m - grad_wkv_den_p - grad_wkv_den_m grad_u += grad_uk.flatten(0, -2).sum(0) grad_k[:, t : t + 1] += grad_uk grad_v[:, t : t + 1] += torch.where(vt > 0, grad_kv_p / vt_p, grad_kv_m / -vt_m) grad_ln_alpha_wkv_p = grad_ln_wkv_p / (1 + (1 / e_num_p)) grad_ln_alpha_wkv_m = grad_ln_wkv_m / (1 + (1 / e_num_m)) grad_ln_beta_wkv = -grad_ln_wkv_p / (1 + (1 / e_den)) - grad_ln_wkv_m / (1 + (1 / e_den)) # Backpropagates alpha gradients. e_alpha_p = torch.exp(kt + ln_v_p + w - ln_alpha_p_prev) e_alpha_m = torch.exp(kt + ln_v_m + w - ln_alpha_m_prev) grad_wa_p = grad_ln_alpha_p / (1 + e_alpha_p) grad_wa_m = grad_ln_alpha_m / (1 + e_alpha_m) grad_w -= (grad_wa_p + grad_wa_m).flatten(0, -2).sum(0) grad_kv_p = grad_ln_alpha_p / (1 + (1 / e_alpha_p)) grad_kv_m = grad_ln_alpha_m / (1 + (1 / e_alpha_m)) grad_k[:, t : t + 1] += grad_kv_p + grad_kv_m grad_v[:, t : t + 1] += torch.where(vt > 0, grad_kv_p / vt_p, -grad_kv_m / vt_m) # Backpropagates beta gradients. e_beta = torch.exp(kt + w - ln_beta_prev) grad_wb = grad_ln_beta / (1 + e_beta) grad_w -= grad_wb.flatten(0, -2).sum(0) grad_k[:, t : t + 1] += grad_ln_beta / (1 + (1 / e_beta)) # Compute gradients for log alpha and log beta. grad_ln_alpha_p = grad_wa_p + grad_ln_alpha_wkv_p grad_ln_alpha_m = grad_wa_m + grad_ln_alpha_wkv_m grad_ln_beta = grad_wb + grad_ln_beta_wkv return grad_w, grad_u, grad_k, grad_v, torch.stack((grad_ln_alpha_p, grad_ln_alpha_m, grad_ln_beta), dim=1)
[docs]class WkvLogSpace(Function):
[docs] @staticmethod def forward( ctx: FunctionCtx, w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor, ) -> tuple[Tensor, Tensor]: wkv, state_out = wkv_log_space_forward(w, u, k, v, state) ctx.save_for_backward(w, u, k, v, state_out[:, :, :-1]) return wkv, state_out[:, :, -1:]
[docs] @staticmethod @once_differentiable def backward( ctx: FunctionCtx, grad_wkv: Tensor, grad_state: Tensor, ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: w, u, k, v, state = cast(tuple[Tensor, ...], ctx.saved_tensors) return wkv_log_space_backward(w, u, k, v, state, grad_wkv, grad_state)
[docs]def initial_state_log_space(emb_dim: int) -> Tensor: return torch.full((1, 3, 1, emb_dim), float("-inf"))
[docs]def wkv_log_space(w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor) -> tuple[Tensor, Tensor]: """Runs the core WKV computation. Args: w: The decay tensor, with shape (D) u: The output multiplier tensor, with shape (D) k: The K tensor, with shape (B, T, D) v: The V tensor, with shape (B, T, D) state: The state tensor, with shape (B, 3, D), consisting of the alpha plus, alpha minus and beta tensors, each with shape (B, 1, D) Returns: The WKV tensor, with shape (B, T, D), and the next state, with shape (B, 2, D), consisting of the next alpha plus, alpha minus and beta tensors, each with shape (B, 1, D) """ return WkvLogSpace.apply(w, u, k, v, state)
[docs]def get_wkv_fn(key: WkvFnKey) -> Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], tuple[Tensor, Tensor]]: match key: case "eps": return wkv_with_eps case "log": return wkv_log_space case _: raise ValueError(f"Unsupported key: {key}")
[docs]def get_wkv_fn_cuda(key: WkvFnKey) -> Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], tuple[Tensor, Tensor]]: if not supports_triton(): return get_wkv_fn(key) from ml.utils.triton.rwkv import wkv_triton_log_space, wkv_triton_with_eps match key: case "eps": return wkv_triton_with_eps case "log": return wkv_triton_log_space case _: raise ValueError(f"Unsupported key: {key}")
[docs]def get_default_wkv_fn_key() -> WkvFnKey: if "WKV_FN" in os.environ: assert (wkv_fn_str := os.environ["WKV_FN"]) in get_args(WkvFnKey), f"Unsupported WKV_FN: {wkv_fn_str}" return cast(WkvFnKey, wkv_fn_str) warnings.warn("Using default WKV_FN: eps") return "eps"
[docs]class RwkvAttention(nn.Module): init_x: Tensor init_state: Tensor def __init__(self, dim: int, wkv_key: WkvFnKey | None = None) -> None: super().__init__() self.time_decay = nn.Parameter(torch.ones(dim)) self.time_first = nn.Parameter(torch.ones(dim)) self.time_mix_k = nn.Parameter(torch.ones(1, 1, dim)) self.time_mix_v = nn.Parameter(torch.ones(1, 1, dim)) self.time_mix_r = nn.Parameter(torch.ones(1, 1, dim)) self.key = nn.Linear(dim, dim, False) self.value = nn.Linear(dim, dim, False) self.receptance = nn.Linear(dim, dim, False) self.output = nn.Linear(dim, dim, False) if wkv_key is None: wkv_key = get_default_wkv_fn_key() self.wkv_fn = get_wkv_fn(wkv_key) self.wkv_fn_cuda = get_wkv_fn_cuda(wkv_key) self.register_buffer("init_x", torch.zeros(1, 1, dim), persistent=False) self.register_buffer("init_state", initial_state_with_eps(dim), persistent=False)
[docs] def time_shift(self, last_x: Tensor, x: Tensor) -> Tensor: _, tsz, _ = x.shape if tsz > 1: last_x = torch.cat((last_x, x[..., :-1, :]), dim=-2) return last_x
[docs] def forward(self, x: Tensor, state: RwkvAttentionState | None) -> tuple[Tensor, RwkvAttentionState]: bsz, _, _ = x.shape if state is None: last_x = self.init_x.repeat_interleave(bsz, dim=0) last_state = self.init_state.repeat_interleave(bsz, dim=0) else: last_x, last_state = state last_x = self.time_shift(last_x, x) k = self.key(x * self.time_mix_k + last_x * (1 - self.time_mix_k)) v = self.value(x * self.time_mix_v + last_x * (1 - self.time_mix_v)) r = self.receptance(x * self.time_mix_r + last_x * (1 - self.time_mix_r)) sr = torch.sigmoid(r) w, u = self.time_decay, self.time_first w = torch.exp(w) wkv_fn = self.wkv_fn_cuda if x.is_cuda else self.wkv_fn wkv, next_state = wkv_fn(w, u, k, v, last_state) rwkv = wkv * sr return self.output(rwkv), (x[..., -1:, :], next_state)
[docs]class RwkvFeedForward(nn.Module): init_state: Tensor def __init__(self, dim: int, ffn_dim: int) -> None: super().__init__() self.time_mix_k = nn.Parameter(torch.ones(1, 1, dim)) self.time_mix_r = nn.Parameter(torch.ones(1, 1, dim)) self.key = nn.Linear(dim, ffn_dim, False) self.receptance = nn.Linear(dim, dim, False) self.value = nn.Linear(ffn_dim, dim, False) self.register_buffer("init_state", torch.zeros(1, 1, dim), persistent=False)
[docs] def time_shift(self, last_x: Tensor, x: Tensor) -> Tensor: _, tsz, _ = x.shape if tsz > 1: last_x = torch.cat((last_x, x[..., :-1, :]), dim=-2) return last_x
[docs] def forward(self, x: Tensor, state: RwkvFeedForwardState | None = None) -> tuple[Tensor, RwkvFeedForwardState]: bsz = x.shape[0] last_x = self.time_shift(self.init_state.repeat(bsz, 1, 1) if state is None else state, x) k = self.key(x * self.time_mix_k + last_x * (1 - self.time_mix_k)) r = self.receptance(x * self.time_mix_r + last_x * (1 - self.time_mix_r)) vk = self.value(F.relu(k) ** 2) return torch.sigmoid(r) * vk, x[..., -1:, :]
[docs]class RwkvBlock(nn.Module): def __init__(self, emb_dim: int, pre_norm: bool, wkv_key: WkvFnKey | None = None) -> None: super().__init__() self.ln0 = nn.LayerNorm(emb_dim) if pre_norm else None self.ln1 = nn.LayerNorm(emb_dim) self.ln2 = nn.LayerNorm(emb_dim) self.att = RwkvAttention(emb_dim, wkv_key=wkv_key) self.ffn = RwkvFeedForward(emb_dim, emb_dim * 4)
[docs] def run_attn(self, x: Tensor, state: RwkvState | None = None) -> tuple[Tensor, RwkvAttentionState]: return self.att.forward(self.ln1(x), None if state is None else state[0])
[docs] def run_ffn(self, x: Tensor, state: RwkvState | None = None) -> tuple[Tensor, RwkvFeedForwardState]: return self.ffn.forward(self.ln2(x), None if state is None else state[1])
[docs] def forward(self, x: Tensor, state: RwkvState | None = None) -> tuple[Tensor, RwkvState]: if self.ln0 is not None: x = self.ln0(x) dx, att_state_out = self.run_attn(x, state) x = x + dx dx, ffn_state_out = self.run_ffn(x, state) x = x + dx return x, (att_state_out, ffn_state_out)
[docs]class RwkvStack(nn.Module): """Defines a stack of RWKV modules. Parameters: emb_dim: The number of embedding dimensions in each block num_layers: The number of layers in the stack wkv_key: The WKV algorithm to use Inputs: x: The input tensor, with shape ``(B, T, D)`` state: The previous state Outputs: The output tensor, with shape ``(B, T, D)``, and the next state """ def __init__(self, emb_dim: int, num_layers: int, wkv_key: WkvFnKey | None = None) -> None: super().__init__() self.blocks = nn.ModuleList( [ RwkvBlock( emb_dim, pre_norm=i == 0, wkv_key=wkv_key, ) for i in range(num_layers) ] )
[docs] def forward(self, x: Tensor, state: list[RwkvState] | None = None) -> tuple[Tensor, list[RwkvState]]: state_out: list[RwkvState] = [] for i, block in enumerate(self.blocks): x, state_out_i = block(x, None if state is None else state[i]) state_out.append(state_out_i) return x, state_out