Spaces:
Runtime error
Runtime error
| import torch | |
| from diffusers import DPMSolverMultistepScheduler, UniPCMultistepScheduler | |
| from typing import List | |
| def AdamBmixer(order, ets, b=1): | |
| cur_order = min(order, len(ets)) | |
| if cur_order == 1: | |
| prime = b * ets[-1] | |
| elif cur_order == 2: | |
| prime = ((2+b) * ets[-1] - (2-b)*ets[-2]) / 2 | |
| elif cur_order == 3: | |
| prime = ((18+5*b) * ets[-1] - (24-8*b) * ets[-2] + (6-1*b) * ets[-3]) / 12 | |
| elif cur_order == 4: | |
| prime = ((46+9*b) * ets[-1] - (78-19*b) * ets[-2] + (42-5*b) * ets[-3] - (10-b) * ets[-4]) / 24 | |
| elif cur_order == 5: | |
| prime = ((1650+251*b) * ets[-1] - (3420-646*b) * ets[-2] | |
| + (2880-264*b) * ets[-3] - (1380-106*b) * ets[-4] | |
| + (270-19*b)* ets[-5]) / 720 | |
| else: | |
| raise NotImplementedError | |
| prime = prime/b | |
| return prime | |
| class PLMSWithHBScheduler(): | |
| """ | |
| PLMS with Polyak's Heavy Ball Momentum (HB) for diffusion ODEs. | |
| We implement it as a wrapper for schedulers in diffusers (https://github.com/huggingface/diffusers) | |
| When order is an integer, this method is equivalent to PLMS without momentum. | |
| """ | |
| def __init__(self, scheduler, order): | |
| self.scheduler = scheduler | |
| self.ets = [] | |
| self.update_order(order) | |
| self.mixer = AdamBmixer | |
| def update_order(self, order): | |
| self.order = order // 1 + 1 if order%1 > 0 else order // 1 | |
| self.beta = order % 1 if order%1 > 0 else 1 | |
| self.vel = None | |
| def clear(self): | |
| self.ets = [] | |
| self.vel = None | |
| def update_ets(self, val): | |
| self.ets.append(val) | |
| if len(self.ets) > self.order: | |
| self.ets.pop(0) | |
| def _step_with_momentum(self, grads): | |
| self.update_ets(grads) | |
| prime = self.mixer(self.order, self.ets, 1.0) | |
| self.vel = (1 - self.beta) * self.vel + self.beta * prime | |
| return self.vel | |
| def step( | |
| self, | |
| grads: torch.FloatTensor, | |
| timestep: int, | |
| latents: torch.FloatTensor, | |
| output_mode: str = "scale", | |
| ): | |
| if self.vel is None: self.vel = grads | |
| if hasattr(self.scheduler, 'sigmas'): | |
| step_index = (self.scheduler.timesteps == timestep).nonzero().item() | |
| sigma = self.scheduler.sigmas[step_index] | |
| sigma_next = self.scheduler.sigmas[step_index + 1] | |
| del_g = sigma_next - sigma | |
| update_val = self._step_with_momentum(grads) | |
| return latents + del_g * update_val | |
| elif isinstance(self.scheduler, DPMSolverMultistepScheduler): | |
| step_index = (self.scheduler.timesteps == timestep).nonzero().item() | |
| current_timestep = self.scheduler.timesteps[step_index] | |
| prev_timestep = 0 if step_index == len(self.scheduler.timesteps) - 1 else self.scheduler.timesteps[step_index + 1] | |
| alpha_prod_t = self.scheduler.alphas_cumprod[current_timestep] | |
| alpha_bar_prev = self.scheduler.alphas_cumprod[prev_timestep] | |
| s0 = torch.sqrt(alpha_prod_t) | |
| s_1 = torch.sqrt(alpha_bar_prev) | |
| g0 = torch.sqrt(1-alpha_prod_t)/s0 | |
| g_1 = torch.sqrt(1-alpha_bar_prev)/s_1 | |
| del_g = g_1 - g0 | |
| update_val = self._step_with_momentum(grads) | |
| if output_mode in ["scale"]: | |
| return (latents/s0 + del_g * update_val) * s_1 | |
| elif output_mode in ["back"]: | |
| return latents + del_g * update_val * s_1 | |
| elif output_mode in ["front"]: | |
| return latents + del_g * update_val * s0 | |
| else: | |
| return latents + del_g * update_val | |
| else: | |
| raise NotImplementedError | |
| class GHVBScheduler(PLMSWithHBScheduler): | |
| """ | |
| Generalizing Polyak's Heavy Bal (GHVB) for diffusion ODEs. | |
| We implement it as a wrapper for schedulers in diffusers (https://github.com/huggingface/diffusers) | |
| When order is an integer, this method is equivalent to PLMS without momentum. | |
| """ | |
| def _step_with_momentum(self, grads): | |
| self.vel = (1 - self.beta) * self.vel + self.beta * grads | |
| self.update_ets(self.vel) | |
| prime = self.mixer(self.order, self.ets, self.beta) | |
| return prime | |
| class PLMSWithNTScheduler(PLMSWithHBScheduler): | |
| """ | |
| PLMS with Nesterov Momentum (NT) for diffusion ODEs. | |
| We implement it as a wrapper for schedulers in diffusers (https://github.com/huggingface/diffusers) | |
| When order is an integer, this method is equivalent to PLMS without momentum. | |
| """ | |
| def _step_with_momentum(self, grads): | |
| self.update_ets(grads) | |
| prime = self.mixer(self.order, self.ets, 1.0) # update v^{(2)} | |
| self.vel = (1 - self.beta) * self.vel + self.beta * prime # update v^{(1)} | |
| update_val = (1 - self.beta) * self.vel + self.beta * prime # update x | |
| return update_val | |
| class MomentumDPMSolverMultistepScheduler(DPMSolverMultistepScheduler): | |
| """ | |
| DPM-Solver++2M with HB momentum. | |
| Currently support only algorithm_type = "dpmsolver++" and solver_type = "midpoint" | |
| When beta = 1.0, this method is equivalent to DPM-Solver++2M without momentum. | |
| """ | |
| def initialize_momentum(self, beta): | |
| self.vel = None | |
| self.beta = beta | |
| def multistep_dpm_solver_second_order_update( | |
| self, | |
| model_output_list: List[torch.FloatTensor], | |
| timestep_list: List[int], | |
| prev_timestep: int, | |
| sample: torch.FloatTensor, | |
| ) -> torch.FloatTensor: | |
| t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] | |
| m0, m1 = model_output_list[-1], model_output_list[-2] | |
| lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] | |
| alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] | |
| sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] | |
| h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 | |
| r0 = h_0 / h | |
| D0, D1 = m0, (1.0 / r0) * (m0 - m1) | |
| if self.config.algorithm_type == "dpmsolver++": | |
| # See https://arxiv.org/abs/2211.01095 for detailed derivations | |
| if self.config.solver_type == "midpoint": | |
| diff = (D0 + 0.5 * D1) | |
| if self.vel is None: | |
| self.vel = diff | |
| else: | |
| self.vel = (1-self.beta)*self.vel + self.beta * diff | |
| x_t = ( | |
| (sigma_t / sigma_s0) * sample | |
| - (alpha_t * (torch.exp(-h) - 1.0)) * self.vel | |
| ) | |
| elif self.config.solver_type == "heun": | |
| raise NotImplementedError( | |
| "{self.config.algorithm_type} with {self.config.solver_type} is currently not supported." | |
| ) | |
| elif self.config.algorithm_type == "dpmsolver": | |
| # See https://arxiv.org/abs/2206.00927 for detailed derivations | |
| if self.config.solver_type == "midpoint": | |
| raise NotImplementedError( | |
| "{self.config.algorithm_type} with {self.config.solver_type} is currently not supported." | |
| ) | |
| elif self.config.solver_type == "heun": | |
| raise NotImplementedError( | |
| "{self.config.algorithm_type} with {self.config.solver_type} is currently not supported." | |
| ) | |
| return x_t | |
| class MomentumUniPCMultistepScheduler(UniPCMultistepScheduler): | |
| """ | |
| UniPC with HB momentum. | |
| Currently support only self.predict_x0 = True | |
| When beta = 1.0, this method is equivalent to UniPC without momentum. | |
| """ | |
| def initialize_momentum(self, beta): | |
| self.vel_p = None | |
| self.vel_c = None | |
| self.beta = beta | |
| def multistep_uni_p_bh_update( | |
| self, | |
| model_output: torch.FloatTensor, | |
| prev_timestep: int, | |
| sample: torch.FloatTensor, | |
| order: int, | |
| ) -> torch.FloatTensor: | |
| timestep_list = self.timestep_list | |
| model_output_list = self.model_outputs | |
| s0, t = self.timestep_list[-1], prev_timestep | |
| m0 = model_output_list[-1] | |
| x = sample | |
| if self.solver_p: | |
| x_t = self.solver_p.step(model_output, s0, x).prev_sample | |
| return x_t | |
| lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] | |
| alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] | |
| sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] | |
| h = lambda_t - lambda_s0 | |
| device = sample.device | |
| rks = [] | |
| D1s = [] | |
| for i in range(1, order): | |
| si = timestep_list[-(i + 1)] | |
| mi = model_output_list[-(i + 1)] | |
| lambda_si = self.lambda_t[si] | |
| rk = (lambda_si - lambda_s0) / h | |
| rks.append(rk) | |
| D1s.append((mi - m0) / rk) | |
| rks.append(1.0) | |
| rks = torch.tensor(rks, device=device) | |
| R = [] | |
| b = [] | |
| hh = -h if self.predict_x0 else h | |
| h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 | |
| h_phi_k = h_phi_1 / hh - 1 | |
| factorial_i = 1 | |
| if self.config.solver_type == "bh1": | |
| B_h = hh | |
| elif self.config.solver_type == "bh2": | |
| B_h = torch.expm1(hh) | |
| else: | |
| raise NotImplementedError() | |
| for i in range(1, order + 1): | |
| R.append(torch.pow(rks, i - 1)) | |
| b.append(h_phi_k * factorial_i / B_h) | |
| factorial_i *= i + 1 | |
| h_phi_k = h_phi_k / hh - 1 / factorial_i | |
| R = torch.stack(R) | |
| b = torch.tensor(b, device=device) | |
| if len(D1s) > 0: | |
| D1s = torch.stack(D1s, dim=1) # (B, K) | |
| # for order 2, we use a simplified version | |
| if order == 2: | |
| rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) | |
| else: | |
| rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) | |
| else: | |
| D1s = None | |
| if self.predict_x0: | |
| if D1s is not None: | |
| pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s) | |
| else: | |
| pred_res = 0 | |
| val = ( h_phi_1 * m0 + B_h * pred_res ) /sigma_t /h_phi_1 | |
| if self.vel_p is None: | |
| self.vel_p = val | |
| else: | |
| self.vel_p = (1-self.beta)*self.vel_p + self.beta * val | |
| self.vel_p = val | |
| x_t = sigma_t * (x/ sigma_s0 - alpha_t * self.vel_p * h_phi_1) | |
| else: | |
| raise NotImplementedError | |
| x_t = x_t.to(x.dtype) | |
| return x_t | |
| def multistep_uni_c_bh_update( | |
| self, | |
| this_model_output: torch.FloatTensor, | |
| this_timestep: int, | |
| last_sample: torch.FloatTensor, | |
| this_sample: torch.FloatTensor, | |
| order: int, | |
| ) -> torch.FloatTensor: | |
| timestep_list = self.timestep_list | |
| model_output_list = self.model_outputs | |
| s0, t = timestep_list[-1], this_timestep | |
| m0 = model_output_list[-1] | |
| x = last_sample | |
| x_t = this_sample | |
| model_t = this_model_output | |
| lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] | |
| alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] | |
| sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] | |
| h = lambda_t - lambda_s0 | |
| device = this_sample.device | |
| rks = [] | |
| D1s = [] | |
| for i in range(1, order): | |
| si = timestep_list[-(i + 1)] | |
| mi = model_output_list[-(i + 1)] | |
| lambda_si = self.lambda_t[si] | |
| rk = (lambda_si - lambda_s0) / h | |
| rks.append(rk) | |
| D1s.append((mi - m0) / rk) | |
| rks.append(1.0) | |
| rks = torch.tensor(rks, device=device) | |
| R = [] | |
| b = [] | |
| hh = -h if self.predict_x0 else h | |
| h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 | |
| h_phi_k = h_phi_1 / hh - 1 | |
| factorial_i = 1 | |
| if self.config.solver_type == "bh1": | |
| B_h = hh | |
| elif self.config.solver_type == "bh2": | |
| B_h = torch.expm1(hh) | |
| else: | |
| raise NotImplementedError() | |
| for i in range(1, order + 1): | |
| R.append(torch.pow(rks, i - 1)) | |
| b.append(h_phi_k * factorial_i / B_h) | |
| factorial_i *= i + 1 | |
| h_phi_k = h_phi_k / hh - 1 / factorial_i | |
| R = torch.stack(R) | |
| b = torch.tensor(b, device=device) | |
| if len(D1s) > 0: | |
| D1s = torch.stack(D1s, dim=1) | |
| else: | |
| D1s = None | |
| # for order 1, we use a simplified version | |
| if order == 1: | |
| rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) | |
| else: | |
| rhos_c = torch.linalg.solve(R, b) | |
| if self.predict_x0: | |
| if D1s is not None: | |
| corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s) | |
| else: | |
| corr_res = 0 | |
| D1_t = model_t - m0 | |
| val = (h_phi_1 * m0 + B_h * (corr_res + rhos_c[-1] * D1_t))/sigma_t/h_phi_1 | |
| if self.vel_c is None: | |
| self.vel_c = val | |
| else: | |
| self.vel_c = (1-self.beta)*self.vel_c + self.beta * val | |
| x_t = sigma_t * (x/ sigma_s0 - alpha_t * self.vel_c * h_phi_1) | |
| else: | |
| raise NotImplementedError | |
| x_t = x_t.to(x.dtype) | |
| return x_t |