Skip to content

IndexError: too many indices for tensor of dimension 2 #3560

Closed
@heroding77

Description

Reminder

  • I have read the README and searched the existing issues.

Reproduction

CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 src/train_bash.py
--stage sft
--do_train True
--model_name_or_path xxx
--finetuning_type lora
--template llama3
--flash_attn fa2
--dataset_dir data
--dataset LongAlpaca-12k
--cutoff_len 32768
--learning_rate 2e-05
--num_train_epochs 3.0
--max_samples 100000
--per_device_train_batch_size 2
--gradient_accumulation_steps 8
--lr_scheduler_type constant_with_warmup
--max_grad_norm 1.0
--logging_steps 5
--save_steps 140
--warmup_steps 20
--optim adamw_torch
--shift_attn True
--report_to none
--output_dir xxx
--fp16 True
--lora_rank 8
--lora_alpha 16
--lora_dropout 0.1
--use_dora True
--lora_target all
--plot_loss True

Expected behavior

When I use flash_attn2 and shift_attn together, I get an error: IndexError: too many indices for tensor of dimension 2. There is no problem when I use flash_attn2 or shift_attn separately. What could be the reason for this?

System Info

  • transformers version: 4.40.0
  • Platform: Linux-3.10.0-1160.el7.x86_64-x86_64-with-glibc2.35
  • Python version: 3.10.14
  • Huggingface_hub version: 0.22.2
  • Safetensors version: 0.4.3
  • Accelerate version: 0.29.3
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.2+cu118 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Others

flash_attn version: 2.5.8

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    solvedThis problem has been already solved

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions