Source code for ml.models.lora

"""Helper utilities for using LoRA layers.

LoRA layers are drop-in replacements for certain modules, which can be used
for fine-tuning pre-trained models. It is described in the paper
`LoRA: Low-Rank Adaptation of Large Language Models <https://arxiv.org/abs/2106.09685>`_.

.. highlight:: python
.. code-block:: python

    from ml.models.lora import lora

    # The pre-trained model weights can be loaded into the LoRA model.
    model = nn.Sequential(nn.Linear(5, 7), nn.Linear(7, 5))
    lora_model = nn.Sequential(lora(nn.Linear(5, 7)), lora(nn.Linear(7, 5)))
    lora_model.load_state_dict(model.state_dict())  # No errors

    from ml.models.lora import LoRALinear

    # Alternatively, you can just substitute the module name.
    model = nn.Sequential(LoRALinear(5, 7), LoRALinear(7, 5))

The modules which can be wrapped with LoRA modules are:

- ``nn.Embedding``
- ``nn.Linear``
- ``nn.Conv1d``
- ``nn.ConvTranspose1d``
- ``nn.Conv2d``
- ``nn.ConvTranspose2d``
- ``nn.LSTM``
- ``nn.GRU``
- ``ColumnParallelLinear``
- ``RowParallelLinear``
- ``ParallelEmbedding``

In the paper, the authors typically use values of 1, 2, 4, or 8 for the
``r`` parameter. The ``lora_alpha`` parameter is typically set to 1.0, but
can be tuned to improve performance.
"""

import math
import warnings
import weakref
from abc import abstractmethod
from typing import Any, TypeVar, Union, cast, overload

import torch
import torch.nn.functional as F
from torch import _VF, Tensor, nn
from torch.nn.modules.module import _IncompatibleKeys

from ml.models.init import InitializationType
from ml.models.parallel import (
    ColumnParallelLinear,
    ParallelEmbedding,
    RowParallelLinear,
    mp_copy,
    mp_gather,
    mp_reduce,
    mp_scatter,
)

T = TypeVar("T")

SupportedModuleNonParallel = Union[
    nn.Embedding,
    nn.Linear,
    nn.Conv1d,
    nn.ConvTranspose1d,
    nn.Conv2d,
    nn.ConvTranspose2d,
    nn.LSTM,
    nn.GRU,
    nn.LSTMCell,
    nn.GRUCell,
]

SupportedModule = Union[
    SupportedModuleNonParallel,
    ColumnParallelLinear,
    RowParallelLinear,
    ParallelEmbedding,
]


def _lora_post_hook(module: "_Lora", incompatible_keys: _IncompatibleKeys) -> None:
    lora_keys = [k for k in incompatible_keys.missing_keys if k.split(".")[-1].startswith("lora_")]
    for lora_key in lora_keys:
        incompatible_keys.missing_keys.remove(lora_key)


class _Lora(nn.Module):
    def __init__(self, *args: Any, **kwargs: Any) -> None:  # noqa: ANN401
        super().__init__(*args, **kwargs)

        # This allows modules to use LoRA layers as drop-in replacements for
        # non-LoRA pretrained models without throwing annoying errors for
        # state dict incompatibility.
        self.register_load_state_dict_post_hook(_lora_post_hook)

    @abstractmethod
    def reset_lora_parameters(self) -> None:
        """Resets LoRA parameters in-place."""


