File tree 4 files changed +10
-7
lines changed
4 files changed +10
-7
lines changed Original file line number Diff line number Diff line change 5
5
import gc
6
6
from pathlib import Path
7
7
8
- from safetensors .torch import save_file
9
8
import torch
10
9
import torch .nn as nn
11
10
import torch .nn .functional as F
@@ -92,6 +91,8 @@ def __init__(
92
91
93
92
def from_pretrained (self , file_path : str ):
94
93
if self .is_vllm and platform .system ().lower () == "linux" :
94
+ from safetensors .torch import save_file
95
+
95
96
from .velocity .llm import LLM
96
97
from .velocity .post_model import PostModel
97
98
@@ -104,7 +105,7 @@ def from_pretrained(self, file_path: str):
104
105
gpt .gpt .save_pretrained (vllm_folder / "gpt" )
105
106
post_model = (
106
107
PostModel (
107
- int (self .gpt .config .hidden_size ),
108
+ int (gpt .gpt .config .hidden_size ),
108
109
self .num_audio_tokens ,
109
110
self .num_text_tokens ,
110
111
)
Original file line number Diff line number Diff line change @@ -101,7 +101,12 @@ conda activate chattts
101
101
pip install -r requirements.txt
102
102
```
103
103
104
- #### Optional: Install TransformerEngine if using NVIDIA GPU (Linux only)
104
+ #### Optional: Install vLLM (Linux only)
105
+ ``` bash
106
+ pip install safetensors vllm==0.2.7 torchaudio
107
+ ```
108
+
109
+ #### Unrecommended Optional: Install TransformerEngine if using NVIDIA GPU (Linux only)
105
110
> [ !Note]
106
111
> The installation process is very slow.
107
112
@@ -113,7 +118,7 @@ pip install -r requirements.txt
113
118
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
114
119
```
115
120
116
- #### Optional: Install FlashAttention-2 (mainly NVIDIA GPU)
121
+ #### Unrecommended Optional: Install FlashAttention-2 (mainly NVIDIA GPU)
117
122
> [ !Note]
118
123
> See supported devices at the [ Hugging Face Doc] ( https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 ) .
119
124
Original file line number Diff line number Diff line change @@ -14,5 +14,3 @@ WeTextProcessing; sys_platform == 'linux'
14
14
nemo_text_processing ; sys_platform == 'linux'
15
15
av
16
16
pydub
17
- safetensors
18
- vllm >= 0.2.7 ; sys_platform == 'linux'
Original file line number Diff line number Diff line change 28
28
"transformers>=4.41.1" ,
29
29
"vector_quantize_pytorch" ,
30
30
"vocos" ,
31
- "safetensors" ,
32
31
],
33
32
platforms = "any" ,
34
33
classifiers = [
You can’t perform that action at this time.
0 commit comments