Spaces:
Runtime error
Runtime error
use new fashion way to invoke autocast
Browse files
model.py
CHANGED
|
@@ -215,9 +215,9 @@ class SALMONN(nn.Module):
|
|
| 215 |
embeds = torch.cat([bos_embeds, prompt_left_embeds, speech_embeds, prompt_right_embeds], dim=1)
|
| 216 |
atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
|
| 217 |
|
| 218 |
-
from torch.
|
| 219 |
|
| 220 |
-
with autocast(
|
| 221 |
output = self.llama_model.generate(
|
| 222 |
inputs_embeds=embeds,
|
| 223 |
max_length=max_length,
|
|
|
|
| 215 |
embeds = torch.cat([bos_embeds, prompt_left_embeds, speech_embeds, prompt_right_embeds], dim=1)
|
| 216 |
atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
|
| 217 |
|
| 218 |
+
from torch.amp import autocast
|
| 219 |
|
| 220 |
+
with autocast("cuda", dtype=torch.float16):
|
| 221 |
output = self.llama_model.generate(
|
| 222 |
inputs_embeds=embeds,
|
| 223 |
max_length=max_length,
|