Skip to content

Token shape issue in LLaVA-onevision fine-tuning #38481

Open
@HoinJung

Description

@HoinJung

System Info

  • transformers version: 4.52.3
  • Platform: Linux-6.8.0-51-generic-x86_64-with-glibc2.39
  • Python version: 3.12.0
  • Huggingface_hub version: 0.32.2
  • Safetensors version: 0.5.3
  • Accelerate version: 1.7.0
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (GPU?): 2.7.0+cu126 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA L40S

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import os
# from datasets import load_dataset
from datasets import load_from_disk
from transformers import AutoTokenizer, AutoProcessor, LlavaOnevisionForConditionalGeneration, TrainingArguments, Trainer
from PIL import Image
import torch
from tqdm import tqdm

train_ds = load_from_disk('mydataset/vlm_hf_dataset')
validation_ds = load_from_disk('mydataset/vlm_hf_dataset_validation')
test_ds = load_from_disk('mydataset/vlm_hf_dataset_test')


# 2. Load model, tokenizer, and processor
model_id = "llava-hf/llava-onevision-qwen2-7b-ov-hf"

processor = AutoProcessor.from_pretrained(model_id)
model = LlavaOnevisionForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16,device_map='auto')

# 3. Preprocessing function
def preprocess(example):
    # image_path = os.path.join("./data", example["image_path"])
    image_path =example["image_path"]

    image = Image.open(image_path).convert("RGB")

    # Tokenize input
    prompt = example["question"]
    answer = example["response"]
    full_input = prompt + " " + answer
    processed = processor(text = full_input, images=image, return_tensors="pt",
                                padding='max_length',truncation=True,max_length=1024)
    # print(processed)
    imgae_sizes = processed['image_sizes'][0]
    input_ids = processed['input_ids'][0]
    attention_mask = processed['attention_mask'][0]

    prompt_ids = processor.tokenizer(prompt, return_tensors="pt").input_ids[0]
    labels = input_ids.clone()
    labels[:len(prompt_ids)] = -100

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
        "pixel_values": processed['pixel_values'][0],'imgae_sizes':imgae_sizes }
def save_dataset(raw_dataset, split_name, save_path):
    save_file = os.path.join(save_path, f"{split_name}.pt")
    if os.path.exists(save_file):
        return 
    else:
        processed = []
        for example in tqdm(raw_dataset, desc=f"Preprocessing split {split_name}"):
            processed.append(preprocess(example))
        torch.save(processed, save_file )
# 4. Apply preprocessing
save_dir = './preprocessed_llava_one'
os.makedirs(save_dir, exist_ok=True)
save_dataset(train_ds, 'train', save_dir)
save_dataset(validation_ds, 'validation',save_dir)
save_dataset(test_ds, 'test',save_dir)


class LLAVADataset(torch.utils.data.Dataset):
    def __init__(self, path):
        self.data = torch.load(path)
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]

def collate_fn(batch):
    input_ids = torch.nn.utils.rnn.pad_sequence(
        [x["input_ids"] for x in batch], batch_first=True, padding_value=processor.tokenizer.pad_token_id
    )
    attention_mask = torch.nn.utils.rnn.pad_sequence(
        [x["attention_mask"] for x in batch], batch_first=True, padding_value=0
    )
    labels = torch.nn.utils.rnn.pad_sequence(
        [x["labels"] for x in batch], batch_first=True, padding_value=-100
    )

    # Handling pixel values with AnyRes strategy
    # pixel_values = torch.stack([x["pixel_values"] for x in batch])
    max_len = max(x["pixel_values"].shape[0] for x in batch)
    padded_pixel_values = []
    for x in batch:
        seq = x["pixel_values"]
        padding_len = max_len - seq.shape[0]
        padding = torch.zeros((padding_len, 3, 384, 384), device=seq.device, dtype=seq.dtype)
        padded_seq = torch.cat((seq, padding), dim=0)
        padded_pixel_values.append(padded_seq)

    pixel_values = torch.stack(padded_pixel_values).to(dtype=torch.float16)
    image_sizes = torch.stack([x["imgae_sizes"] for x in batch]).to(dtype=torch.float16)

    return {
        "input_ids": input_ids,
        "labels": labels,
        "attention_mask": attention_mask,
        "pixel_values": pixel_values,
        "image_sizes": image_sizes,
    }


