Skip to content

[RFC] PyTorch Custom Operators & Multi-Backend Support #1545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
matthewdouglas opened this issue Feb 27, 2025 · 3 comments
Closed

[RFC] PyTorch Custom Operators & Multi-Backend Support #1545

matthewdouglas opened this issue Feb 27, 2025 · 3 comments
Labels
Cross Platform High Priority (first issues that will be worked on) RFC request for comments on proposed library improvements
Milestone

Comments

@matthewdouglas
Copy link
Member

Purpose

We intend to integrate PyTorch Custom Operators as the primary mechanism for dispatching to device-specific operator implementations. An initial scaffolding of this is presented in PR #1544. This RFC will serve as a guideline to collect community feedback and refine our development plans moving forward.

Why?

  • Registering operators with torch.library allows us to take advantage of the existing device dispatch mechanisms in PyTorch.
  • We can treat calls to functionality in our CUDA kernels, or other low-level backend implementations, as opaque for improved torch.compile support.
  • We can provide naive implementations of operators with only PyTorch code as a fallback option. This may additionally serve as a secondary CPU baseline, as per [RFC] Cross-Platform Refactor: CPU-only implementation #1021.
  • This helps to simplify the development for additional backends, while taking an idiomatic modern PyTorch approach.

What about the multi-backend-refactor branch?

We are planning to deprecate further development on that branch upon the merging of #1544. After that point, the expectation is that we will implement backends using the new custom operator registration mechanisms. We expect to be able to reuse much of the existing implementations in the refactoring process.

Our goal is to aggressively mainline our in-tree backends, while additionally enabling out-of-tree backends. We will expand on this topic in the near future.

Supersedure

This RFC is intended to supersede topics which were covered in previously related RFCs which remained open as of this writing:

Related Issues

Related issues and discussions include:

Additionally, this relates to the following issues and discussions which have been closed:

Relevant contributors

The following contributors may have particular interest and feedback on this topic:

@Titus-von-Koeller
@christoph-koehncke
@jiqing-feng
@pnunna93
@akx
@rickardp
@ji-huazhong
@SlightwindSec

@matthewdouglas matthewdouglas added Cross Platform RFC request for comments on proposed library improvements labels Feb 27, 2025
@matthewdouglas matthewdouglas pinned this issue Feb 27, 2025
@matthewdouglas matthewdouglas added the High Priority (first issues that will be worked on) label Feb 28, 2025
@matthewdouglas matthewdouglas added this to the v0.46.0 milestone Mar 3, 2025
@jiqing-feng
Copy link
Contributor

Hi @matthewdouglas . That sounds pretty cool. One thing I want to make sure that we don't need to build any C++ codes in CPU/XPU, so we do not need cmake in the CPU/XPU backend. I assume it won't break your design right?

@matthewdouglas
Copy link
Member Author

@jiqing-feng That's perfectly fine. We won't expect that every backend would need to build a library, or that any library it builds would necessarily need to expose the same API.

For the CPU/XPU case we'll guard around what the cextension module does. Right now there is still CPU version of the C library that exposes a couple of functions: cquantize_blockwise_cpu_fp32 and cdequantize_blockwise_cpu_fp32. Those correspond to the operator definitions in this RFC:

torch.library.define("bitsandbytes::quantize_blockwise", "(Tensor A, Tensor code, int blocksize) -> (Tensor, Tensor)")

torch.library.define(
    "bitsandbytes::dequantize_blockwise",
    "(Tensor A, Tensor absmax, Tensor code, int blocksize, ScalarType dtype) -> Tensor",
)

I have not benchmarked it, but I expect that we'll be able to implement it with torch, like what was already done here:

def dequant_8bit(A, offset, quant_state):
assert A.dtype == torch.uint8
absmax = quant_state.code[A.reshape(-1).int()]
blocks = absmax.shape[-1] // 256
res = absmax.shape[-1] % 256
if res != 0:
absmax = F.pad(absmax, (0, 256 - res), mode="constant", value=0)
absmax = (absmax.view(-1, 256) * quant_state.absmax.view(-1, 1)).to(quant_state.dtype).reshape(-1)
absmax = absmax[: blocks * 256 + res]
absmax = absmax.reshape(A.shape)
absmax += offset
return absmax

@matthewdouglas
Copy link
Member Author

I'm going to close this out as this is implemented and merged with #1544.

General note: the custom operators in bitsandbytes are still subject to revision, and optimizers have not been implemented yet, but it is stable for internal use and initial backend integrations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Cross Platform High Priority (first issues that will be worked on) RFC request for comments on proposed library improvements
Projects
None yet
Development

No branches or pull requests

2 participants