import re
import pandas as pd
import json
import requests
Logging Data Types for Activations, Gradients, Weights, Optimizer States and Loss during Training with LLM-Foundry
Background
In a previous blog post I shared my first couple of iterations of custom Composer callback used to log data types of different entities (activations, gradients, weights, optimizer states, and loss) during training with LLM-Foundry. In this blog post I’ll share my final callback iteration’s code, some lessons I learned along the way (i.e. LLaMA’s self-attention module doesn’t have positional arguments!) and analyze the logging results to observe entity data types throughout the training loop.
Composer Callback Walkthrough
The data types of entities (activations, gradients, weights, loss, and optimizer states) are logged during training with a custom Composer callback DtypeLogger
passed to the Composer Trainer
. This callback was built up and tested event-by-event using Claude. There is one event handler in the callback for each Composer event from <FIT_START>
to <BATCH_END>
:
# <INIT>
# <BEFORE_LOAD>
# <AFTER_LOAD>
# <FIT_START>
for epoch in range(NUM_EPOCHS):
# <EPOCH_START>
while True:
# <BEFORE_DATALOADER>
batch = next(dataloader)
if batch is None:
break
inputs, targets = batch
# <AFTER_DATALOADER>
# <BATCH_START>
# <BEFORE_FORWARD>
outputs = model.forward(inputs)
# <AFTER_FORWARD>
# <BEFORE_LOSS>
loss = model.loss(outputs, targets)
# <AFTER_LOSS>
# <BEFORE_BACKWARD>
loss.backward()
# <AFTER_BACKWARD>
optimizer.step()
optimizer.zero_grad()
# <BATCH_END>
# <EPOCH_END>
There are four explicit logging functions:
_log_model_weight_dtypes
_log_gradient_dtypes
_log_optimizer_state_dtypes
_log_loss_dtype
Additionally, activations are logged using register_forward_hook
for all modules except self-attention (more on that below). Self-attention inputs are logged using a monkey-patched forward pass.
class DtypeLogger(Callback):
def __init__(self, save_path="/model-checkpoints/dtype_tracking", log_interval=10):
self.save_path = Path(save_path)
self.dtype_logs = {'log': {}}
self.log_interval = log_interval
self.hooks = []
def fit_start(self, state: State, logger: Logger) -> None:
self._log_model_weight_dtypes(state, "fit_start")
self._save_logs()
def epoch_start(self, state: State, logger: Logger) -> None:
self._log_model_weight_dtypes(state, "epoch_start")
self._save_logs()
def before_dataloader(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.log_interval == 0:
self._log_model_weight_dtypes(state, "before_dataloader")
self._save_logs()
def after_dataloader(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.log_interval == 0:
self._log_model_weight_dtypes(state, "after_dataloader")
self._save_logs()
def batch_start(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.log_interval == 0:
self._log_model_weight_dtypes(state, "batch_start")
self._save_logs()
def before_forward(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.log_interval == 0:
self._log_model_weight_dtypes(state, "before_forward")
# Clear old hooks
for hook in self.hooks:
hook.remove()self.hooks = []
# Get the model
= state.model.model.base_model.model
model = model.model # This is the transformer part
transformer_model = state.timestamp.batch.value
batch_id
# Store original forward methods to restore later
self.original_forward_methods = {}
def hook_fn(layer_name, module_name):
def _hook(module, inputs, outputs):
# Log input activation dtype
if isinstance(inputs, tuple) and len(inputs) > 0:
self.dtype_logs["log"][f"forward:{module_name}:{layer_name}:activation_input"] = str(inputs[0].dtype)
# Log output activation dtype
if isinstance(outputs, torch.Tensor):
self.dtype_logs["log"][f"forward:{module_name}:{layer_name}:activation_output"] = str(outputs.dtype)
elif isinstance(outputs, tuple) and len(outputs) > 0:
self.dtype_logs["log"][f"forward:{module_name}:{layer_name}:activation_output"] = str(outputs[0].dtype)
return _hook
# Monkey patch self-attention modules
for layer_idx, layer in enumerate(transformer_model.layers):
# Store the original forward method
= layer.self_attn.forward
original_forward self.original_forward_methods[layer_idx] = original_forward
# Define a closure to capture the current layer_idx
def make_patched_forward(layer_idx, orig_forward):
def patched_forward(self_attn, *args, **kwargs):
# Log the hidden_states dtype
if 'hidden_states' in kwargs and hasattr(kwargs['hidden_states'], 'dtype'):
self.dtype_logs["log"][f"forward:self_attn:layer_{layer_idx}:activation_input"] = str(kwargs['hidden_states'].dtype)
# Call the original method as a bound method
# This ensures 'self_attn' is correctly passed as 'self'
return orig_forward.__get__(self_attn, type(self_attn))(**kwargs)
return patched_forward
# Replace the forward method
= make_patched_forward(layer_idx, original_forward).__get__(layer.self_attn, type(layer.self_attn))
layer.self_attn.forward
# Register hook for lm_head
if hasattr(model, 'lm_head'):
self.hooks.append(model.lm_head.register_forward_hook(hook_fn("output", "lm_head")))
# Register hook for embedding layer
self.hooks.append(transformer_model.embed_tokens.register_forward_hook(hook_fn("embeddings", "embed_tokens")))
# Register hooks for each transformer layer
for layer_idx, layer in enumerate(transformer_model.layers):
# Self-attention components - we still register hooks for outputs
self.hooks.append(layer.self_attn.register_forward_hook(hook_fn(f"layer_{layer_idx}", "self_attn")))
self.hooks.append(layer.self_attn.q_proj.register_forward_hook(hook_fn(f"layer_{layer_idx}", "q_proj")))
self.hooks.append(layer.self_attn.k_proj.register_forward_hook(hook_fn(f"layer_{layer_idx}", "k_proj")))
self.hooks.append(layer.self_attn.v_proj.register_forward_hook(hook_fn(f"layer_{layer_idx}", "v_proj")))
self.hooks.append(layer.self_attn.o_proj.register_forward_hook(hook_fn(f"layer_{layer_idx}", "o_proj")))
# MLP components
self.hooks.append(layer.mlp.register_forward_hook(hook_fn(f"layer_{layer_idx}", "mlp")))
self.hooks.append(layer.mlp.gate_proj.register_forward_hook(hook_fn(f"layer_{layer_idx}", "gate_proj")))
self.hooks.append(layer.mlp.up_proj.register_forward_hook(hook_fn(f"layer_{layer_idx}", "up_proj")))
self.hooks.append(layer.mlp.down_proj.register_forward_hook(hook_fn(f"layer_{layer_idx}", "down_proj")))
# Layer norms
self.hooks.append(layer.input_layernorm.register_forward_hook(hook_fn(f"layer_{layer_idx}", "input_layernorm")))
self.hooks.append(layer.post_attention_layernorm.register_forward_hook(hook_fn(f"layer_{layer_idx}", "post_attention_layernorm")))
# Final layer norm
self.hooks.append(transformer_model.norm.register_forward_hook(hook_fn("final", "norm")))
self._save_logs()
def after_forward(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.log_interval == 0:
self._log_model_weight_dtypes(state, "after_forward")
# Restore original forward methods
if hasattr(self, 'original_forward_methods'):
= state.model.model.base_model.model
model = model.model
transformer_model
for layer_idx, original_forward in self.original_forward_methods.items():
= original_forward
transformer_model.layers[layer_idx].self_attn.forward
self.original_forward_methods = {}
# Clear hooks
for hook in self.hooks:
hook.remove()self.hooks = []
self._save_logs()
def before_loss(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.log_interval == 0:
self._log_model_weight_dtypes(state, "before_loss")
self._save_logs()
def after_loss(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.log_interval == 0:
self._log_model_weight_dtypes(state, "after_loss")
self._log_loss_dtype(state, "after_loss")
self._save_logs()
def before_backward(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.log_interval == 0:
self._log_model_weight_dtypes(state, "before_backward")
self._save_logs()
def after_backward(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.log_interval == 0:
# Log gradient dtypes as before
self._log_gradient_dtypes(state, "after_backward")
# Track weight dtypes before optimizer step
self._log_model_weight_dtypes(state, "before_optim_step")
# Log optimizer state dtypes
self._log_optimizer_state_dtypes(state, "optimizer_step")
self._save_logs()
def batch_end(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.log_interval == 0:
# Track weight dtypes after optimizer step to detect precision changes
self._log_model_weight_dtypes(state, "after_optim_step")
self._save_logs()
def epoch_end(self, state: State, logger: Logger) -> None:
self._log_model_weight_dtypes(state, "epoch_end")
self._save_logs()
def _log_model_weight_dtypes(self, state: State, event_name: str) -> None:
= state.model
model for name, param in model.named_parameters():
= name.removeprefix("model.base_model.model.model.")
name self.dtype_logs["log"][f"{event_name}:{name}:weights"] = str(param.dtype)
def _log_gradient_dtypes(self, state: State, event_name: str) -> None:
= state.model
model for name, param in model.named_parameters():
= name.removeprefix("model.base_model.model.model.")
name if param.grad is not None: self.dtype_logs['log'][f"{event_name}:{name}:gradients"] = str(param.grad.dtype)
else: self.dtype_logs['log'][f"{event_name}:{name}:gradients"] = "None"
def _log_loss_dtype(self, state: State, event_name: str) -> None:
if hasattr(state, 'loss') and hasattr(state.loss, 'dtype'):
self.dtype_logs["log"][f"{event_name}:loss"] = str(state.loss.dtype)
def _log_optimizer_state_dtypes(self, state: State, event_name: str) -> None:
if hasattr(state, 'optimizers') and state.optimizers is not None:
# Handle single optimizer or list of optimizers
= state.optimizers if isinstance(state.optimizers, list) else [state.optimizers]
optimizers
for opt_idx, optimizer in enumerate(optimizers):
# Get optimizer state dict
= optimizer.state_dict()
opt_state
# Check if 'state' exists in the optimizer state dict
if 'state' in opt_state:
for param_id, param_state in opt_state['state'].items():
for state_name, state_value in param_state.items():
if isinstance(state_value, torch.Tensor):
# Store dtype of optimizer state tensors (momentum buffers, etc.)
= f"optimizer_{opt_idx}_param_{param_id}_{state_name}"
key self.dtype_logs["log"][f"{event_name}:{key}:optimizer_states"] = str(state_value.dtype)
def _save_logs(self) -> None:
self.save_path, exist_ok=True)
os.makedirs(= self.save_path / "dtype_logs.json"
log_file with open(log_file, 'w') as f:
self.dtype_logs, f, indent=2) json.dump(
The most involved event handler is before_forward
which involves creating a hook function (hook_fn
) passed to PyTorch’s register_forward_hook
which exposes the positional inputs and outputs of a module’s forward
pass. The hook function modifies self.dtype_logs
directly by storing the data type string of inputs and outputs. hook_fn
is used for all modules except self attention.
Self attention cannot utilize register_forward_hook
because the LlamaDecoderLayer does not call self attention forward pass with any positional arguments:
= self.self_attn(
hidden_states, self_attn_weights =hidden_states,
hidden_states=attention_mask,
attention_mask=position_ids,
position_ids=past_key_value,
past_key_value=output_attentions,
output_attentions=use_cache,
use_cache=cache_position,
cache_position=position_embeddings,
position_embeddings**kwargs,
)
Contrast this with how the forward pass of other modules are called with positional arguments only:
# self attention sublayers
= self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
query_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.o_proj(attn_output)
attn_output
# mlp sublayers
= self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
down_proj
# non-self attention modules
= self.input_layernorm(hidden_states)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.norm(hidden_states) hidden_states
Since self-attention inputs can’t be captured by a hook I had to monkey patch its forward pass to log its inputs’ data type:
for layer_idx, layer in enumerate(transformer_model.layers):
# Store the original forward method
= layer.self_attn.forward
original_forward self.original_forward_methods[layer_idx] = original_forward
# Define a closure to capture the current layer_idx
def make_patched_forward(layer_idx, orig_forward):
def patched_forward(self_attn, *args, **kwargs):
# Log the hidden_states dtype
if 'hidden_states' in kwargs and hasattr(kwargs['hidden_states'], 'dtype'):
self.dtype_logs["log"][f"forward:self_attn:layer_{layer_idx}:activation_input"] = str(kwargs['hidden_states'].dtype)
# Call the original method as a bound method
# This ensures 'self_attn' is correctly passed as 'self'
return orig_forward.__get__(self_attn, type(self_attn))(**kwargs)
return patched_forward
# Replace the forward method
= make_patched_forward(layer_idx, original_forward).__get__(layer.self_attn, type(layer.self_attn)) layer.self_attn.forward
patched_forward
receives positional arguments *args
(of which there are none) and keyword arguments **kwargs
(all of the arguments to the self-attention forward) and logs the data types of the inputs to self-attention (hidden_states
) as self_attn_input
before returning the outputs of the original forward pass.
A key line is orig_forward.__get__(self_attn, type(self_attn))(**kwargs)
. As Claude’s comment mentions, this is to avoid using orig_forward(self_attn, **kwargs)
which was causing the following error because the first argument, self_attn
, was being interpreted as hidden_states
whereas it was intended to represent self
:
TypeError: LlamaFlashAttention2.forward() got multiple values for argument 'hidden_states'
In short, when you call __get__(obj, type)
on a function it will bind that function as a method to the given object, thus no longer requiring you to pass in self
as an argument. This is critical because self_attn.forward
has no positional arguments. We can then pass in the keyword arguments to the bound method orig_forward.__get__(self_attn, type(self_attn))(**kwargs)
, and let the model continue using self-attention correctly. See the Descriptor Guide in the Python docs for more information.
Helper Functions
def parse_index(string):
"""Extract structured information from parameter names"""
= {
info 'layer_number': None,
'module': None,
'layer_name': None,
'lora_layer': None,
'training_step': None,
'entity': None
}
# layer = string.split(":")[1]
# info["layer"] = layer
= re.search(r'layers\.(\d+)', string)
layer_number_match if layer_number_match: info['layer_number'] = int(layer_number_match.group(1))
= [
modules "embed_tokens",
"input_layernorm",
"self_attn",
"post_attention_layernorm",
"mlp",
"norm",
"lm_head"
]
= re.search(r'(mlp|self_attn|input_layernorm|post_attention_layernorm|embed_tokens|norm|lm_head)', string)
module_match if module_match: info['module'] = str(modules.index(module_match.group(1))).zfill(2) + '_' + module_match.group(1)
= re.search(r'(q_proj|k_proj|v_proj|o_proj|gate_proj|up_proj|down_proj)', string)
layer_name_match if layer_name_match: info['layer_name'] = layer_name_match.group(1)
= re.search(r'(base_layer|lora_A|lora_B)', string)
lora_match if lora_match: info['lora_layer'] = lora_match.group(1)
else: info['lora_layer'] = "Not a LoRA Layer"
= [
training_steps "fit_start",
"epoch_start",
"before_dataloader",
"after_dataloader",
"batch_start",
"before_forward",
"forward",
"after_forward",
"before_loss",
"after_loss",
"before_backward",
"after_backward",
"before_optim_step",
"optimizer_step",
"after_optim_step"
]
= string.split(":")[0]
training_step 'training_step'] = str(training_steps.index(training_step)).zfill(2) + '_' + training_step
info[
'entity'] = string.split(":")[-1]
info[
return info
def _df(url):
= json.loads(requests.get(url).text)
dtype_data
= pd.DataFrame(dtype_data).reset_index()
df = df.rename(columns={"index": "index", "log": "dtype"})
df
= df['index'].apply(lambda x: parse_index(x))
parsed_info
'layer_number'] = parsed_info.apply(lambda x: x['layer_number'])
df['module'] = parsed_info.apply(lambda x: x['module'])
df['layer_name'] = parsed_info.apply(lambda x: x['layer_name'])
df['lora_layer'] = parsed_info.apply(lambda x: x['lora_layer'])
df['training_step'] = parsed_info.apply(lambda x: x['training_step'])
df['entity'] = parsed_info.apply(lambda x: x['entity'])
df[
return df
Model in fp32 (master_weights_dtype==None
)
In this case, master_weights_dtype
is not provided in the training YAML file.
= "https://gist.githubusercontent.com/vishalbakshi/9ade8d501629d4c30e8aecfa1c6f67cf/raw/0c162e2305002fbe57fd2570ade302c3659140a1/dtypes_logs_1ba_fp32.json"
url = _df(url)
df df.head()
index | dtype | layer_number | module | layer_name | lora_layer | training_step | entity | |
---|---|---|---|---|---|---|---|---|
0 | fit_start:embed_tokens.weight:weights | torch.float32 | NaN | 00_embed_tokens | None | Not a LoRA Layer | 00_fit_start | weights |
1 | fit_start:layers.0.self_attn.q_proj.base_layer... | torch.float32 | 0.0 | 02_self_attn | q_proj | base_layer | 00_fit_start | weights |
2 | fit_start:layers.0.self_attn.q_proj.lora_A.def... | torch.float32 | 0.0 | 02_self_attn | q_proj | lora_A | 00_fit_start | weights |
3 | fit_start:layers.0.self_attn.q_proj.lora_B.def... | torch.float32 | 0.0 | 02_self_attn | q_proj | lora_B | 00_fit_start | weights |
4 | fit_start:layers.0.self_attn.k_proj.base_layer... | torch.float32 | 0.0 | 02_self_attn | k_proj | base_layer | 00_fit_start | weights |
Data Types by lora_layer
All LoRA layer entities are in fp32.
'lora_layer', 'dtype'])['dtype'].count() df.groupby([
dtype | ||
---|---|---|
lora_layer | dtype | |
Not a LoRA Layer | None | 62 |
torch.bfloat16 | 331 | |
torch.float32 | 2339 | |
torch.int64 | 1 | |
base_layer | None | 210 |
torch.float32 | 2520 | |
lora_A | torch.float32 | 2730 |
lora_B | torch.float32 | 2730 |
Data Types by entity
(Activations, Gradients, Loss, Optimizer States and Weights)
Every entity except activations are in fp32. Some parameters don’t have gradients because we are training with LoRA.
'entity', 'dtype'])['dtype'].count() df.groupby([
dtype | ||
---|---|---|
entity | dtype | |
activation_input | torch.bfloat16 | 60 |
torch.float32 | 272 | |
torch.int64 | 1 | |
activation_output | torch.bfloat16 | 271 |
torch.float32 | 62 | |
gradients | None | 272 |
torch.float32 | 420 | |
loss | torch.float32 | 1 |
optimizer_states | torch.float32 | 1260 |
weights | torch.float32 | 8304 |
Data Types by Composer Training Step
'training_step', 'entity', 'dtype'])['dtype'].count() df.groupby([
dtype | |||
---|---|---|---|
training_step | entity | dtype | |
00_fit_start | weights | torch.float32 | 692 |
01_epoch_start | weights | torch.float32 | 692 |
02_before_dataloader | weights | torch.float32 | 692 |
03_after_dataloader | weights | torch.float32 | 692 |
04_batch_start | weights | torch.float32 | 692 |
05_before_forward | weights | torch.float32 | 692 |
06_forward | activation_input | torch.bfloat16 | 60 |
torch.float32 | 272 | ||
torch.int64 | 1 | ||
activation_output | torch.bfloat16 | 271 | |
torch.float32 | 62 | ||
07_after_forward | weights | torch.float32 | 692 |
08_before_loss | weights | torch.float32 | 692 |
09_after_loss | loss | torch.float32 | 1 |
weights | torch.float32 | 692 | |
10_before_backward | weights | torch.float32 | 692 |
11_after_backward | gradients | None | 272 |
torch.float32 | 420 | ||
12_before_optim_step | weights | torch.float32 | 692 |
13_optimizer_step | optimizer_states | torch.float32 | 1260 |
14_after_optim_step | weights | torch.float32 | 692 |
Model in bf16 (master_weights_dtype==bfloat16
)
I also logged data types after setting master_weights_dtype
in the training YAML to bfloat16
.
= "https://gist.githubusercontent.com/vishalbakshi/ec91a59754633611fd8eb33b59031243/raw/5b83a7ebd5759cf6bd2db2369edf1c73e1fb67cf/dtypes_logs_1ba_bf16.json"
url = _df(url)
df df.head()
index | dtype | layer_number | module | layer_name | lora_layer | training_step | entity | |
---|---|---|---|---|---|---|---|---|
0 | fit_start:embed_tokens.weight:weights | torch.bfloat16 | NaN | 00_embed_tokens | None | Not a LoRA Layer | 00_fit_start | weights |
1 | fit_start:layers.0.self_attn.q_proj.base_layer... | torch.bfloat16 | 0.0 | 02_self_attn | q_proj | base_layer | 00_fit_start | weights |
2 | fit_start:layers.0.self_attn.q_proj.lora_A.def... | torch.bfloat16 | 0.0 | 02_self_attn | q_proj | lora_A | 00_fit_start | weights |
3 | fit_start:layers.0.self_attn.q_proj.lora_B.def... | torch.bfloat16 | 0.0 | 02_self_attn | q_proj | lora_B | 00_fit_start | weights |
4 | fit_start:layers.0.self_attn.k_proj.base_layer... | torch.bfloat16 | 0.0 | 02_self_attn | k_proj | base_layer | 00_fit_start | weights |
Data Type by lora_layer
Interestingly, setting master_weights_dtype
makes all LoRA layers bfloat16 but some non-LoRA layers’ entities are still in fp32.
'lora_layer', 'dtype'])['dtype'].count() df.groupby([
dtype | ||
---|---|---|
lora_layer | dtype | |
Not a LoRA Layer | None | 62 |
torch.bfloat16 | 2249 | |
torch.float32 | 421 | |
torch.int64 | 1 | |
base_layer | None | 210 |
torch.bfloat16 | 2520 | |
lora_A | torch.bfloat16 | 2730 |
lora_B | torch.bfloat16 | 2730 |
Data Types by entity
(Activations, Gradients, Loss, Optimizer States and Weights)
All floating point values are in bfloat16 except for the loss and some of the optimizer states. I’m not sure why some optimizer states are in bf16, even though it says in the Composer docs:
Store the weights and perform the optimizer step in single precision, enabling the weight update to be done more precisely.
'entity', 'dtype'])['dtype'].count() df.groupby([
dtype | ||
---|---|---|
entity | dtype | |
activation_input | torch.bfloat16 | 332 |
torch.int64 | 1 | |
activation_output | torch.bfloat16 | 333 |
gradients | None | 272 |
torch.bfloat16 | 420 | |
loss | torch.float32 | 1 |
optimizer_states | torch.bfloat16 | 840 |
torch.float32 | 420 | |
weights | torch.bfloat16 | 8304 |
Data Type by Composer Training Step
'training_step', 'entity', 'dtype'])['dtype'].count() df.groupby([
dtype | |||
---|---|---|---|
training_step | entity | dtype | |
00_fit_start | weights | torch.bfloat16 | 692 |
01_epoch_start | weights | torch.bfloat16 | 692 |
02_before_dataloader | weights | torch.bfloat16 | 692 |
03_after_dataloader | weights | torch.bfloat16 | 692 |
04_batch_start | weights | torch.bfloat16 | 692 |
05_before_forward | weights | torch.bfloat16 | 692 |
06_forward | activation_input | torch.bfloat16 | 332 |
torch.int64 | 1 | ||
activation_output | torch.bfloat16 | 333 | |
07_after_forward | weights | torch.bfloat16 | 692 |
08_before_loss | weights | torch.bfloat16 | 692 |
09_after_loss | loss | torch.float32 | 1 |
weights | torch.bfloat16 | 692 | |
10_before_backward | weights | torch.bfloat16 | 692 |
11_after_backward | gradients | None | 272 |
torch.bfloat16 | 420 | ||
12_before_optim_step | weights | torch.bfloat16 | 692 |
13_optimizer_step | optimizer_states | torch.bfloat16 | 840 |
torch.float32 | 420 | ||
14_after_optim_step | weights | torch.bfloat16 | 692 |
Final Thoughts
I absolutely loved this exercise. I learned a ton about callbacks, data types during mixed precision training, and Python fundamentals. Working with LLM-Foundry has opened up a whole universe of learning opportunities as I try to better understand what’s going on under the hood. It’s a gift that keeps giving!
I’m trying to grow my YouTube channel so please give it a visit and subscribe if you want to stay in the loop.