processed_train = LLAVADataset(os.path.join(save_dir,'train.pt'))
processed_validation = LLAVADataset(os.path.join(save_dir,'validation.pt'))

print(processed_train)
# 6. Training setup
training_args = TrainingArguments(
    output_dir="./llava-finetuned",
    per_device_train_batch_size=2,
    num_train_epochs=3,
    logging_steps=10,
    save_strategy="epoch",
    fp16=False,
    gradient_accumulation_steps=1,
    remove_unused_columns=False,
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=processed_train,
    eval_dataset = processed_validation,
    tokenizer=processor.tokenizer,
    data_collator=collate_fn
)

# 7. Start training
trainer.train()

This is a simple code for fine-tuning the LLaVA-onevision model, and I got an error,

Traceback (most recent call last):      
  File "/home/mine/project/finetuning.py", line 148, in <module>                                                                                                                                          
    trainer.train()                              
  File "/home/mine/miniconda3/envs/py312/lib/python3.12/site-packages/transformers/trainer.py", line 2240, in train
    return inner_training_loop(                                                                                                                                                                             
           ^^^^^^^^^^^^^^^^^^^^         
  File "/home/mine/miniconda3/envs/py312/lib/python3.12/site-packages/transformers/trainer.py", line 2555, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)                                                                                                                                    
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mine/miniconda3/envs/py312/lib/python3.12/site-packages/transformers/trainer.py", line 3745, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)                                                                                                                          
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mine/miniconda3/envs/py312/lib/python3.12/site-packages/transformers/trainer.py", line 3810, in compute_loss
    outputs = model(**inputs)                                                                                                                                                                               
              ^^^^^^^^^^^^^^^           
  File "/home/mine/miniconda3/envs/py312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)                                                                                                                                                                 
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mine/miniconda3/envs/py312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)                                                                                                                                                                    
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mine/miniconda3/envs/py312/lib/python3.12/site-packages/accelerate/hooks.py", line 175, in new_forward
    output = module._old_forward(*args, **kwargs)  
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mine/miniconda3/envs/py312/lib/python3.12/site-packages/transformers/utils/generic.py", line 969, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mine/miniconda3/envs/py312/lib/python3.12/site-packages/transformers/models/llava_onevision/modeling_llava_onevision.py", line 829, in forward
    outputs = self.model(          
              ^^^^^^^^^^^           
  File "/home/mine/miniconda3/envs/py312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mine/miniconda3/envs/py312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mine/miniconda3/envs/py312/lib/python3.12/site-packages/transformers/utils/generic.py", line 969, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mine/miniconda3/envs/py312/lib/python3.12/site-packages/transformers/models/llava_onevision/modeling_llava_onevision.py", line 577, in forward
    raise ValueError(              
ValueError: Image features and image tokens do not match: tokens: 0, features 12438

When I check

print("self.config.image_token_id",self.config.image_token_id)
print("n_image_tokens",n_image_tokens)
print("image_features",image_features.shape)

It shows

self.config.image_token_id 151646          
n_image_tokens tensor(0, device='cuda:0')                                                                                                                                                                   
image_features torch.Size([12438, 3584])

I'm not sure whether it is from transformers code or my preprocessing or collate function.
One different thing in my code is the zero padding in pixel_values because as all the images have different resolution, therefore they have different token length, so it is not stackable without padding. I skip making resizing the image since I believe AnyRes in LLaVA-OneVision can handle this problem.

Do I need to change this strategy, or modify some code in transformers?

Expected behavior

Should work normally.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions