Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| from functools import partial | |
| from dataclasses import dataclass | |
| import torch | |
| import numpy as np | |
| from einops import rearrange | |
| import torch.nn.functional as F | |
| from torch.utils.checkpoint import checkpoint | |
| from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.models import ModelMixin | |
| from diffusers.utils import BaseOutput | |
| from ..modules.ae_modules import Encoder, Decoder | |
| from ..modules.ae_dualref_modules import VideoDecoder | |
| from ..utils import instantiate_from_config | |
| class DecoderOutput(BaseOutput): | |
| """ | |
| Output of decoding method. | |
| Args: | |
| sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): | |
| Decoded output sample of the model. Output of the last layer of the model. | |
| """ | |
| sample: torch.FloatTensor | |
| class AutoencoderKLOutput(BaseOutput): | |
| """ | |
| Output of AutoencoderKL encoding method. | |
| Args: | |
| latent_dist (`DiagonalGaussianDistribution`): | |
| Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. | |
| `DiagonalGaussianDistribution` allows for sampling latents from the distribution. | |
| """ | |
| latent_dist: "DiagonalGaussianDistribution" | |
| class AutoencoderKL(ModelMixin, ConfigMixin): | |
| def __init__(self, | |
| ddconfig, | |
| embed_dim, | |
| image_key="image", | |
| input_dim=4, | |
| use_checkpoint=False, | |
| ): | |
| super().__init__() | |
| self.image_key = image_key | |
| self.encoder = Encoder(**ddconfig) | |
| self.decoder = Decoder(**ddconfig) | |
| assert ddconfig["double_z"] | |
| self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) | |
| self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) | |
| self.embed_dim = embed_dim | |
| self.input_dim = input_dim | |
| self.use_checkpoint = use_checkpoint | |
| def encode(self, x, return_hidden_states=False, **kwargs): | |
| if return_hidden_states: | |
| h, hidden = self.encoder(x, return_hidden_states) | |
| moments = self.quant_conv(h) | |
| posterior = DiagonalGaussianDistribution(moments) | |
| return AutoencoderKLOutput(latent_dist=posterior), hidden | |
| else: | |
| h = self.encoder(x) | |
| moments = self.quant_conv(h) | |
| posterior = DiagonalGaussianDistribution(moments) | |
| return AutoencoderKLOutput(latent_dist=posterior) | |
| def decode(self, z, **kwargs): | |
| if len(kwargs) == 0: ## use the original decoder in AutoencoderKL | |
| z = self.post_quant_conv(z) | |
| dec = self.decoder(z, **kwargs) ##change for SVD decoder by adding **kwargs | |
| return dec | |
| def forward(self, input, sample_posterior=True, **additional_decode_kwargs): | |
| input_tuple = (input, ) | |
| forward_temp = partial(self._forward, sample_posterior=sample_posterior, **additional_decode_kwargs) | |
| return checkpoint(forward_temp, input_tuple, self.parameters(), self.use_checkpoint) | |
| def _forward(self, input, sample_posterior=True, **additional_decode_kwargs): | |
| posterior = self.encode(input)[0] | |
| if sample_posterior: | |
| z = posterior.sample() | |
| else: | |
| z = posterior.mode() | |
| dec = self.decode(z, **additional_decode_kwargs) | |
| ## print(input.shape, dec.shape) torch.Size([16, 3, 256, 256]) torch.Size([16, 3, 256, 256]) | |
| return dec, posterior | |
| def get_input(self, batch, k): | |
| x = batch[k] | |
| if x.dim() == 5 and self.input_dim == 4: | |
| b,c,t,h,w = x.shape | |
| self.b = b | |
| self.t = t | |
| x = rearrange(x, 'b c t h w -> (b t) c h w') | |
| return x | |
| def get_last_layer(self): | |
| return self.decoder.conv_out.weight | |
| class AutoencoderKL_Dualref(AutoencoderKL): | |
| def __init__(self, | |
| ddconfig, | |
| embed_dim, | |
| image_key="image", | |
| input_dim=4, | |
| use_checkpoint=False, | |
| ): | |
| super().__init__(ddconfig, embed_dim, image_key, input_dim, use_checkpoint) | |
| self.decoder = VideoDecoder(**ddconfig) | |
| def _forward(self, input, batch_size, sample_posterior=True, **additional_decode_kwargs): | |
| posterior, hidden_states = self.encode(input, return_hidden_states=True) | |
| hidden_states_first_last = [] | |
| ### use only the first and last hidden states | |
| for hid in hidden_states: | |
| hid = rearrange(hid, '(b t) c h w -> b c t h w', b=batch_size) | |
| hid_new = torch.cat([hid[:, :, 0:1], hid[:, :, -1:]], dim=2) | |
| hidden_states_first_last.append(hid_new) | |
| if sample_posterior: | |
| z = posterior[0].sample() | |
| else: | |
| z = posterior[0].mode() | |
| dec = self.decode(z, ref_context=hidden_states_first_last, **additional_decode_kwargs) | |
| ## print(input.shape, dec.shape) torch.Size([16, 3, 256, 256]) torch.Size([16, 3, 256, 256]) | |
| return dec, posterior |