1
- from invokeai .backend .quantization .gguf .ggml_tensor import GGMLTensor
2
1
from typing import Any
3
2
4
3
import torch
5
4
5
+ from invokeai .backend .quantization .gguf .ggml_tensor import GGMLTensor
6
+
6
7
7
8
class CachedModelOnlyFullLoad :
8
9
"""A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device.
@@ -78,8 +79,7 @@ def full_load_to_vram(self) -> int:
78
79
new_state_dict [k ] = v .to (self ._compute_device , copy = True )
79
80
self ._model .load_state_dict (new_state_dict , assign = True )
80
81
81
-
82
- check_for_gguf = hasattr (self ._model , 'state_dict' ) and self ._model .state_dict ().get ("img_in.weight" )
82
+ check_for_gguf = hasattr (self ._model , "state_dict" ) and self ._model .state_dict ().get ("img_in.weight" )
83
83
if isinstance (check_for_gguf , GGMLTensor ):
84
84
old_value = torch .__future__ .get_overwrite_module_params_on_conversion ()
85
85
torch .__future__ .set_overwrite_module_params_on_conversion (True )
@@ -103,7 +103,7 @@ def full_unload_from_vram(self) -> int:
103
103
if self ._cpu_state_dict is not None :
104
104
self ._model .load_state_dict (self ._cpu_state_dict , assign = True )
105
105
106
- check_for_gguf = hasattr (self ._model , ' state_dict' ) and self ._model .state_dict ().get ("img_in.weight" )
106
+ check_for_gguf = hasattr (self ._model , " state_dict" ) and self ._model .state_dict ().get ("img_in.weight" )
107
107
if isinstance (check_for_gguf , GGMLTensor ):
108
108
old_value = torch .__future__ .get_overwrite_module_params_on_conversion ()
109
109
torch .__future__ .set_overwrite_module_params_on_conversion (True )
0 commit comments