Convert tile to torch.cuda.FloatTensor if it's not already of that type
Browse files
opensora/models/ae/videobase/causal_vae/modeling_causalvae.py
CHANGED
|
@@ -610,6 +610,11 @@ class CausalVAEModel(VideoBaseAE_PL):
|
|
| 610 |
i : i + self.tile_latent_min_size,
|
| 611 |
j : j + self.tile_latent_min_size,
|
| 612 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 613 |
tile = self.post_quant_conv(tile)
|
| 614 |
decoded = self.decoder(tile)
|
| 615 |
row.append(decoded)
|
|
|
|
| 610 |
i : i + self.tile_latent_min_size,
|
| 611 |
j : j + self.tile_latent_min_size,
|
| 612 |
]
|
| 613 |
+
|
| 614 |
+
# Convert tile to torch.cuda.FloatTensor if it's not already of that type
|
| 615 |
+
if tile.dtype != torch.float32:
|
| 616 |
+
tile = tile.float()
|
| 617 |
+
|
| 618 |
tile = self.post_quant_conv(tile)
|
| 619 |
decoded = self.decoder(tile)
|
| 620 |
row.append(decoded)
|