from transformers import AutoModelForCausalLM
from peft import PeftModel
import os
import torch
import psutil
import copy
import gc
TIL: PeftModel Base Model Behavior
Background
In this TIL blog post I share some unexpected behavior when using PeftModel
. In short, when merging LoRA adapter weights with the base model, the base model gets overwritten. While unexpected, in hindsight this makes sense if you want to minimize memory usage.
from google.colab import userdata
'HUGGING_FACE_HUB_TOKEN'] = userdata.get('HUGGING_FACE_HUB_TOKEN') os.environ[
def _mem(): print(f"RAM Usage: {psutil.virtual_memory().percent}% (Used: {psutil.virtual_memory().used / (1024**3):.2f} GB / Total: {psutil.virtual_memory().total / (1024**3):.2f} GB)")
Merging LoRA Adapter Weights
Before loading any model, here is the memory usage. I’m using an A100 GPU with Colab Pro.
_mem()
RAM Usage: 3.5% (Used: 2.10 GB / Total: 83.48 GB)
= AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf").to("cpu") base_model
After loading the base model (Llama2-7B) the memory usage increases to 27GB.
_mem()
RAM Usage: 33.8% (Used: 27.35 GB / Total: 83.48 GB)
Loading the LoRA adapter weights increases the memory usage to 28 GB.
= PeftModel.from_pretrained(
model_to_merge =base_model,
model="LoRA-TMLR-2024/magicoder-lora-rank-64-alpha-128"
model_id"cpu") ).to(
_mem()
RAM Usage: 34.8% (Used: 28.22 GB / Total: 83.48 GB)
= model_to_merge.merge_and_unload() merged_model
Merging the model essentially keeps the memory usage constant at 28GB.
_mem()
RAM Usage: 34.9% (Used: 28.28 GB / Total: 83.48 GB)
Comparing base_model
and merged_model
Weights
However, saving memory comes at a cost! You no longer have access to the base model. I’ll first do a visual inspection of one of the weight matrices.
0].self_attn.q_proj.weight base_model.model.layers[
Parameter containing:
tensor([[-0.0020, -0.0156, 0.0023, ..., 0.0098, -0.0017, -0.0031],
[ 0.0283, -0.0176, 0.0062, ..., -0.0076, 0.0004, 0.0087],
[-0.0230, 0.0225, 0.0001, ..., 0.0028, 0.0190, -0.0063],
...,
[ 0.0003, 0.0016, -0.0013, ..., 0.0081, -0.0308, 0.0110],
[ 0.0259, 0.0203, 0.0045, ..., -0.0310, -0.0147, -0.0111],
[-0.0077, -0.0174, 0.0012, ..., 0.0182, 0.0181, -0.0070]])
0].self_attn.q_proj.weight merged_model.model.layers[
Parameter containing:
tensor([[-0.0020, -0.0156, 0.0023, ..., 0.0098, -0.0017, -0.0031],
[ 0.0283, -0.0176, 0.0062, ..., -0.0076, 0.0004, 0.0087],
[-0.0230, 0.0225, 0.0001, ..., 0.0028, 0.0190, -0.0063],
...,
[ 0.0003, 0.0016, -0.0013, ..., 0.0081, -0.0308, 0.0110],
[ 0.0259, 0.0203, 0.0045, ..., -0.0310, -0.0147, -0.0111],
[-0.0077, -0.0174, 0.0012, ..., 0.0182, 0.0181, -0.0070]])
Both matrices are equal. Analyzing weight matrix differences more systematically:
def _diffs(model1, model2):
= 0
n_diff for layer_idx in range(32):
for component in ["q_proj", "k_proj", "o_proj", "v_proj"]:
= getattr(model1.model.layers[layer_idx].self_attn, component).weight
W1 = getattr(model2.model.layers[layer_idx].self_attn, component).weight
W2if not torch.allclose(W1, W2, rtol=1e-5, atol=1e-8): n_diff += 1
print(f"Different Self-Attention Matrices: {n_diff}")
= 0
n_diff for layer_idx in range(32):
for component in ["up_proj", "down_proj", "gate_proj"]:
= getattr(model1.model.layers[layer_idx].mlp, component).weight
W1 = getattr(model2.model.layers[layer_idx].mlp, component).weight
W2 if not torch.allclose(W1, W2, rtol=1e-5, atol=1e-8): n_diff += 1
print(f"Different MLP Weight Matrices: {n_diff}")
_diffs(base_model, merged_model)
Different Self-Attention Matrices: 0
Different MLP Weight Matrices: 0
For both self-attention and MLP modules, all weight matrices between the base_model
and the merged_model
are the same. Using the is
operator we can see that they reference the same object in memory (which is where the memory savings come from):
is merged_model base_model
True
Copying the Base Model for Comparison
I’ll now load the base model again to compare with the merged model weights.
_mem()
RAM Usage: 35.4% (Used: 28.68 GB / Total: 83.48 GB)
del base_model
gc.collect()
483
_mem()
RAM Usage: 35.5% (Used: 28.78 GB / Total: 83.48 GB)
Note that deleting the base model did not change the memory usage.
= AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf").to("cpu") base_model
_mem()
RAM Usage: 65.7% (Used: 53.94 GB / Total: 83.48 GB)
With a new base model loaded, the memory usage jumps up to 54 GB.
_diffs(base_model, merged_model)
Different Self-Attention Matrices: 128
Different MLP Weight Matrices: 96
There are 32 layers in this Llama model, and each model’s self-attention module has 4 weight matrices we are comparing, resulting in 128 matrices in total. The MLP module has 3 weight matrices we are comparing, resulting in 96 total across the model. The base model and merged model are fully different models (in terms of weight matrix values).
Using .get_base_model
Looking at the PeftModel
documentation, I noted the method get_base_model
which seems relevant to this exercise. However, using that method results in the same weights as the merged model:
model_to_merge.get_base_model
peft.peft_model.PeftModel.get_base_model
def get_base_model() -> torch.nn.Module
Returns the base model.
_diffs(merged_model, model_to_merge.get_base_model())
Different Self-Attention Matrices: 0
Different MLP Weight Matrices: 0
I am planning to do more of these short TIL blog posts this year! It helps me solidify concepts as I come across them. I hope you enjoyed this blog post!