import re
import pandas as pd
import json
import requestsLogging Data Types for Activations, Gradients, Weights, Optimizer States and Loss during Training with LLM-Foundry
Background
In a previous blog post I shared my first couple of iterations of custom Composer callback used to log data types of different entities (activations, gradients, weights, optimizer states, and loss) during training with LLM-Foundry. In this blog post I’ll share my final callback iteration’s code, some lessons I learned along the way (i.e. LLaMA’s self-attention module doesn’t have positional arguments!) and analyze the logging results to observe entity data types throughout the training loop.
Composer Callback Walkthrough
The data types of entities (activations, gradients, weights, loss, and optimizer states) are logged during training with a custom Composer callback DtypeLogger passed to the Composer Trainer. This callback was built up and tested event-by-event using Claude. There is one event handler in the callback for each Composer event from <FIT_START> to <BATCH_END>:
# <INIT>
# <BEFORE_LOAD>
# <AFTER_LOAD>
# <FIT_START>
for epoch in range(NUM_EPOCHS):
# <EPOCH_START>
while True:
# <BEFORE_DATALOADER>
batch = next(dataloader)
if batch is None:
break
inputs, targets = batch
# <AFTER_DATALOADER>
# <BATCH_START>
# <BEFORE_FORWARD>
outputs = model.forward(inputs)
# <AFTER_FORWARD>
# <BEFORE_LOSS>
loss = model.loss(outputs, targets)
# <AFTER_LOSS>
# <BEFORE_BACKWARD>
loss.backward()
# <AFTER_BACKWARD>
optimizer.step()
optimizer.zero_grad()
# <BATCH_END>
# <EPOCH_END>
There are four explicit logging functions:
_log_model_weight_dtypes_log_gradient_dtypes_log_optimizer_state_dtypes_log_loss_dtype
Additionally, activations are logged using register_forward_hook for all modules except self-attention (more on that below). Self-attention inputs are logged using a monkey-patched forward pass.
class DtypeLogger(Callback):
def __init__(self, save_path="/model-checkpoints/dtype_tracking", log_interval=10):
self.save_path = Path(save_path)
self.dtype_logs = {'log': {}}
self.log_interval = log_interval
self.hooks = []
def fit_start(self, state: State, logger: Logger) -> None:
self._log_model_weight_dtypes(state, "fit_start")
self._save_logs()
def epoch_start(self, state: State, logger: Logger) -> None:
self._log_model_weight_dtypes(state, "epoch_start")
self._save_logs()
def before_dataloader(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.log_interval == 0:
self._log_model_weight_dtypes(state, "before_dataloader")
self._save_logs()
def after_dataloader(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.log_interval == 0:
self._log_model_weight_dtypes(state, "after_dataloader")
self._save_logs()
def batch_start(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.log_interval == 0:
self._log_model_weight_dtypes(state, "batch_start")
self._save_logs()
def before_forward(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.log_interval == 0:
self._log_model_weight_dtypes(state, "before_forward")
# Clear old hooks
for hook in self.hooks:
hook.remove()
self.hooks = []
# Get the model
model = state.model.model.base_model.model
transformer_model = model.model # This is the transformer part
batch_id = state.timestamp.batch.value
# Store original forward methods to restore later
self.original_forward_methods = {}
def hook_fn(layer_name, module_name):
def _hook(module, inputs, outputs):
# Log input activation dtype
if isinstance(inputs, tuple) and len(inputs) > 0:
self.dtype_logs["log"][f"forward:{module_name}:{layer_name}:activation_input"] = str(inputs[0].dtype)
# Log output activation dtype
if isinstance(outputs, torch.Tensor):
self.dtype_logs["log"][f"forward:{module_name}:{layer_name}:activation_output"] = str(outputs.dtype)
elif isinstance(outputs, tuple) and len(outputs) > 0:
self.dtype_logs["log"][f"forward:{module_name}:{layer_name}:activation_output"] = str(outputs[0].dtype)
return _hook
# Monkey patch self-attention modules
for layer_idx, layer in enumerate(transformer_model.layers):
# Store the original forward method
original_forward = layer.self_attn.forward
self.original_forward_methods[layer_idx] = original_forward
# Define a closure to capture the current layer_idx
def make_patched_forward(layer_idx, orig_forward):
def patched_forward(self_attn, *args, **kwargs):
# Log the hidden_states dtype
if 'hidden_states' in kwargs and hasattr(kwargs['hidden_states'], 'dtype'):
self.dtype_logs["log"][f"forward:self_attn:layer_{layer_idx}:activation_input"] = str(kwargs['hidden_states'].dtype)
# Call the original method as a bound method
# This ensures 'self_attn' is correctly passed as 'self'
return orig_forward.__get__(self_attn, type(self_attn))(**kwargs)
return patched_forward
# Replace the forward method
layer.self_attn.forward = make_patched_forward(layer_idx, original_forward).__get__(layer.self_attn, type(layer.self_attn))
# Register hook for lm_head
if hasattr(model, 'lm_head'):
self.hooks.append(model.lm_head.register_forward_hook(hook_fn("output", "lm_head")))
# Register hook for embedding layer
self.hooks.append(transformer_model.embed_tokens.register_forward_hook(hook_fn("embeddings", "embed_tokens")))
# Register hooks for each transformer layer
for layer_idx, layer in enumerate(transformer_model.layers):
# Self-attention components - we still register hooks for outputs
self.hooks.append(layer.self_attn.register_forward_hook(hook_fn(f"layer_{layer_idx}", "self_attn")))
self.hooks.append(layer.self_attn.q_proj.register_forward_hook(hook_fn(f"layer_{layer_idx}", "q_proj")))
self.hooks.append(layer.self_attn.k_proj.register_forward_hook(hook_fn(f"layer_{layer_idx}", "k_proj")))
self.hooks.append(layer.self_attn.v_proj.register_forward_hook(hook_fn(f"layer_{layer_idx}", "v_proj")))
self.hooks.append(layer.self_attn.o_proj.register_forward_hook(hook_fn(f"layer_{layer_idx}", "o_proj")))
# MLP components
self.hooks.append(layer.mlp.register_forward_hook(hook_fn(f"layer_{layer_idx}", "mlp")))
self.hooks.append(layer.mlp.gate_proj.register_forward_hook(hook_fn(f"layer_{layer_idx}", "gate_proj")))
self.hooks.append(layer.mlp.up_proj.register_forward_hook(hook_fn(f"layer_{layer_idx}", "up_proj")))
self.hooks.append(layer.mlp.down_proj.register_forward_hook(hook_fn(f"layer_{layer_idx}", "down_proj")))
# Layer norms
self.hooks.append(layer.input_layernorm.register_forward_hook(hook_fn(f"layer_{layer_idx}", "input_layernorm")))
self.hooks.append(layer.post_attention_layernorm.register_forward_hook(hook_fn(f"layer_{layer_idx}", "post_attention_layernorm")))
# Final layer norm
self.hooks.append(transformer_model.norm.register_forward_hook(hook_fn("final", "norm")))
self._save_logs()
def after_forward(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.log_interval == 0:
self._log_model_weight_dtypes(state, "after_forward")
# Restore original forward methods
if hasattr(self, 'original_forward_methods'):
model = state.model.model.base_model.model
transformer_model = model.model
for layer_idx, original_forward in self.original_forward_methods.items():
transformer_model.layers[layer_idx].self_attn.forward = original_forward
self.original_forward_methods = {}
# Clear hooks
for hook in self.hooks:
hook.remove()
self.hooks = []
self._save_logs()
def before_loss(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.log_interval == 0:
self._log_model_weight_dtypes(state, "before_loss")
self._save_logs()
def after_loss(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.log_interval == 0:
self._log_model_weight_dtypes(state, "after_loss")
self._log_loss_dtype(state, "after_loss")
self._save_logs()
def before_backward(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.log_interval == 0:
self._log_model_weight_dtypes(state, "before_backward")
self._save_logs()
def after_backward(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.log_interval == 0:
# Log gradient dtypes as before
self._log_gradient_dtypes(state, "after_backward")
# Track weight dtypes before optimizer step
self._log_model_weight_dtypes(state, "before_optim_step")
# Log optimizer state dtypes
self._log_optimizer_state_dtypes(state, "optimizer_step")
self._save_logs()
def batch_end(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.log_interval == 0:
# Track weight dtypes after optimizer step to detect precision changes
self._log_model_weight_dtypes(state, "after_optim_step")
self._save_logs()
def epoch_end(self, state: State, logger: Logger) -> None:
self._log_model_weight_dtypes(state, "epoch_end")
self._save_logs()
def _log_model_weight_dtypes(self, state: State, event_name: str) -> None:
model = state.model
for name, param in model.named_parameters():
name = name.removeprefix("model.base_model.model.model.")
self.dtype_logs["log"][f"{event_name}:{name}:weights"] = str(param.dtype)
def _log_gradient_dtypes(self, state: State, event_name: str) -> None:
model = state.model
for name, param in model.named_parameters():
name = name.removeprefix("model.base_model.model.model.")
if param.grad is not None: self.dtype_logs['log'][f"{event_name}:{name}:gradients"] = str(param.grad.dtype)
else: self.dtype_logs['log'][f"{event_name}:{name}:gradients"] = "None"
def _log_loss_dtype(self, state: State, event_name: str) -> None:
if hasattr(state, 'loss') and hasattr(state.loss, 'dtype'):
self.dtype_logs["log"][f"{event_name}:loss"] = str(state.loss.dtype)
def _log_optimizer_state_dtypes(self, state: State, event_name: str) -> None:
if hasattr(state, 'optimizers') and state.optimizers is not None:
# Handle single optimizer or list of optimizers
optimizers = state.optimizers if isinstance(state.optimizers, list) else [state.optimizers]
for opt_idx, optimizer in enumerate(optimizers):
# Get optimizer state dict
opt_state = optimizer.state_dict()
# Check if 'state' exists in the optimizer state dict
if 'state' in opt_state:
for param_id, param_state in opt_state['state'].items():
for state_name, state_value in param_state.items():
if isinstance(state_value, torch.Tensor):
# Store dtype of optimizer state tensors (momentum buffers, etc.)
key = f"optimizer_{opt_idx}_param_{param_id}_{state_name}"
self.dtype_logs["log"][f"{event_name}:{key}:optimizer_states"] = str(state_value.dtype)
def _save_logs(self) -> None:
os.makedirs(self.save_path, exist_ok=True)
log_file = self.save_path / "dtype_logs.json"
with open(log_file, 'w') as f:
json.dump(self.dtype_logs, f, indent=2)The most involved event handler is before_forward which involves creating a hook function (hook_fn) passed to PyTorch’s register_forward_hook which exposes the positional inputs and outputs of a module’s forward pass. The hook function modifies self.dtype_logs directly by storing the data type string of inputs and outputs. hook_fn is used for all modules except self attention.
Self attention cannot utilize register_forward_hook because the LlamaDecoderLayer does not call self attention forward pass with any positional arguments:
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)Contrast this with how the forward pass of other modules are called with positional arguments only:
# self attention sublayers
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
attn_output = self.o_proj(attn_output)
# mlp sublayers
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
# non-self attention modules
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.norm(hidden_states)Since self-attention inputs can’t be captured by a hook I had to monkey patch its forward pass to log its inputs’ data type:
for layer_idx, layer in enumerate(transformer_model.layers):
# Store the original forward method
original_forward = layer.self_attn.forward
self.original_forward_methods[layer_idx] = original_forward
# Define a closure to capture the current layer_idx
def make_patched_forward(layer_idx, orig_forward):
def patched_forward(self_attn, *args, **kwargs):
# Log the hidden_states dtype
if 'hidden_states' in kwargs and hasattr(kwargs['hidden_states'], 'dtype'):
self.dtype_logs["log"][f"forward:self_attn:layer_{layer_idx}:activation_input"] = str(kwargs['hidden_states'].dtype)
# Call the original method as a bound method
# This ensures 'self_attn' is correctly passed as 'self'
return orig_forward.__get__(self_attn, type(self_attn))(**kwargs)
return patched_forward
# Replace the forward method
layer.self_attn.forward = make_patched_forward(layer_idx, original_forward).__get__(layer.self_attn, type(layer.self_attn))patched_forward receives positional arguments *args (of which there are none) and keyword arguments **kwargs (all of the arguments to the self-attention forward) and logs the data types of the inputs to self-attention (hidden_states) as self_attn_input before returning the outputs of the original forward pass.
A key line is orig_forward.__get__(self_attn, type(self_attn))(**kwargs). As Claude’s comment mentions, this is to avoid using orig_forward(self_attn, **kwargs) which was causing the following error because the first argument, self_attn, was being interpreted as hidden_states whereas it was intended to represent self:
TypeError: LlamaFlashAttention2.forward() got multiple values for argument 'hidden_states'
In short, when you call __get__(obj, type) on a function it will bind that function as a method to the given object, thus no longer requiring you to pass in self as an argument. This is critical because self_attn.forward has no positional arguments. We can then pass in the keyword arguments to the bound method orig_forward.__get__(self_attn, type(self_attn))(**kwargs), and let the model continue using self-attention correctly. See the Descriptor Guide in the Python docs for more information.
Helper Functions
def parse_index(string):
"""Extract structured information from parameter names"""
info = {
'layer_number': None,
'module': None,
'layer_name': None,
'lora_layer': None,
'training_step': None,
'entity': None
}
# layer = string.split(":")[1]
# info["layer"] = layer
layer_number_match = re.search(r'layers\.(\d+)', string)
if layer_number_match: info['layer_number'] = int(layer_number_match.group(1))
modules = [
"embed_tokens",
"input_layernorm",
"self_attn",
"post_attention_layernorm",
"mlp",
"norm",
"lm_head"
]
module_match = re.search(r'(mlp|self_attn|input_layernorm|post_attention_layernorm|embed_tokens|norm|lm_head)', string)
if module_match: info['module'] = str(modules.index(module_match.group(1))).zfill(2) + '_' + module_match.group(1)
layer_name_match = re.search(r'(q_proj|k_proj|v_proj|o_proj|gate_proj|up_proj|down_proj)', string)
if layer_name_match: info['layer_name'] = layer_name_match.group(1)
lora_match = re.search(r'(base_layer|lora_A|lora_B)', string)
if lora_match: info['lora_layer'] = lora_match.group(1)
else: info['lora_layer'] = "Not a LoRA Layer"
training_steps = [
"fit_start",
"epoch_start",
"before_dataloader",
"after_dataloader",
"batch_start",
"before_forward",
"forward",
"after_forward",
"before_loss",
"after_loss",
"before_backward",
"after_backward",
"before_optim_step",
"optimizer_step",
"after_optim_step"
]
training_step = string.split(":")[0]
info['training_step'] = str(training_steps.index(training_step)).zfill(2) + '_' + training_step
info['entity'] = string.split(":")[-1]
return infodef _df(url):
dtype_data = json.loads(requests.get(url).text)
df = pd.DataFrame(dtype_data).reset_index()
df = df.rename(columns={"index": "index", "log": "dtype"})
parsed_info = df['index'].apply(lambda x: parse_index(x))
df['layer_number'] = parsed_info.apply(lambda x: x['layer_number'])
df['module'] = parsed_info.apply(lambda x: x['module'])
df['layer_name'] = parsed_info.apply(lambda x: x['layer_name'])
df['lora_layer'] = parsed_info.apply(lambda x: x['lora_layer'])
df['training_step'] = parsed_info.apply(lambda x: x['training_step'])
df['entity'] = parsed_info.apply(lambda x: x['entity'])
return dfModel in fp32 (master_weights_dtype==None)
In this case, master_weights_dtype is not provided in the training YAML file.
url = "https://gist.githubusercontent.com/vishalbakshi/9ade8d501629d4c30e8aecfa1c6f67cf/raw/0c162e2305002fbe57fd2570ade302c3659140a1/dtypes_logs_1ba_fp32.json"
df = _df(url)
df.head()| index | dtype | layer_number | module | layer_name | lora_layer | training_step | entity | |
|---|---|---|---|---|---|---|---|---|
| 0 | fit_start:embed_tokens.weight:weights | torch.float32 | NaN | 00_embed_tokens | None | Not a LoRA Layer | 00_fit_start | weights |
| 1 | fit_start:layers.0.self_attn.q_proj.base_layer... | torch.float32 | 0.0 | 02_self_attn | q_proj | base_layer | 00_fit_start | weights |
| 2 | fit_start:layers.0.self_attn.q_proj.lora_A.def... | torch.float32 | 0.0 | 02_self_attn | q_proj | lora_A | 00_fit_start | weights |
| 3 | fit_start:layers.0.self_attn.q_proj.lora_B.def... | torch.float32 | 0.0 | 02_self_attn | q_proj | lora_B | 00_fit_start | weights |
| 4 | fit_start:layers.0.self_attn.k_proj.base_layer... | torch.float32 | 0.0 | 02_self_attn | k_proj | base_layer | 00_fit_start | weights |
Data Types by lora_layer
All LoRA layer entities are in fp32.
df.groupby(['lora_layer', 'dtype'])['dtype'].count()| dtype | ||
|---|---|---|
| lora_layer | dtype | |
| Not a LoRA Layer | None | 62 |
| torch.bfloat16 | 331 | |
| torch.float32 | 2339 | |
| torch.int64 | 1 | |
| base_layer | None | 210 |
| torch.float32 | 2520 | |
| lora_A | torch.float32 | 2730 |
| lora_B | torch.float32 | 2730 |
Data Types by entity (Activations, Gradients, Loss, Optimizer States and Weights)
Every entity except activations are in fp32. Some parameters don’t have gradients because we are training with LoRA.
df.groupby(['entity', 'dtype'])['dtype'].count()| dtype | ||
|---|---|---|
| entity | dtype | |
| activation_input | torch.bfloat16 | 60 |
| torch.float32 | 272 | |
| torch.int64 | 1 | |
| activation_output | torch.bfloat16 | 271 |
| torch.float32 | 62 | |
| gradients | None | 272 |
| torch.float32 | 420 | |
| loss | torch.float32 | 1 |
| optimizer_states | torch.float32 | 1260 |
| weights | torch.float32 | 8304 |
Data Types by Composer Training Step
df.groupby(['training_step', 'entity', 'dtype'])['dtype'].count()| dtype | |||
|---|---|---|---|
| training_step | entity | dtype | |
| 00_fit_start | weights | torch.float32 | 692 |
| 01_epoch_start | weights | torch.float32 | 692 |
| 02_before_dataloader | weights | torch.float32 | 692 |
| 03_after_dataloader | weights | torch.float32 | 692 |
| 04_batch_start | weights | torch.float32 | 692 |
| 05_before_forward | weights | torch.float32 | 692 |
| 06_forward | activation_input | torch.bfloat16 | 60 |
| torch.float32 | 272 | ||
| torch.int64 | 1 | ||
| activation_output | torch.bfloat16 | 271 | |
| torch.float32 | 62 | ||
| 07_after_forward | weights | torch.float32 | 692 |
| 08_before_loss | weights | torch.float32 | 692 |
| 09_after_loss | loss | torch.float32 | 1 |
| weights | torch.float32 | 692 | |
| 10_before_backward | weights | torch.float32 | 692 |
| 11_after_backward | gradients | None | 272 |
| torch.float32 | 420 | ||
| 12_before_optim_step | weights | torch.float32 | 692 |
| 13_optimizer_step | optimizer_states | torch.float32 | 1260 |
| 14_after_optim_step | weights | torch.float32 | 692 |
Model in bf16 (master_weights_dtype==bfloat16)
I also logged data types after setting master_weights_dtype in the training YAML to bfloat16.
url = "https://gist.githubusercontent.com/vishalbakshi/ec91a59754633611fd8eb33b59031243/raw/5b83a7ebd5759cf6bd2db2369edf1c73e1fb67cf/dtypes_logs_1ba_bf16.json"
df = _df(url)
df.head()| index | dtype | layer_number | module | layer_name | lora_layer | training_step | entity | |
|---|---|---|---|---|---|---|---|---|
| 0 | fit_start:embed_tokens.weight:weights | torch.bfloat16 | NaN | 00_embed_tokens | None | Not a LoRA Layer | 00_fit_start | weights |
| 1 | fit_start:layers.0.self_attn.q_proj.base_layer... | torch.bfloat16 | 0.0 | 02_self_attn | q_proj | base_layer | 00_fit_start | weights |
| 2 | fit_start:layers.0.self_attn.q_proj.lora_A.def... | torch.bfloat16 | 0.0 | 02_self_attn | q_proj | lora_A | 00_fit_start | weights |
| 3 | fit_start:layers.0.self_attn.q_proj.lora_B.def... | torch.bfloat16 | 0.0 | 02_self_attn | q_proj | lora_B | 00_fit_start | weights |
| 4 | fit_start:layers.0.self_attn.k_proj.base_layer... | torch.bfloat16 | 0.0 | 02_self_attn | k_proj | base_layer | 00_fit_start | weights |
Data Type by lora_layer
Interestingly, setting master_weights_dtype makes all LoRA layers bfloat16 but some non-LoRA layers’ entities are still in fp32.
df.groupby(['lora_layer', 'dtype'])['dtype'].count()| dtype | ||
|---|---|---|
| lora_layer | dtype | |
| Not a LoRA Layer | None | 62 |
| torch.bfloat16 | 2249 | |
| torch.float32 | 421 | |
| torch.int64 | 1 | |
| base_layer | None | 210 |
| torch.bfloat16 | 2520 | |
| lora_A | torch.bfloat16 | 2730 |
| lora_B | torch.bfloat16 | 2730 |
Data Types by entity (Activations, Gradients, Loss, Optimizer States and Weights)
All floating point values are in bfloat16 except for the loss and some of the optimizer states. I’m not sure why some optimizer states are in bf16, even though it says in the Composer docs:
Store the weights and perform the optimizer step in single precision, enabling the weight update to be done more precisely.
df.groupby(['entity', 'dtype'])['dtype'].count()| dtype | ||
|---|---|---|
| entity | dtype | |
| activation_input | torch.bfloat16 | 332 |
| torch.int64 | 1 | |
| activation_output | torch.bfloat16 | 333 |
| gradients | None | 272 |
| torch.bfloat16 | 420 | |
| loss | torch.float32 | 1 |
| optimizer_states | torch.bfloat16 | 840 |
| torch.float32 | 420 | |
| weights | torch.bfloat16 | 8304 |
Data Type by Composer Training Step
df.groupby(['training_step', 'entity', 'dtype'])['dtype'].count()| dtype | |||
|---|---|---|---|
| training_step | entity | dtype | |
| 00_fit_start | weights | torch.bfloat16 | 692 |
| 01_epoch_start | weights | torch.bfloat16 | 692 |
| 02_before_dataloader | weights | torch.bfloat16 | 692 |
| 03_after_dataloader | weights | torch.bfloat16 | 692 |
| 04_batch_start | weights | torch.bfloat16 | 692 |
| 05_before_forward | weights | torch.bfloat16 | 692 |
| 06_forward | activation_input | torch.bfloat16 | 332 |
| torch.int64 | 1 | ||
| activation_output | torch.bfloat16 | 333 | |
| 07_after_forward | weights | torch.bfloat16 | 692 |
| 08_before_loss | weights | torch.bfloat16 | 692 |
| 09_after_loss | loss | torch.float32 | 1 |
| weights | torch.bfloat16 | 692 | |
| 10_before_backward | weights | torch.bfloat16 | 692 |
| 11_after_backward | gradients | None | 272 |
| torch.bfloat16 | 420 | ||
| 12_before_optim_step | weights | torch.bfloat16 | 692 |
| 13_optimizer_step | optimizer_states | torch.bfloat16 | 840 |
| torch.float32 | 420 | ||
| 14_after_optim_step | weights | torch.bfloat16 | 692 |
Final Thoughts
I absolutely loved this exercise. I learned a ton about callbacks, data types during mixed precision training, and Python fundamentals. Working with LLM-Foundry has opened up a whole universe of learning opportunities as I try to better understand what’s going on under the hood. It’s a gift that keeps giving!
I’m trying to grow my YouTube channel so please give it a visit and subscribe if you want to stay in the loop.