Description
System Info
transformers
version: 4.40.0- Platform: Linux-5.15.0-100-generic-x86_64-with-glibc2.35
- Python version: 3.11.8
- Huggingface_hub version: 0.21.4
- Safetensors version: 0.4.2
- Accelerate version: 0.28.0
- PyTorch version (GPU?): 2.2.0+cu121 (True)
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Experiments are conducted using the following scripts:
import matplotlib.pyplot as plt
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)
model_id = "Qwen/Qwen1.5-0.5B-Chat"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
dataset = load_dataset("mlabonne/guanaco-llama2-1k", split="train[:200]").map(
lambda x: tokenizer(x["text"], max_length=1024, truncation=True, add_special_tokens=False),
batched=True,
)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
training_args = TrainingArguments(
output_dir="test_galore",
learning_rate=1e-4,
per_device_train_batch_size=4,
gradient_accumulation_steps=1,
num_train_epochs=6.0,
logging_steps=10,
warmup_steps=10, # warmup_steps=0
optim="galore_adamw_layerwise",
optim_args="scale=2.0,update_proj_gap=400",
optim_target_modules="all-linear",
gradient_checkpointing=True,
report_to="none",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
data_collator=data_collator,
)
trainer.train()
steps, losses = [], []
for i in range(len(trainer.state.log_history)):
if "loss" in trainer.state.log_history[i]:
steps.append(trainer.state.log_history[i]["step"])
losses.append(trainer.state.log_history[i]["loss"])
plt.figure()
plt.plot(steps, losses)
plt.savefig("loss.png", format="png", dpi=100)
With warmup_steps=0
, the loss converges normally:
With warmup_steps=10
, the loss cannot converge:
Expected behavior
The implementation of layerwise GaLore optimizer depends on the hooks, the trainer first attaches a hook to each parameter for the optimizers:
transformers/src/transformers/trainer.py
Lines 1351 to 1358 in 8c12690
and then attaches another hook to each parameter for the schedulers:
transformers/src/transformers/optimization.py
Lines 445 to 453 in 8c12690
However, since the scheduler hook was attached after the optimizer hook, the parameter gradient had been already cleared by the optimizer hook. Therefore the if condition param.grad is not None
could NOT hold inside the scheduler hook, and the scheduler_dict[param].step()
was actually not even called during the training.
Apart from the above experiment, we can alternatively validate it by adding a print statement in the scheduler hook:
def scheduler_hook(param):
# Since the optimizer hook has been already attached we only need to
# attach the scheduler hook
if param.grad is not None:
print("scheduler step")
scheduler_dict[param].step()
As we can see, this string will never be printed during training. Consequently, the scheduler unexpectedly has no effect to the training if the layerwise optimizer is used.
Activity