-
Notifications
You must be signed in to change notification settings - Fork 528
TPU 6e-1 + Pytorch 2.6[TPU] does not work #8960
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Given that this issue is only present on A potential work around is to use software version instead of tpu-ubuntu2204-base. |
@qihqi Not related to machine. I relaunched TPU v6e-1 in a nother region, this time US-Central-B with same tpu-ubuntu2204-base image. It appears this image is has completedly broken hw/kernel support for TPU v6e-1. This is the only image I want to use for pytorch since I am not going to be using This time I tried torch 2.8 nightly with torch_xla[tpu] 2.8 nightly as well. (vm311) qubitium@t1v-n-fed9844f-w-0:~$ PJRT_DEVICE=TPU python3 -c "import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices(\"TPU\"))"
/home/qubitium/vm311/lib/python3.11/site-packages/torch_xla/core/xla_model.py:90: UserWarning: `devkind` argument is deprecated and will be removed in a future release.
warnings.warn("`devkind` argument is deprecated and will be removed in a "
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/home/qubitium/vm311/lib/python3.11/site-packages/torch_xla/core/xla_model.py", line 93, in get_xla_supported_devices
xla_devices = _DEVICES.value
^^^^^^^^^^^^^^
File "/home/qubitium/vm311/lib/python3.11/site-packages/torch_xla/utils/utils.py", line 29, in value
self._value = self._gen_fn()
^^^^^^^^^^^^^^
File "/home/qubitium/vm311/lib/python3.11/site-packages/torch_xla/core/xla_model.py", line 27, in <lambda>
_DEVICES = xu.LazyProperty(lambda: torch_xla._XLAC._xla_get_devices())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Bad StatusOr access: INTERNAL: Failed to get global TPU topology. |
@RaviVijay Please ping your Google Cloud contacts. This is pretty bad regression of the TPU cloud if the latest (only ubuntu image) can't even see TPU v6e using xla/torch. I suspect this has more to do with the gcloud image/kernel as it relates to tpu v6e than the actual code within xla/torch. |
Uh oh!
There was an error while loading. Please reload this page.
Ubuntu 22.04 image
Python 3.11
Google Cloud TPU 6e-1
Using pytorch TPU guide from: https://cloud.google.com/tpu/docs/run-calculation-pytorch
Everyting installs correct but execution can't find TPU? IS TPU 6e supported in 2.6.0 pytorch from Google?
# install guide from Google Docs: sudo apt-get update sudo apt-get install libopenblas-dev -y pip install numpy pip install torch torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html
I am running an active TPU 6e-1 instance, not some cpu compute instance.

The same setup work on v5e-4
The text was updated successfully, but these errors were encountered: