TIL: Custom Composer Callback to Push Checkpoints to HuggingFace Hub During Training.
Background
In this short TIL blog post, I’m going to share the code I wrote with Claude’s help for a custom Composer callback which pushes the model to Hugging Face Hub every specified number of steps. The purpose of doing so is so that you can run evaluation after training so it doesn’t slow down training.
Custom Composer Callback
class HFPushCallback(Callback):
def __init__(self, repo_id: str, push_every_n_steps: int = 10):
self.repo_id = repo_id
self.push_every_n_steps = push_every_n_steps
self.token = os.getenv("HF_TOKEN")
self.hf_api = HfApi(token=self.token)
create_repo(=self.repo_id,
repo_id=self.token,
token=True,
private=True
exist_ok
)
def batch_end(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.push_every_n_steps == 0:
self._push_model(state)
def _push_model(self, state: State):
with tempfile.TemporaryDirectory() as temp_dir:
state.model.model.save_pretrained(temp_dir)
self.hf_api.upload_folder(
=temp_dir,
folder_path=self.repo_id,
repo_id=f"Step {state.timestamp.batch.value}"
commit_message )
Important to note, state.model
is the ComposerHFCausalLM
wrapper around the HuggingFace model, so you have to access state.model.model
to use the attribute save_pretrained
.
Running Inference
You can use the following code to run inference on the model, just as you would any set of PEFT adapters from Hugging Face Hub.
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
= "HuggingFaceTB/SmolLM2-135M"
model_id = AutoTokenizer.from_pretrained(model_id)
tokenizer = AutoModelForCausalLM.from_pretrained(model_id)
model
= PeftModel.from_pretrained(
model
model,"<repo_id>",
= "<revision>"
revision
)
= "The best thing about artificial intelligence is "
prompt = tokenizer(prompt, return_tensors="pt")
inputs = inputs["attention_mask"]
attention_mask
= model.generate(
outputs 'input_ids'],
inputs[=attention_mask,
attention_mask=tokenizer.eos_token_id
pad_token_id
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
The revision
parameter is the commit ID in your Hugging Face repo. In this way, if you, say, push your model every 100 steps, then you can use the revision
argument for each of those checkpoints and run your evaluations. Then you can log those evaluations to your W&B project so that your evaluation log is comparable with other training logs.