[docs]class LoraEmbedding(nn.Embedding, _Lora): __constants__ = nn.Embedding.__constants__ + ["r", "lora_alpha", "scaling", "merge", "merged"] def __init__( self, num_embeddings: int, embedding_dim: int, r: int, lora_alpha: float = 1.0, lora_dropout: float = 0.0, merge: bool = False, padding_idx: int | None = None, max_norm: float | None = None, norm_type: float = 2.0, scale_grad_by_freq: bool = False, sparse: bool = False, ) -> None: super().__init__( num_embeddings, embedding_dim, padding_idx=padding_idx, max_norm=max_norm, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, sparse=sparse, ) assert r > 0 self.r = r self.lora_alpha = lora_alpha self.scaling = self.lora_alpha / self.r self.merge = merge self.dropout = nn.Identity() if lora_dropout == 0.0 else nn.Dropout(p=lora_dropout) self.merged = False self.lora_a = nn.Parameter(self.weight.new_empty((r, num_embeddings))) self.lora_b = nn.Parameter(self.weight.new_empty((embedding_dim, r))) self.weight.requires_grad_(False) self.reset_parameters()
[docs] def reset_parameters(self) -> None: super().reset_parameters() if hasattr(self, "lora_a") and hasattr(self, "lora_b"): self.reset_lora_parameters()
[docs] def reset_lora_parameters(self) -> None: nn.init.kaiming_normal_(self.lora_a, a=math.sqrt(5)) nn.init.zeros_(self.lora_b)
[docs] def train(self, mode: bool = True) -> "LoraEmbedding": super().train(mode) if mode: if self.merge and self.merged: # Make sure that the weights are not merged if self.lora_a is not None and self.lora_b is not None: self.weight.data -= (self.lora_b @ self.lora_a).transpose(0, 1) * self.scaling self.merged = False elif self.merge and not self.merged: # Merge the weights and mark it if self.lora_a is not None and self.lora_b is not None: self.weight.data += (self.lora_b @ self.lora_a).transpose(0, 1) * self.scaling self.merged = True return self
[docs] def forward(self, x: Tensor) -> Tensor: if self.lora_a is not None and self.lora_b is not None and not self.merged: result = super().forward(x) after_a = F.embedding( x, self.lora_a.transpose(0, 1), self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse, ) return result + (after_a @ self.lora_b.transpose(0, 1)) * self.scaling return super().forward(x)
[docs]class LoraLinear(nn.Linear, _Lora): __constants__ = nn.Linear.__constants__ + ["r", "lora_alpha", "scaling", "merge", "fan_in_fan_out", "merged"] def __init__( self, in_features: int, out_features: int, r: int, lora_alpha: float = 1.0, lora_dropout: float = 0.0, fan_in_fan_out: bool = False, merge: bool = False, bias: bool = True, ) -> None: super().__init__( in_features, out_features, bias=bias, ) assert r > 0 self.r = r self.lora_alpha = lora_alpha self.scaling = self.lora_alpha / self.r self.merge = merge self.fan_in_fan_out = fan_in_fan_out self.dropout = nn.Identity() if lora_dropout == 0.0 else nn.Dropout(p=lora_dropout) self.merged = False self.lora_a = nn.Parameter(self.weight.new_empty((r, in_features))) self.lora_b = nn.Parameter(self.weight.new_empty((out_features, r))) self.weight.requires_grad_(False) self.reset_parameters() if fan_in_fan_out: self.weight.data = self.weight.data.transpose(0, 1)
[docs] def reset_parameters(self) -> None: super().reset_parameters() if hasattr(self, "lora_a") and hasattr(self, "lora_b"): self.reset_lora_parameters()
[docs] def reset_lora_parameters(self) -> None: nn.init.kaiming_normal_(self.lora_a, a=math.sqrt(5)) nn.init.zeros_(self.lora_b)
def _t(self, w: Tensor) -> Tensor: return w.transpose(0, 1) if self.fan_in_fan_out else w
[docs] def train(self, mode: bool = True) -> "LoraLinear": super().train(mode) if mode: if self.merge and self.merged: # Make sure that the weights are not merged if self.lora_a is not None and self.lora_b is not None: self.weight.data -= self._t(self.lora_b @ self.lora_a) * self.scaling self.merged = False elif self.merge and not self.merged: # Merge the weights and mark it if self.lora_a is not None and self.lora_b is not None: self.weight.data += self._t(self.lora_b @ self.lora_a) * self.scaling self.merged = True return self
[docs] def forward(self, x: Tensor) -> Tensor: if self.lora_a is not None and self.lora_b is not None and not self.merged: result = F.linear(x, self._t(self.weight), bias=self.bias) mm = self.dropout(x) @ self.lora_a.transpose(0, 1) @ self.lora_b.transpose(0, 1) return result + mm * self.scaling return F.linear(x, self._t(self.weight), bias=self.bias)
[docs]class LoraConv1d(nn.Conv1d, _Lora): __constants__ = nn.Conv1d.__constants__ + ["r", "lora_alpha", "scaling", "merge", "merged"] def __init__( self, in_channels: int, out_channels: int, kernel_size: int | tuple[int], r: int, lora_alpha: float = 1.0, lora_dropout: float = 0.0, merge: bool = False, stride: int | tuple[int] = 1, padding: str | int | tuple[int] = 0, dilation: int | tuple[int] = 1, groups: int = 1, bias: bool = True, ) -> None: super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, ) assert r > 0 self.r = r self.lora_alpha = lora_alpha self.scaling = self.lora_alpha / self.r self.merge = merge self.dropout = nn.Identity() if lora_dropout == 0.0 else nn.Dropout(p=lora_dropout) self.merged = False self.lora_a = nn.Parameter(self.weight.new_empty((r, in_channels, *self.kernel_size))) self.lora_b = nn.Parameter(self.weight.new_empty((out_channels, r, 1))) self.weight.requires_grad_(False) self.reset_parameters()
[docs] def reset_parameters(self) -> None: super().reset_parameters() if hasattr(self, "lora_a") and hasattr(self, "lora_b"): self.reset_lora_parameters()
[docs] def reset_lora_parameters(self) -> None: nn.init.kaiming_normal_(self.lora_a, a=math.sqrt(5)) nn.init.zeros_(self.lora_b)
[docs] def train(self, mode: bool = True) -> "LoraConv1d": super().train(mode) if mode: if self.merge and self.merged: # Make sure that the weights are not merged if self.lora_a is not None and self.lora_b is not None: self.weight.data -= self.lora_b @ self.lora_a * self.scaling self.merged = False elif self.merge and not self.merged: # Merge the weights and mark it if self.lora_a is not None and self.lora_b is not None: self.weight.data += self.lora_b @ self.lora_a * self.scaling self.merged = True return self
[docs] def forward(self, x: Tensor) -> Tensor: if self.lora_a is not None and self.lora_b is not None and not self.merged: result = F.conv1d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) mm_a = F.conv1d(self.dropout(x), self.lora_a, None, self.stride, self.padding, self.dilation, self.groups) mm = F.conv1d(mm_a, self.lora_b) return result + mm * self.scaling return F.conv1d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
[docs]class LoraConvTranspose1d(nn.ConvTranspose1d, _Lora): __constants__ = nn.ConvTranspose1d.__constants__ + ["r", "lora_alpha", "scaling", "merge", "merged"] def __init__( self, in_channels: int, out_channels: int, kernel_size: int | tuple[int], r: int, lora_alpha: float = 1.0, lora_dropout: float = 0.0, merge: bool = False, stride: int | tuple[int] = 1, padding: int | tuple[int] = 0, output_padding: int | tuple[int] = 0, dilation: int | tuple[int] = 1, groups: int = 1, bias: bool = True, ) -> None: super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding, dilation=dilation, groups=groups, bias=bias, ) assert r > 0 self.r = r self.lora_alpha = lora_alpha self.scaling = self.lora_alpha / self.r self.merge = merge self.dropout = nn.Identity() if lora_dropout == 0.0 else nn.Dropout(p=lora_dropout) self.merged = False self.lora_a = nn.Parameter(self.weight.new_empty((in_channels, r, *self.kernel_size))) self.lora_b = nn.Parameter(self.weight.new_empty((r, out_channels, 1))) self.weight.requires_grad_(False) self.reset_parameters()
[docs] def reset_parameters(self) -> None: super().reset_parameters() if hasattr(self, "lora_a") and hasattr(self, "lora_b"): self.reset_lora_parameters()
[docs] def reset_lora_parameters(self) -> None: nn.init.kaiming_normal_(self.lora_a, a=math.sqrt(5)) nn.init.zeros_(self.lora_b)
[docs] def train(self, mode: bool = True) -> "LoraConvTranspose1d": super().train(mode) if mode: if self.merge and self.merged: # Make sure that the weights are not merged if self.lora_a is not None and self.lora_b is not None: self.weight.data -= self.lora_b @ self.lora_a * self.scaling self.merged = False elif self.merge and not self.merged: # Merge the weights and mark it if self.lora_a is not None and self.lora_b is not None: self.weight.data += self.lora_b @ self.lora_a * self.scaling self.merged = True return self
[docs] def forward(self, x: Tensor, output_size: list[int] | None = None) -> Tensor: assert isinstance(self.padding, tuple) if self.lora_a is not None and self.lora_b is not None and not self.merged: result = F.conv_transpose1d( x, self.weight, self.bias, self.stride, self.padding, self.output_padding, self.groups, self.dilation, ) mm_a = F.conv_transpose1d( self.dropout(x), self.lora_a, None, self.stride, self.padding, self.output_padding, self.groups, self.dilation, ) mm = F.conv_transpose1d(mm_a, self.lora_b) return result + mm * self.scaling return F.conv_transpose1d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
[docs]class LoraConv2d(nn.Conv2d, _Lora): __constants__ = nn.Conv2d.__constants__ + ["r", "lora_alpha", "scaling", "merge", "merged"] def __init__( self, in_channels: int, out_channels: int, kernel_size: int | tuple[int, int], r: int, lora_alpha: float = 1.0, lora_dropout: float = 0.0, merge: bool = False, stride: int | tuple[int, int] = (1, 1), padding: str | int | tuple[int, int] = (0, 0), dilation: int | tuple[int, int] = (1, 1), groups: int = 1, bias: bool = True, ) -> None: super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, ) assert r > 0 self.r = r self.lora_alpha = lora_alpha self.scaling = self.lora_alpha / self.r self.merge = merge self.dropout = nn.Identity() if lora_dropout == 0.0 else nn.Dropout(p=lora_dropout) self.merged = False self.lora_a = nn.Parameter(self.weight.new_empty((r, in_channels, *self.kernel_size))) self.lora_b = nn.Parameter(self.weight.new_empty((out_channels, r, 1, 1))) self.weight.requires_grad_(False) self.reset_parameters()
[docs] def reset_parameters(self) -> None: super().reset_parameters() if hasattr(self, "lora_a") and hasattr(self, "lora_b"): self.reset_lora_parameters()
[docs] def reset_lora_parameters(self) -> None: nn.init.kaiming_normal_(self.lora_a, a=math.sqrt(5)) nn.init.zeros_(self.lora_b)
[docs] def train(self, mode: bool = True) -> "LoraConv2d": super().train(mode) if mode: if self.merge and self.merged: # Make sure that the weights are not merged if self.lora_a is not None and self.lora_b is not None: self.weight.data -= self.lora_b @ self.lora_a * self.scaling self.merged = False elif self.merge and not self.merged: # Merge the weights and mark it if self.lora_a is not None and self.lora_b is not None: self.weight.data += self.lora_b @ self.lora_a * self.scaling self.merged = True return self
[docs] def forward(self, x: Tensor) -> Tensor: if self.lora_a is not None and self.lora_b is not None and not self.merged: result = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) mm_a = F.conv2d(self.dropout(x), self.lora_a, None, self.stride, self.padding, self.dilation, self.groups) mm = F.conv2d(mm_a, self.lora_b) return result + mm * self.scaling return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
[docs]class LoraConvTranspose2d(nn.ConvTranspose2d, _Lora): __constants__ = nn.ConvTranspose2d.__constants__ + ["r", "lora_alpha", "scaling", "merge", "merged"] def __init__( self, in_channels: int, out_channels: int, kernel_size: int | tuple[int, int], r: int, lora_alpha: float = 1.0, lora_dropout: float = 0.0, merge: bool = False, stride: int | tuple[int, int] = (1, 1), padding: int | tuple[int, int] = (0, 0), output_padding: int | tuple[int, int] = (0, 0), dilation: int | tuple[int, int] = (1, 1), groups: int = 1, bias: bool = True, ) -> None: super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding, dilation=dilation, groups=groups, bias=bias, ) assert r > 0 self.r = r self.lora_alpha = lora_alpha self.scaling = self.lora_alpha / self.r self.merge = merge self.dropout = nn.Identity() if lora_dropout == 0.0 else nn.Dropout(p=lora_dropout) self.merged = False self.lora_a = nn.Parameter(self.weight.new_empty((in_channels, r, *self.kernel_size))) self.lora_b = nn.Parameter(self.weight.new_empty((r, out_channels, 1, 1))) self.weight.requires_grad_(False) self.reset_parameters()
[docs] def reset_parameters(self) -> None: super().reset_parameters() if hasattr(self, "lora_a") and hasattr(self, "lora_b"): self.reset_lora_parameters()
[docs] def reset_lora_parameters(self) -> None: nn.init.kaiming_normal_(self.lora_a, a=math.sqrt(5)) nn.init.zeros_(self.lora_b)
[docs] def train(self, mode: bool = True) -> "LoraConvTranspose2d": super().train(mode) if mode: if self.merge and self.merged: # Make sure that the weights are not merged if self.lora_a is not None and self.lora_b is not None: self.weight.data -= self.lora_b @ self.lora_a * self.scaling self.merged = False elif self.merge and not self.merged: # Merge the weights and mark it if self.lora_a is not None and self.lora_b is not None: self.weight.data += self.lora_b @ self.lora_a * self.scaling self.merged = True return self
[docs] def forward(self, x: Tensor, output_size: list[int] | None = None) -> Tensor: assert isinstance(self.padding, tuple) if self.lora_a is not None and self.lora_b is not None and not self.merged: result = F.conv_transpose2d( x, self.weight, self.bias, self.stride, self.padding, self.output_padding, self.groups, self.dilation, ) mm_a = F.conv_transpose2d( self.dropout(x), self.lora_a, None, self.stride, self.padding, self.output_padding, self.groups, self.dilation, ) mm = F.conv_transpose2d(mm_a, self.lora_b) return result + mm * self.scaling return F.conv_transpose2d( x, self.weight, self.bias, self.stride, self.padding, self.output_padding, self.groups, self.dilation, )
class _LoraRNN(nn.RNNBase, _Lora): __constants__ = nn.RNNBase.__constants__ + ["r", "lora_alpha", "scaling"] def __init__( self, mode: str, input_size: int, hidden_size: int, gate_mul: int, r: int, lora_alpha: float = 1.0, num_layers: int = 1, bias: bool = True, batch_first: bool = False, dropout: float = 0.0, bidirectional: bool = False, proj_size: int = 0, ) -> None: super().__init__( mode=mode, input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional, proj_size=proj_size, ) assert r > 0 self.r = r self.lora_alpha = lora_alpha self.scaling = self.lora_alpha / self.r num_directions = 2 if bidirectional else 1 gate_size = gate_mul * hidden_size for layer in range(num_layers): for direction in range(num_directions): real_hidden_size = proj_size if proj_size > 0 else hidden_size layer_input_size = input_size if layer == 0 else real_hidden_size * num_directions suffix = "_reverse" if direction == 1 else "" w_ih: Tensor = getattr(self, f"weight_ih_l{layer}{suffix}") w_hh: Tensor = getattr(self, f"weight_hh_l{layer}{suffix}") w_ih.requires_grad_(False) w_hh.requires_grad_(False) lora_a_ih = nn.Parameter(w_ih.new_empty((r, gate_size))) lora_b_ih = nn.Parameter(w_ih.new_empty((layer_input_size, r))) lora_a_hh = nn.Parameter(w_hh.new_empty((r, gate_size))) lora_b_hh = nn.Parameter(w_hh.new_empty((real_hidden_size, r))) setattr(self, f"lora_a_ih_l{layer}{suffix}", lora_a_ih) setattr(self, f"lora_b_ih_l{layer}{suffix}", lora_b_ih) setattr(self, f"lora_a_hh_l{layer}{suffix}", lora_a_hh) setattr(self, f"lora_b_hh_l{layer}{suffix}", lora_b_hh) if self.proj_size != 0: w_hr: Tensor = getattr(self, f"weight_hr_l{layer}{suffix}") w_hr.requires_grad_(False) lora_a_hr = nn.Parameter(w_hr.new_empty((r, proj_size))) lora_b_hr = nn.Parameter(w_hr.new_empty((hidden_size, r))) setattr(self, f"lora_a_hr_l{layer}{suffix}", lora_a_hr) setattr(self, f"lora_b_hr_l{layer}{suffix}", lora_b_hr) self._init_flat_weights() self.reset_parameters() def _lora_names(self, weight_name: str) -> tuple[str, str]: weight_name = weight_name[len("weight_") :] lora_a_name, lora_b_name = f"lora_a_{weight_name}", f"lora_b_{weight_name}" return lora_a_name, lora_b_name def _get_weight(self, weight_name: str) -> Tensor: weight = getattr(self, weight_name) if weight_name.startswith("bias_"): return weight lora_a_name, lora_b_name = self._lora_names(weight_name) if not hasattr(self, lora_a_name) or not hasattr(self, lora_b_name): return weight lora_a, lora_b = getattr(self, lora_a_name), getattr(self, lora_b_name) return weight + (lora_a.transpose(0, 1) @ lora_b.transpose(0, 1)) * self.scaling def _init_flat_weights(self) -> None: self._flat_weights = [self._get_weight(wn) if hasattr(self, wn) else None for wn in self._flat_weights_names] self._flat_weight_refs = [weakref.ref(w) if w is not None else None for w in self._flat_weights] self.flatten_parameters() def reset_parameters(self) -> None: super().reset_parameters() self.reset_lora_parameters() def reset_lora_parameters(self) -> None: for wn in self._flat_weights_names: lora_a_name, lora_b_name = self._lora_names(wn) if hasattr(self, lora_a_name) and hasattr(self, lora_b_name): lora_a, lora_b = getattr(self, lora_a_name), getattr(self, lora_b_name) nn.init.kaiming_normal_(lora_a, a=math.sqrt(5)) nn.init.zeros_(lora_b)
[docs]class LoraLSTM(nn.LSTM, _LoraRNN): def __init__( self, input_size: int, hidden_size: int, r: int, lora_alpha: float = 1.0, num_layers: int = 1, bias: bool = True, batch_first: bool = False, dropout: float = 0.0, bidirectional: bool = False, proj_size: int = 0, ) -> None: _LoraRNN.__init__( self, mode="LSTM", input_size=input_size, hidden_size=hidden_size, gate_mul=4, r=r, lora_alpha=lora_alpha, num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional, proj_size=proj_size, )
[docs]class LoraGRU(nn.GRU, _LoraRNN): def __init__( self, input_size: int, hidden_size: int, r: int, lora_alpha: float = 1.0, num_layers: int = 1, bias: bool = True, batch_first: bool = False, dropout: float = 0.0, bidirectional: bool = False, proj_size: int = 0, ) -> None: _LoraRNN.__init__( self, mode="GRU", input_size=input_size, hidden_size=hidden_size, gate_mul=3, r=r, lora_alpha=lora_alpha, num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional, proj_size=proj_size, )
class _LoraRNNCellBase(nn.RNNCellBase, _Lora): __constants__ = nn.RNNCell.__constants__ + ["r", "lora_alpha", "scaling"] def __init__( self, input_size: int, hidden_size: int, bias: bool, num_chunks: int, r: int, lora_alpha: float = 1.0, ) -> None: super().__init__(input_size=input_size, hidden_size=hidden_size, bias=bias, num_chunks=num_chunks) self.r = r self.lora_alpha = lora_alpha self.scaling = self.lora_alpha / self.r self.lora_a_ih = nn.Parameter(self.weight_ih.new_empty((r, input_size))) self.lora_b_ih = nn.Parameter(self.weight_ih.new_empty((hidden_size * num_chunks, r))) self.lora_a_hh = nn.Parameter(self.weight_hh.new_empty((r, hidden_size))) self.lora_b_hh = nn.Parameter(self.weight_hh.new_empty((hidden_size * num_chunks, r))) self.weight_ih.requires_grad_(False) self.weight_hh.requires_grad_(False) def reset_parameters(self) -> None: super().reset_parameters() self.reset_lora_parameters() def reset_lora_parameters(self) -> None: if hasattr(self, "lora_a_ih") and hasattr(self, "lora_b_ih"): nn.init.kaiming_normal_(self.lora_a_ih, a=math.sqrt(5)) nn.init.zeros_(self.lora_b_ih) if hasattr(self, "lora_a_hh") and hasattr(self, "lora_b_hh"): nn.init.kaiming_normal_(self.lora_a_hh, a=math.sqrt(5)) nn.init.zeros_(self.lora_b_hh)
[docs]class LoraLSTMCell(nn.LSTMCell, _LoraRNNCellBase): def __init__( self, input_size: int, hidden_size: int, r: int, bias: bool = True, lora_alpha: float = 1.0, ) -> None: _LoraRNNCellBase.__init__( self, input_size=input_size, hidden_size=hidden_size, bias=bias, num_chunks=4, r=r, lora_alpha=lora_alpha, )
[docs] def forward(self, input: Tensor, hx: tuple[Tensor, Tensor] | None = None) -> tuple[Tensor, Tensor]: assert input.dim() in (1, 2), f"LSTMCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" is_batched = input.dim() == 2 if not is_batched: input = input.unsqueeze(0) if hx is None: zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) hx = (zeros, zeros) else: hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx lora_ih = (self.lora_b_ih @ self.lora_a_ih) * self.scaling lora_hh = (self.lora_b_hh @ self.lora_a_hh) * self.scaling ret = _VF.lstm_cell(input, hx, self.weight_ih + lora_ih, self.weight_hh + lora_hh, self.bias_ih, self.bias_hh) if not is_batched: ret = (ret[0].squeeze(0), ret[1].squeeze(0)) return ret
[docs]class LoraGRUCell(nn.GRUCell, _LoraRNNCellBase): def __init__( self, input_size: int, hidden_size: int, r: int, bias: bool = True, lora_alpha: float = 1.0, ) -> None: _LoraRNNCellBase.__init__( self, input_size=input_size, hidden_size=hidden_size, bias=bias, num_chunks=3, r=r, lora_alpha=lora_alpha, )
[docs] def forward(self, input: Tensor, hx: Tensor | None = None) -> Tensor: assert input.dim() in (1, 2), f"GRUCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" is_batched = input.dim() == 2 if not is_batched: input = input.unsqueeze(0) if hx is None: hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) else: hx = hx.unsqueeze(0) if not is_batched else hx lora_ih = (self.lora_b_ih @ self.lora_a_ih) * self.scaling lora_hh = (self.lora_b_hh @ self.lora_a_hh) * self.scaling ret = _VF.gru_cell(input, hx, self.weight_ih + lora_ih, self.weight_hh + lora_hh, self.bias_ih, self.bias_hh) if not is_batched: ret = ret.squeeze(0) return ret
[docs]class LoraParallelEmbedding(ParallelEmbedding, _Lora): __constants__ = ParallelEmbedding.__constants__ + ["r", "lora_alpha", "merge", "scaling", "merged"] def __init__( self, num_embeddings: int, embedding_dim: int, r: int, lora_alpha: float = 1.0, lora_dropout: float = 0.0, merge: bool = False, padding_idx: int | None = None, max_norm: float | None = None, norm_type: float = 2.0, scale_grad_by_freq: bool = False, sparse: bool = False, init_type: InitializationType = "xavier_normal", ) -> None: super().__init__( num_embeddings=num_embeddings, embedding_dim=embedding_dim, padding_idx=padding_idx, max_norm=max_norm, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, sparse=sparse, init_type=init_type, ) assert r > 0 self.r = r self.lora_alpha = lora_alpha self.scaling = self.lora_alpha / self.r self.merge = merge self.dropout = nn.Identity() if lora_dropout == 0.0 else nn.Dropout(p=lora_dropout) self.merged = False self.lora_a = nn.Parameter(self.weight.new_empty((r, num_embeddings))) self.lora_b = nn.Parameter(self.weight.new_empty((self.embedding_dim_per_rank, r))) self.weight.requires_grad_(False) self.reset_parameters()
[docs] def reset_parameters(self) -> None: super().reset_parameters() if hasattr(self, "lora_a") and hasattr(self, "lora_b"): self.reset_lora_parameters()
[docs] def reset_lora_parameters(self) -> None: nn.init.kaiming_normal_(self.lora_a, a=math.sqrt(5)) nn.init.zeros_(self.lora_b)
[docs] def train(self, mode: bool = True) -> "LoraParallelEmbedding": super().train(mode) if mode: if self.merge and self.merged: # Make sure that the weights are not merged if self.lora_a is not None and self.lora_b is not None: self.weight.data -= (self.lora_b @ self.lora_a).transpose(0, 1) * self.scaling self.merged = False elif self.merge and not self.merged: # Merge the weights and mark it if self.lora_a is not None and self.lora_b is not None: self.weight.data += (self.lora_b @ self.lora_a).transpose(0, 1) * self.scaling self.merged = True return self
[docs] def forward(self, x: Tensor) -> Tensor: x = mp_copy(x) if self.lora_a is not None and self.lora_b is not None and not self.merged: output_parallel = F.embedding( x, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse, ) after_a_parallel = F.embedding( x, self.lora_a.transpose(0, 1), self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse, ) return mp_gather(output_parallel + (after_a_parallel @ self.lora_b.transpose(0, 1)) * self.scaling) return mp_gather(output_parallel)
[docs]class LoraColumnParallelLinear(ColumnParallelLinear, _Lora): __constants__ = ColumnParallelLinear.__constants__ + [ "r", "lora_alpha", "scaling", "merge", "fan_in_fan_out", "merged", ] def __init__( self, in_features: int, out_features: int, r: int, lora_alpha: float = 1.0, lora_dropout: float = 0.0, fan_in_fan_out: bool = False, merge: bool = False, bias: bool = True, gather_output: bool = True, init_type: InitializationType = "xavier_normal", stride: int = 1, ) -> None: super().__init__( in_features=in_features, out_features=out_features, bias=bias, gather_output=gather_output, init_type=init_type, stride=stride, ) assert r > 0 self.r = r self.lora_alpha = lora_alpha self.scaling = self.lora_alpha / self.r self.merge = merge self.fan_in_fan_out = fan_in_fan_out self.dropout = nn.Identity() if lora_dropout == 0.0 else nn.Dropout(p=lora_dropout) self.merged = False self.lora_a = nn.Parameter(self.weight.new_empty((r, in_features))) self.lora_b = nn.Parameter(self.weight.new_empty((self.output_size_per_partition, r))) self.weight.requires_grad_(False) self.reset_parameters() if fan_in_fan_out: self.weight.data = self.weight.data.transpose(0, 1)
[docs] def reset_parameters(self) -> None: super().reset_parameters() if hasattr(self, "lora_a") and hasattr(self, "lora_b"): self.reset_lora_parameters()
[docs] def reset_lora_parameters(self) -> None: nn.init.kaiming_normal_(self.lora_a, a=math.sqrt(5)) nn.init.zeros_(self.lora_b)
def _t(self, w: Tensor) -> Tensor: return w.transpose(0, 1) if self.fan_in_fan_out else w
[docs] def train(self, mode: bool = True) -> "LoraColumnParallelLinear": super().train(mode) if mode: if self.merge and self.merged: # Make sure that the weights are not merged if self.lora_a is not None and self.lora_b is not None: self.weight.data -= self._t(self.lora_b @ self.lora_a) * self.scaling self.merged = False elif self.merge and not self.merged: # Merge the weights and mark it if self.lora_a is not None and self.lora_b is not None: self.weight.data += self._t(self.lora_b @ self.lora_a) * self.scaling self.merged = True return self
[docs] def forward(self, x: Tensor) -> Tensor: input_parallel = mp_copy(x) if self.lora_a is not None and self.lora_b is not None and not self.merged: output_parallel = F.linear(input_parallel, self._t(self.weight), bias=self.bias) mm = self.dropout(input_parallel) @ self.lora_a.transpose(0, 1) @ self.lora_b.transpose(0, 1) output_parallel = output_parallel + mm * self.scaling return mp_gather(output_parallel) if self.gather_output else output_parallel output_parallel = F.linear(input_parallel, self.weight, self.bias) return mp_gather(output_parallel) if self.gather_output else output_parallel
[docs]class LoraRowParallelLinear(RowParallelLinear, _Lora): __constants__ = RowParallelLinear.__constants__ + [ "r", "lora_alpha", "scaling", "merge", "fan_in_fan_out", "merged", ] def __init__( self, in_features: int, out_features: int, r: int, lora_alpha: float = 1.0, lora_dropout: float = 0.0, fan_in_fan_out: bool = False, merge: bool = False, bias: bool = True, input_is_parallel: bool = False, init_type: InitializationType = "xavier_normal", stride: int = 1, ) -> None: super().__init__( in_features=in_features, out_features=out_features, bias=bias, input_is_parallel=input_is_parallel, init_type=init_type, stride=stride, ) self.r = r self.lora_alpha = lora_alpha self.scaling = self.lora_alpha / self.r self.merge = merge self.fan_in_fan_out = fan_in_fan_out self.dropout = nn.Identity() if lora_dropout == 0.0 else nn.Dropout(p=lora_dropout) self.merged = False self.lora_a = nn.Parameter(self.weight.new_empty((r, self.input_size_per_partition))) self.lora_b = nn.Parameter(self.weight.new_empty((out_features, r))) self.weight.requires_grad_(False) self.reset_parameters() if fan_in_fan_out: self.weight.data = self.weight.data.transpose(0, 1)
[docs] def reset_parameters(self) -> None: super().reset_parameters() if hasattr(self, "lora_a") and hasattr(self, "lora_b"): self.reset_lora_parameters()
[docs] def reset_lora_parameters(self) -> None: nn.init.kaiming_normal_(self.lora_a, a=math.sqrt(5)) nn.init.zeros_(self.lora_b)
def _t(self, w: Tensor) -> Tensor: return w.transpose(0, 1) if self.fan_in_fan_out else w
[docs] def train(self, mode: bool = True) -> "LoraRowParallelLinear": super().train(mode) if mode: if self.merge and self.merged: # Make sure that the weights are not merged if self.lora_a is not None and self.lora_b is not None: self.weight.data -= self._t(self.lora_b @ self.lora_a) * self.scaling self.merged = False elif self.merge and not self.merged: # Merge the weights and mark it if self.lora_a is not None and self.lora_b is not None: self.weight.data += self._t(self.lora_b @ self.lora_a) * self.scaling self.merged = True return self
[docs] def forward(self, x: Tensor) -> Tensor: input_parallel = x if self.input_is_parallel else mp_scatter(x) if self.lora_a is not None and self.lora_b is not None and not self.merged: output_parallel = F.linear(input_parallel, self._t(self.weight), bias=self.bias) mm = self.dropout(input_parallel) @ self.lora_a.transpose(0, 1) @ self.lora_b.transpose(0, 1) output_parallel = output_parallel + mm * self.scaling output = mp_reduce(output_parallel) return output if self.bias is None else output + self.bias output_parallel = F.linear(input_parallel, self.weight, self.bias) output = mp_reduce(output_parallel) return output if self.bias is None else output + self.bias
@overload def lora(module: nn.Embedding, r: int, alpha: float = 1.0, dropout: float = 0.0, merge: bool = False) -> LoraEmbedding: ... @overload def lora(module: nn.Linear, r: int, alpha: float = 1.0, dropout: float = 0.0, merge: bool = False) -> LoraLinear: ... @overload def lora(module: nn.Conv1d, r: int, alpha: float = 1.0, dropout: float = 0.0, merge: bool = False) -> LoraConv1d: ... @overload def lora( module: nn.ConvTranspose1d, r: int, alpha: float = 1.0, dropout: float = 0.0, merge: bool = False, ) -> LoraConv1d: ... @overload def lora(module: nn.Conv2d, r: int, alpha: float = 1.0, dropout: float = 0.0, merge: bool = False) -> LoraConv2d: ... @overload def lora( module: nn.ConvTranspose2d, r: int, alpha: float = 1.0, dropout: float = 0.0, merge: bool = False, ) -> LoraConv2d: ... @overload def lora(module: nn.LSTM, r: int, alpha: float = 1.0, dropout: float = 0.0, merge: bool = False) -> LoraLSTM: ... @overload def lora(module: nn.GRU, r: int, alpha: float = 1.0, dropout: float = 0.0, merge: bool = False) -> LoraGRU: ... @overload def lora(module: nn.LSTMCell, r: int, alpha: float = 1.0, dropout: float = 0.0, merge: bool = False) -> LoraLSTMCell: ... @overload def lora(module: nn.GRUCell, r: int, alpha: float = 1.0, dropout: float = 0.0, merge: bool = False) -> LoraGRUCell: ... @overload def lora( module: ParallelEmbedding, r: int, alpha: float = 1.0, dropout: float = 0.0, merge: bool = False, ) -> LoraParallelEmbedding: ... @overload def lora( module: ColumnParallelLinear, r: int, alpha: float = 1.0, dropout: float = 0.0, merge: bool = False, ) -> LoraColumnParallelLinear: ... @overload def lora( module: RowParallelLinear, r: int, alpha: float = 1.0, dropout: float = 0.0, merge: bool = False, ) -> LoraRowParallelLinear: ... @overload def lora(module: SupportedModule, r: int, alpha: float = 1.0, dropout: float = 0.0, merge: bool = False) -> nn.Module: ...
[docs]def lora(module: SupportedModule, r: int, alpha: float = 1.0, dropout: float = 0.0, merge: bool = False) -> nn.Module: """Wraps a module with LoRA. This function takes a base module and returns the LoRA version of that module. The new module is effectively a drop-in replacement for the original module; for example, it can load the same state dict, and it has the same input and output shapes. Args: module: The module to wrap. r: The number of LoRA components to use. If 0, then LoRA is not used. alpha: The scaling factor for the LoRA components. A higher value means that more weight is given to the LoRA components. dropout: The dropout probability applied to the input value before computing the LoRA components. This parameter is not supported for RNNs (because it would require modifying the underyling kernel). merge: Whether to merge the LoRA components into the original weights. If True, then the LoRA components are merged into the weights during training, and the original weights are used during evaluation. If False, then the LoRA components are used during both training and evaluation. Returns: The LoRA version of the module. Raises: ValueError: If the module is not supported. """ if isinstance(module, nn.Embedding): embedding = LoraEmbedding( module.num_embeddings, module.embedding_dim, padding_idx=module.padding_idx, max_norm=module.max_norm, norm_type=module.norm_type, scale_grad_by_freq=module.scale_grad_by_freq, sparse=module.sparse, r=r, lora_alpha=alpha, merge=merge, ) embedding.weight.data.copy_(module.weight.data) return embedding if isinstance(module, nn.Linear): linear = LoraLinear( module.in_features, module.out_features, r=r, lora_alpha=alpha, merge=merge, bias=module.bias is not None, ) linear.weight.data.copy_(module.weight.data) if module.bias is not None and linear.bias is not None: linear.bias.data.copy_(module.bias.data) return linear if isinstance(module, nn.Conv1d): conv_1d = LoraConv1d( module.in_channels, module.out_channels, cast(tuple[int], module.kernel_size), r=r, lora_alpha=alpha, lora_dropout=dropout, merge=merge, stride=cast(tuple[int], module.stride), padding=cast(str | tuple[int], module.padding), dilation=cast(tuple[int], module.dilation), groups=module.groups, bias=module.bias is not None, ) conv_1d.weight.data.copy_(module.weight.data) if module.bias is not None and conv_1d.bias is not None: conv_1d.bias.data.copy_(module.bias.data) return conv_1d if isinstance(module, nn.ConvTranspose1d): conv_transpose_1d = LoraConvTranspose1d( module.in_channels, module.out_channels, cast(tuple[int], module.kernel_size), r=r, lora_alpha=alpha, lora_dropout=dropout, merge=merge, stride=cast(tuple[int], module.stride), padding=cast(tuple[int], module.padding), output_padding=cast(tuple[int], module.output_padding), dilation=cast(tuple[int], module.dilation), groups=module.groups, bias=module.bias is not None, ) conv_transpose_1d.weight.data.copy_(module.weight.data) if module.bias is not None and conv_transpose_1d.bias is not None: conv_transpose_1d.bias.data.copy_(module.bias.data) return conv_transpose_1d if isinstance(module, nn.Conv2d): conv_2d = LoraConv2d( module.in_channels, module.out_channels, cast(tuple[int, int], module.kernel_size), r=r, lora_alpha=alpha, lora_dropout=dropout, merge=merge, stride=cast(tuple[int, int], module.stride), padding=cast(str | tuple[int, int], module.padding), dilation=cast(tuple[int, int], module.dilation), groups=module.groups, bias=module.bias is not None, ) conv_2d.weight.data.copy_(module.weight.data) if module.bias is not None and conv_2d.bias is not None: conv_2d.bias.data.copy_(module.bias.data) return conv_2d if isinstance(module, nn.ConvTranspose2d): conv_transpose_2d = LoraConvTranspose2d( module.in_channels, module.out_channels, cast(tuple[int, int], module.kernel_size), r=r, lora_alpha=alpha, lora_dropout=dropout, merge=merge, stride=cast(tuple[int, int], module.stride), padding=cast(tuple[int, int], module.padding), output_padding=cast(tuple[int, int], module.output_padding), dilation=cast(tuple[int, int], module.dilation), groups=module.groups, bias=module.bias is not None, ) conv_transpose_2d.weight.data.copy_(module.weight.data) if module.bias is not None and conv_transpose_2d.bias is not None: conv_transpose_2d.bias.data.copy_(module.bias.data) return conv_transpose_2d if isinstance(module, nn.LSTM): if dropout > 0.0: warnings.warn("LoRA dropout is not supported for LSTMs") lstm = LoraLSTM( module.input_size, module.hidden_size, r=r, lora_alpha=alpha, num_layers=module.num_layers, batch_first=module.batch_first, dropout=module.dropout, bidirectional=module.bidirectional, proj_size=module.proj_size, bias=module.bias, ) for param_name, param_value in module.named_parameters(): getattr(lstm, param_name).data.copy_(param_value.data) return lstm if isinstance(module, nn.GRU): if dropout > 0.0: warnings.warn("LoRA dropout is not supported for GRUs") gru = LoraGRU( module.input_size, module.hidden_size, r=r, lora_alpha=alpha, num_layers=module.num_layers, bias=module.bias, batch_first=module.batch_first, dropout=module.dropout, bidirectional=module.bidirectional, proj_size=module.proj_size, ) for param_name, param_value in module.named_parameters(): getattr(gru, param_name).data.copy_(param_value.data) return gru if isinstance(module, nn.LSTMCell): if dropout > 0.0: warnings.warn("LoRA dropout is not supported for LSTMCells") lstm_cell = LoraLSTMCell( module.input_size, module.hidden_size, r=r, lora_alpha=alpha, bias=module.bias, ) lstm_cell.weight_hh.data.copy_(module.weight_hh.data) lstm_cell.weight_ih.data.copy_(module.weight_ih.data) if module.bias: lstm_cell.bias_hh.data.copy_(module.bias_hh.data) lstm_cell.bias_ih.data.copy_(module.bias_ih.data) return lstm_cell if isinstance(module, nn.GRUCell): if dropout > 0.0: warnings.warn("LoRA dropout is not supported for GRUCells") gru_cell = LoraGRUCell( module.input_size, module.hidden_size, r=r, lora_alpha=alpha, bias=module.bias, ) gru_cell.weight_hh.data.copy_(module.weight_hh.data) gru_cell.weight_ih.data.copy_(module.weight_ih.data) if module.bias: gru_cell.bias_hh.data.copy_(module.bias_hh.data) gru_cell.bias_ih.data.copy_(module.bias_ih.data) return gru_cell if isinstance(module, ParallelEmbedding): parallel_embedding = LoraParallelEmbedding( module.num_embeddings, module.embedding_dim, r=r, lora_alpha=alpha, lora_dropout=dropout, merge=merge, padding_idx=module.padding_idx, max_norm=module.max_norm, norm_type=module.norm_type, scale_grad_by_freq=module.scale_grad_by_freq, sparse=module.sparse, ) parallel_embedding.weight.data.copy_(module.weight.data) return parallel_embedding if isinstance(module, RowParallelLinear): row_parallel_linear = LoraRowParallelLinear( module.in_features, module.out_features, r=r, lora_alpha=alpha, lora_dropout=dropout, merge=merge, bias=module.bias is not None, ) row_parallel_linear.weight.data.copy_(module.weight.data) if module.bias is not None and row_parallel_linear.bias is not None: row_parallel_linear.bias.data.copy_(module.bias.data) return row_parallel_linear if isinstance(module, ColumnParallelLinear): column_parallel_linear = LoraColumnParallelLinear( module.in_features, module.out_features, r=r, lora_alpha=alpha, lora_dropout=dropout, merge=merge, bias=module.bias is not None, ) column_parallel_linear.weight.data.copy_(module.weight.data) if module.bias is not None and column_parallel_linear.bias is not None: column_parallel_linear.bias.data.copy_(module.bias.data) return column_parallel_linear raise ValueError(f"Unsupported module type {type(module)}")
T_module = TypeVar("T_module", bound=SupportedModule)
[docs]def maybe_lora( module: T_module, r: int | None, alpha: float = 1.0, dropout: float = 0.0, merge: bool = False, freeze: bool = True, ) -> T_module: """Apply LoRA to a supported module, if a LoRA rank is provided. Args: module: A supported module. r: The LoRA rank. alpha: The LoRA alpha parameter. dropout: The LoRA dropout rate. merge: Whether to merge the LoRA rank into the input dimension. freeze: Whether to freeze the module's parameters if a LoRA rank is not provided. This argument has no effect if a LoRA rank is provided, since downstream users can always freeze just the module themselves. Typically, when trying out LoRA fine-tuning, downstream users will want to freeze most of the module parameters and apply LoRA only to a subset of the module's layers, so this is the default behavior. Returns: The module with LoRA applied, if a LoRA rank is provided. """ if freeze and r is None: module = cast(T_module, module.requires_grad_(False)) return module if r is None else lora(module, r, alpha, dropout, merge)
[docs]def maybe_lora_weight_norm( module: T_module, r: int | None, alpha: float = 1.0, dropout: float = 0.0, merge: bool = False, freeze: bool = True, ) -> T_module: module = maybe_lora(module, r=r, alpha=alpha, dropout=dropout, merge=merge, freeze=freeze) return nn.utils.weight_norm(module)
[docs]def reset_lora_weights_(module: nn.Module) -> None: """Resets any LoRA weights in the module. All of the LoRA modules have a ``reset_lora_parameters`` method that will reset the LoRA weights in-place. This function looks for any modules with this method and calls it. Args: module: The module to reset, in-place. """ for _, submodule in module.named_modules(): if isinstance(submodule, _Lora): submodule.reset_lora_parameters()
[docs]def freeze_non_lora_(module: nn.Module) -> None: """Freezes any non-LoRA parameters in the module. Args: module: The module to freeze, in-place. """ for _, submodule in module.named_modules(): if isinstance(submodule, _Lora): continue for param in submodule.parameters(recurse=False): param.requires_grad_(False)