Skip to content

Commit 3f34789

Browse files
Vargolpsychedelicious
authored andcommitted
fix import ordering, remove code I reverted that the resync added back
1 parent 4e237a2 commit 3f34789

File tree

2 files changed

+4
-17
lines changed

2 files changed

+4
-17
lines changed

invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
21
from typing import Any
32

43
import torch
54

5+
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
6+
67

78
class CachedModelOnlyFullLoad:
89
"""A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device.
@@ -78,8 +79,7 @@ def full_load_to_vram(self) -> int:
7879
new_state_dict[k] = v.to(self._compute_device, copy=True)
7980
self._model.load_state_dict(new_state_dict, assign=True)
8081

81-
82-
check_for_gguf = hasattr(self._model, 'state_dict') and self._model.state_dict().get("img_in.weight")
82+
check_for_gguf = hasattr(self._model, "state_dict") and self._model.state_dict().get("img_in.weight")
8383
if isinstance(check_for_gguf, GGMLTensor):
8484
old_value = torch.__future__.get_overwrite_module_params_on_conversion()
8585
torch.__future__.set_overwrite_module_params_on_conversion(True)
@@ -103,7 +103,7 @@ def full_unload_from_vram(self) -> int:
103103
if self._cpu_state_dict is not None:
104104
self._model.load_state_dict(self._cpu_state_dict, assign=True)
105105

106-
check_for_gguf = hasattr(self._model, 'state_dict') and self._model.state_dict().get("img_in.weight")
106+
check_for_gguf = hasattr(self._model, "state_dict") and self._model.state_dict().get("img_in.weight")
107107
if isinstance(check_for_gguf, GGMLTensor):
108108
old_value = torch.__future__.get_overwrite_module_params_on_conversion()
109109
torch.__future__.set_overwrite_module_params_on_conversion(True)

invokeai/backend/quantization/gguf/ggml_tensor.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -119,19 +119,6 @@ def size(self, dim: int | None = None):
119119
return self.tensor_shape[dim]
120120
return self.tensor_shape
121121

122-
@overload
123-
def to(self, *args, **kwargs) -> torch.Tensor: ...
124-
125-
def to(self, *args, **kwargs):
126-
for func_arg in args:
127-
if isinstance(func_arg, torch.dtype) and func_arg != self.quantized_data.dtype:
128-
raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.")
129-
if "dtype" in kwargs.keys():
130-
if kwargs["dtype"] != self.quantized_data.dtype:
131-
raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.")
132-
self.quantized_data = self.quantized_data.to(*args, **kwargs)
133-
return self
134-
135122
@property
136123
def shape(self) -> torch.Size: # pyright: ignore[reportIncompatibleVariableOverride] pyright doesn't understand this for some reason.
137124
"""The shape of the tensor after dequantization. I.e. the shape that will be used in any math ops."""

0 commit comments

Comments
 (0)