DoRA’s Magnitude Vector

python
deep learning
machine learning
LLM
In this blog post I highlight a key difference I saw between Raschka’s and peft’s implementation of DoRA.
Author

Vishal Bakshi

Published

February 1, 2025

Setup

from peft import LoraConfig, get_peft_model
import transformers

# the following imports are from dora.py
from copy import deepcopy

import torch
import torch.nn.functional as F
from torch import nn

from peft.utils.integrations import dequantize_module_weight, gather_params_ctx
from peft.utils.other import transpose

Background

I am currently re-reading the DoRA (Weight-Decomposed Low-Rank Adaptation) paper. I took a bit of a detour and worked through the fantastic article Improving LoRA: Implementing Weight-Decomposed Low-Rank Adaptation (DoRA) from Scratch by Sebastian Raschka (I am also reading his book Building a Large Language Model (from scratch) as part of a fastai study group). The article is full of helpful diagrams and breakdowns of concepts as well as easily digestible implementation in code. One particular breakthrough for me while reading the article was his demonstration of the distributive law of multiplication:

x.(W+ΔW) = x.W + x.ΔW

Similarly, we can write the following for LoRA:

x.(W+A.B) = x.W + x.A.B

Reading this made it click for me why and how LoRA adapters are such an efficient way of handling downstream tasks.

I also took a deep dive into the peft library’s implementation of DoRA. I recently made a video of this deep dive.

In this blog post I am going to compare Raschka’s article’s implementation with peft’s and highlight a key difference that I found between them in how they implement the decomposition of a weight matrix into its magnitude and directional components.

I’ll start by reviewing both approaches.

Raschka’s Implementation

I want to add a caveat that this implementation I assume is by no means a “final” or “production” implementation, as I understand it to be more educational and illustrative.

I’ll start by copy/pasting relevant code: LoRALayer (DoRA uses LoRA to fine-tune the directional component) and LinearWithDoRAMerged.

class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        self.B = nn.Parameter(torch.zeros(rank, out_dim))
        self.alpha = alpha

    def forward(self, x):
        x = self.alpha * (x @ self.A @ self.B)
        return x
class LinearWithDoRAMerged(nn.Module):

    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )
        self.m = nn.Parameter(
            self.linear.weight.norm(p=2, dim=0, keepdim=True))


  # Code loosely inspired by
  # https://github.com/catid/dora/blob/main/dora.py

    def forward(self, x):
        lora = self.lora.A @ self.lora.B
        numerator = self.linear.weight + self.lora.alpha*lora.T
        denominator = numerator.norm(p=2, dim=0, keepdim=True)
        directional_component = numerator / denominator
        new_weight = self.m * directional_component
        return F.linear(x, new_weight, self.linear.bias)

I’ll also create a regular linear layer using one of the in/out feature values in the Raschka article:

linear = nn.Linear(in_features=784, out_features=128, bias=True)
linear
Linear(in_features=784, out_features=128, bias=True)
dora_layer = LinearWithDoRAMerged(linear, 256, 512)
dora_layer
LinearWithDoRAMerged(
  (linear): Linear(in_features=784, out_features=128, bias=True)
  (lora): LoRALayer()
)

Here’s the key value: the shape of the magnitude vector. In Raschka’s code, it’s 1 x 784, where 784 is the number of linear in_features.

dora_layer.m.shape
torch.Size([1, 784])

Looking at LinearWithDoRAMerged.__init__:

self.m = nn.Parameter(
            self.linear.weight.norm(p=2, dim=0, keepdim=True))

The norm is taking over dim=0, which is the dimension of out_features:

linear.weight.shape
torch.Size([128, 784])

In other words, we end up with 1 magnitude value for each of the 784 input neurons.

peft Implementation

From src/peft/tuners/lora/dora.py:

class DoraLinearLayer(nn.Module):
    def __init__(self, fan_in_fan_out):
        super().__init__()
        self.fan_in_fan_out = fan_in_fan_out

    def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
        # calculate L2 norm of weight matrix, column-wise
        weight = transpose(weight, self.fan_in_fan_out)
        weight = weight + scaling * lora_weight
        weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype)
        return weight_norm

    ...

The very important attribute here is fan_in_fan_out. I found a few places in the peft codebase which documented it as follows:

Set this to True if the layer to replace stores weight like (fan_in, fan_out)

How I interpret this: if the weights are stored as (in, out), fan_in_fan_out is True, if stored as (out, in) fan_in_fan_out is False.

Looking at an example, I’ll peft-ify SmolLM2-135M:

model_nm = 'HuggingFaceTB/SmolLM2-135M'
model_nm
'HuggingFaceTB/SmolLM2-135M'
model = transformers.AutoModelForSequenceClassification.from_pretrained(model_nm, num_labels=2)
peft_config = LoraConfig(r=256, use_rslora=False, use_dora=True, target_modules=['down_proj', 'gate_proj', 'k_proj', 'o_proj', 'q_proj', 'up_proj', 'v_proj'])
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
trainable params: 78,307,200 || all params: 212,823,360 || trainable%: 36.7945

Looking at one of the layers which has a different number of input and output features, k_proj:

k_proj = model.base_model.model.model.layers[0].self_attn.k_proj
k_proj
lora.Linear(
  (base_layer): Linear(in_features=576, out_features=192, bias=False)
  (lora_dropout): ModuleDict(
    (default): Identity()
  )
  (lora_A): ModuleDict(
    (default): Linear(in_features=576, out_features=256, bias=False)
  )
  (lora_B): ModuleDict(
    (default): Linear(in_features=256, out_features=192, bias=False)
  )
  (lora_embedding_A): ParameterDict()
  (lora_embedding_B): ParameterDict()
  (lora_magnitude_vector): ModuleDict(
    (default): lora.dora.DoraLinearLayer()
  )
)

The base layer has 576 in_features and 192 out_features:

k_proj.base_layer.weight.shape
torch.Size([192, 576])

The fan_in_fan_out attribute is False which checks out by looking at the shape above which is (out, in).

k_proj.fan_in_fan_out
False

Why is fan_in_fan_out such a big deal to me? Well, because look at how get_weight_norm is written:

def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
    # calculate L2 norm of weight matrix, column-wise
    weight = transpose(weight, self.fan_in_fan_out)
    weight = weight + scaling * lora_weight
    weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype)
    return weight_norm

I’ll walk through each line, starting with the base layers weight matrix:

weight = k_proj.base_layer.weight
weight.shape
torch.Size([192, 576])

We then pass the weight and fan_in_fan_out to transpose:

weight = transpose(weight, k_proj.fan_in_fan_out)
weight.shape
torch.Size([192, 576])

It doesn’t transpose it! That’s because in tranpose if fan_in_fan_out is False it returns the weight as is:

def transpose(weight, fan_in_fan_out):
    if not fan_in_fan_out:
        return weight

    if isinstance(weight, torch.nn.Parameter):
        return torch.nn.Parameter(weight.T)
    return weight.T

This is absolutely critical in determining the shape of the magnitude vector:

weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype)
weight_norm.shape
torch.Size([192])

Note two things here: - they take the norm across dim=1 (the number of in_features) - we end up with 1 magnitude value for each of the 192 output neurons. Contrast this with Raschka’s article where we had 1 magnitude value for each of the input neurons.

Aside: Claude Conversation

While I was re-reading the DoRA paper, I found the emphasis on column-wise vector norm interesting so I asked Claude about it:

Prompt: Why columns? What is special about doing this for columns and not rows. Is there something fundamental about columns in a matrix of pretrained weights?

Response: For neural networks, columns of weight matrices map features from one layer’s dimension to another’s. Each column represents all the connections to a single output neuron. So columns form natural functional units, making column-wise normalization more meaningful than row-wise for understanding and modifying the network’s learned transformations.

Prompt: So suppose we have an input vector x which contains values corresponding to each input variable, the pretrained weight matrix would is d x k so x would be 1 x d. Each row in pretrained weight matrix would represent one input variable. Multiplying 1 x d with d x k gives us 1 x k where k is the number of activations (i.e. the number of neurons).

Response: Yes, that’s correct. The output for a single neuron (one column of the weight matrix) is the dot product between the input vector and that column’s weights. So when we do column-wise normalization, we’re normalizing the set of weights that feed into each individual output neuron. This helps explain why column-wise normalization is natural - each column represents a complete “unit” in terms of how the network processes information to produce one output activation.

So, it seems like decomposing the weight matrix into a magnitude vector and direction matrix corresponding to output neurons seems really important.

Final Thoughts

Explorations like this is why I’m grateful for open source code. The efforts of folks like HuggingFace and Raschka to democratize LLMs allows folks like me to curate insightful explorations for myself, and I’m obviously happy to share them out in this blog post. If you find something incorrect in my interpretation of Raschka’s or peft’s code, please at me on Twitter @vishal_learner.

A future exercise that I want to do: train DoRA implementations with a column-wise norm and a row-wise norm and compare results. This would truly highlight whether it’s important to align the decomposition with the output neurons of the weight matrix.