from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import matplotlib.pyplot as plt
Logit Divergence Between Models Differently Converted to torch.bfloat16
torch.bfloat16
differently (one model with torch_dtype
specified in from_pretrained
and the other with .to(torch.bfloat16)
specified after model is loaded.
Background
In this blog post I’ll illustrate a recent head-scratcher I came across—how to convert a model to torch.bfloat16
changes the intermediate and final outputs. I don’t know why this happens and not sure of a path to figure that out.
In model1
I specify torch_dtype
in AutoModelForCausalLM.from_pretrained
. In model2
, I don’t, and instead use to(torch.bfloat16)
after the model is loaded.
= "HuggingFaceTB/SmolLM2-135M"
checkpoint = "cuda"
device = AutoTokenizer.from_pretrained(checkpoint)
tokenizer = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16).to(device)
model1 = AutoModelForCausalLM.from_pretrained(checkpoint).to(device).to(torch.bfloat16) model2
Comparing Logits
Given a set of input tokens, the output logits of the two models are not identical.
= tokenizer.encode("Gravity is", return_tensors="pt").to(device)
inputs inputs
tensor([[22007, 6463, 314]], device='cuda:0')
eval()
model1.= model1(inputs).logits
logits1 logits1
tensor([[[18.0000, 14.5625, 14.6875, ..., 16.2500, 16.2500, 22.1250],
[15.6875, -0.4180, -0.3477, ..., 8.2500, 12.1250, 7.3438],
[12.1875, -2.2812, -2.2031, ..., 7.3750, 10.6875, 8.1875]]],
device='cuda:0', dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)
eval()
model2.= model2(inputs).logits
logits2 logits2
tensor([[[18.0000, 14.5625, 14.6875, ..., 16.2500, 16.2500, 22.1250],
[15.7500, -0.2715, -0.2002, ..., 8.4375, 12.2500, 7.5000],
[12.3125, -2.2188, -2.1406, ..., 7.5000, 10.6875, 8.3125]]],
device='cuda:0', dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)
torch.allclose(logits1, logits2)
False
== logits2).float().mean() (logits1
tensor(0.3457, device='cuda:0')
abs(logits1 - logits2).mean() torch.
tensor(0.0762, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
Comparing Weights
A helper function to inspect a particular submodule in a particular layer.
def _print(model1, model2, module, submodule, layer_idx):
= getattr(getattr(model1.model.layers[layer_idx], module), submodule).weight
w1 = getattr(getattr(model2.model.layers[layer_idx], module), submodule).weight
w2 print(f"{module}.{submodule} torch.allclose: {torch.allclose(w1, w2)}")
"self_attn", "q_proj", 0) _print(model1, model2,
self_attn.q_proj torch.allclose: True
Looping through all weight matrices in state dicts, they are all identical—why are output logits not identical then? I would assume that something in the matrix ops is causing the divergence.
= 0
n = 0
d for k in model1.state_dict().keys():
= model1.state_dict()[k]
w1 = model2.state_dict()[k]
w2 if torch.allclose(w1, w2): n += 1
+= 1
d /d n, d, n
(273, 273, 1.0)
Forward Hooks
Hooking the two models to track intermediate layer outputs.
eval()
model1.eval()
model2.= {} outputs_dict
def capture_output(name):
def hook_fn(module, input, output):
= output[0].detach()
outputs_dict[name] return hook_fn
= []
hooks for i in range(30):
f"model1_{i}")))
hooks.append(model1.model.layers[i].register_forward_hook(capture_output(f"model2_{i}"))) hooks.append(model2.model.layers[i].register_forward_hook(capture_output(
with torch.no_grad():
model1(inputs) model2(inputs)
for h in hooks: h.remove()
The difference in intermediate outputs diverges as you pass through the model. That smells of typical floating point precision error.
= "mean"
metric for i in range(30):
= outputs_dict[f"model1_{i}"]
o1 = outputs_dict[f"model2_{i}"]
o2
if not torch.allclose(o1, o2):
= (o1-o2).abs().max().item()
max_diff = (o1-o2).abs().mean().item()
mean_diff if metric == "max": print(f"Layer {i}: max diff = {max_diff}")
if metric == "mean": print(f"Layer {i}: mean diff = {mean_diff}")
Layer 0: mean diff = 0.0017547607421875
Layer 1: mean diff = 0.005035400390625
Layer 2: mean diff = 0.00830078125
Layer 3: mean diff = 0.010986328125
Layer 4: mean diff = 0.011962890625
Layer 5: mean diff = 0.01251220703125
Layer 6: mean diff = 0.01312255859375
Layer 7: mean diff = 0.0137939453125
Layer 8: mean diff = 0.015380859375
Layer 9: mean diff = 0.0172119140625
Layer 10: mean diff = 0.0189208984375
Layer 11: mean diff = 0.0185546875
Layer 12: mean diff = 0.01953125
Layer 13: mean diff = 0.020751953125
Layer 14: mean diff = 0.021728515625
Layer 15: mean diff = 0.0234375
Layer 16: mean diff = 0.026123046875
Layer 17: mean diff = 0.0263671875
Layer 18: mean diff = 0.0269775390625
Layer 19: mean diff = 0.0301513671875
Layer 20: mean diff = 0.03271484375
Layer 21: mean diff = 0.036376953125
Layer 22: mean diff = 0.044921875
Layer 23: mean diff = 0.05322265625
Layer 24: mean diff = 0.05810546875
Layer 25: mean diff = 0.06689453125
Layer 26: mean diff = 0.0771484375
Layer 27: mean diff = 0.091796875
Layer 28: mean diff = 0.1005859375
Layer 29: mean diff = 0.1484375
The max difference in outputs reaches 6.0
by the 30th layer!
= "max"
metric for i in range(30):
= outputs_dict[f"model1_{i}"]
o1 = outputs_dict[f"model2_{i}"]
o2
if not torch.allclose(o1, o2):
= (o1-o2).abs().max().item()
max_diff = (o1-o2).abs().mean().item()
mean_diff if metric == "max": print(f"Layer {i}: max diff = {max_diff}")
if metric == "mean": print(f"Layer {i}: mean diff = {mean_diff}")
Layer 0: max diff = 0.0625
Layer 1: max diff = 0.25
Layer 2: max diff = 0.25
Layer 3: max diff = 0.25
Layer 4: max diff = 0.25
Layer 5: max diff = 0.25
Layer 6: max diff = 0.25
Layer 7: max diff = 0.5
Layer 8: max diff = 0.5
Layer 9: max diff = 0.5
Layer 10: max diff = 1.0
Layer 11: max diff = 1.0
Layer 12: max diff = 1.0
Layer 13: max diff = 0.5
Layer 14: max diff = 0.5
Layer 15: max diff = 0.5
Layer 16: max diff = 0.5
Layer 17: max diff = 0.5
Layer 18: max diff = 0.5
Layer 19: max diff = 0.75
Layer 20: max diff = 0.5
Layer 21: max diff = 0.5
Layer 22: max diff = 0.5
Layer 23: max diff = 0.75
Layer 24: max diff = 1.0
Layer 25: max diff = 2.0
Layer 26: max diff = 2.0
Layer 27: max diff = 2.0
Layer 28: max diff = 1.0
Layer 29: max diff = 6.0
Reloading the models and inspecting the outputs of intermediate modules like self_attn
and mlp
.
= AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16).to(device)
model1 = AutoModelForCausalLM.from_pretrained(checkpoint).to(device).to(torch.bfloat16) model2
= {
modules "self_attn": ["q_proj", "k_proj", "v_proj", "o_proj"],
"mlp": ["gate_proj", "up_proj", "down_proj"],
"input_layernorm": [],
"post_attention_layernorm": []
}
= {} outputs_dict
for module in modules.keys():
if module == "input_layernorm" or module == "post_attention_layernorm":
getattr(model1.model.layers[0], module).register_forward_hook(capture_output(f"model1_{module}")))
hooks.append(getattr(model2.model.layers[0], module).register_forward_hook(capture_output(f"model2_{module}")))
hooks.append(else:
for submodule in modules[module]:
getattr(getattr(model1.model.layers[0], module), submodule).register_forward_hook(capture_output(f"model1_{module}_{submodule}")))
hooks.append(getattr(getattr(model2.model.layers[0], module), submodule).register_forward_hook(capture_output(f"model2_{module}_{submodule}"))) hooks.append(
with torch.no_grad():
model1(inputs) model2(inputs)
for h in hooks: h.remove()
Interestingly, the intermediate attention outputs are identical but there’s divergence in the outputs of the attention mechanism as it passes through o_proj
.
for module in modules.keys():
if module == "input_layernorm" or module == "post_attention_layernorm":
= outputs_dict[f"model1_{module}"]
o1 = outputs_dict[f"model2_{module}"]
o2 = (o1-o2).abs().mean().item()
diff print(f"{module}: {diff}")
else:
for submodule in modules[module]:
= outputs_dict[f"model1_{module}_{submodule}"]
o1 = outputs_dict[f"model2_{module}_{submodule}"]
o2 = (o1-o2).abs().mean().item()
diff print(f"{module}.{submodule}: {diff}")
self_attn.q_proj: 0.0
self_attn.k_proj: 0.0
self_attn.v_proj: 0.0
self_attn.o_proj: 2.1457672119140625e-05
mlp.gate_proj: 0.000476837158203125
mlp.up_proj: 0.0003833770751953125
mlp.down_proj: 0.00177764892578125
input_layernorm: 0.0
post_attention_layernorm: 1.8715858459472656e-05
Again I haven’t dug into why these differences exist, but wanted to document that they do.