Skip to content

FAIL: Using cache and enabling back past_key_values cache #58

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
msedalatzadeh opened this issue May 6, 2025 · 2 comments
Open

FAIL: Using cache and enabling back past_key_values cache #58

msedalatzadeh opened this issue May 6, 2025 · 2 comments

Comments

@msedalatzadeh
Copy link

I tried various models of transformers.js and those that support past_key_values does not actually handle it. I face several issues:

  1. The default past_key_values are gpuBuffer tensors and ONNX requires cpu tensors as input
  2. Downloading past_key_values into cpu using downloader method and running again will run into dimension inconsistency problems. Basically we need to feed input_ids, attention_mask, position_ids into the model.generate(), I tried various shapes and all failed:
  • Assume the past_key_value.dims[2] = past_length and the input_ids.dims[1] = full_length. I tweaked all combinations of each input being past_length or full_length or full_length - past_length or simply 1 (one token). None worked.

Please share a working example of transformers.js with past_key_values enabled.

Here is my code:

const full_inputs = tokenizer.apply_chat_template(messages, {
  add_generation_prompt: true,
  return_dict: true
});


for (const key in past_gpu_kv) {
  if (past_gpu_kv[key]?.ort_tensor) {
    past_kv[key] = await convertToCPUTensor(past_gpu_kv[key].ort_tensor);
  }
}

const { past_key_values, sequences } = await model.generate({
  ...inputs,
  past_key_values: past_kv,
  use_cache: true,
  do_sample: false,
  top_k: 3,
  temperature: 0.2,
  max_new_tokens: 1024,
  streamer,
  stopping_criteria,
  return_dict_in_generate: true,
});
async function convertToCPUTensor(ortTensor) {
  if (!ortTensor || typeof ortTensor.downloader !== 'function') {
    throw new Error('Invalid ort_tensor: missing downloader method');
  }

  // Download the data from GPU
  const rawData = await ortTensor.downloader(); // usually a Float16Array or Float32Array

  // Check the tensor type and convert to Float32Array if it's float16
  let data = rawData;
  let dtype = ortTensor.type;

  if (dtype === 'float16') {
    data = Float16Array.from(rawData); // Ensure data remains float16
    dtype = 'float16';
  }

  return new Tensor(dtype, data, ortTensor.dims);
}
function buildInputsForGenerate(full_inputs, past_key_values_cache, modelKey) {
  const input_ids_tensor = full_inputs.input_ids;

  if (!past_key_values_cache[modelKey]) {
    return full_inputs;
  }

  const seq_len = input_ids_tensor.dims[1];
  if (seq_len === 0) {
    throw new Error("input_ids is empty — can't slice last token.");
  }
  // Use past key dims to get cached length
  const past = past_key_values_cache[modelKey];
  const past_len = past['past_key_values.0.key'].dims[2];
  const new_len = seq_len - past_len;

  const input_ids = input_ids_tensor.slice([0, 1], [seq_len - 1, seq_len]);
  
  const attention_mask_length = seq_len + 1;
  const attention_mask = new Tensor(
    "int64",
    BigInt64Array.from([
      //...Array(past_len).fill(BigInt(0)),       // Mask out past tokens
      ...Array(attention_mask_length).fill(BigInt(1)),        // Attend only to new tokens
    ]),
    [1, attention_mask_length]
  );

  const position_ids = new Tensor(
    "int64",
    BigInt64Array.from([...Array(new_len).keys()].map(i => BigInt(past_len + i))),
    [1, new_len]
  );

  return {
    input_ids,
    attention_mask,
    position_ids,
  };
}
@PaUl1481980
Copy link

BigInt64Array.from([...Array(new_len).keys()].map(i => BigInt(past_len + i)))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants