Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import datetime | |
| import os | |
| import json | |
| import torch | |
| import torchvision.transforms as transforms | |
| from torchvision.transforms import functional as F | |
| import spaces | |
| from huggingface_hub import snapshot_download | |
| import gradio as gr | |
| from diffusers import DDIMScheduler | |
| from lvdm.models.unet import UNetModel | |
| from lvdm.models.autoencoder import AutoencoderKL, AutoencoderKL_Dualref | |
| from lvdm.models.condition import FrozenOpenCLIPEmbedder, FrozenOpenCLIPImageEmbedderV2, Resampler | |
| from lvdm.models.layer_controlnet import LayerControlNet | |
| from lvdm.pipelines.pipeline_animation import AnimationPipeline | |
| from lvdm.utils import generate_gaussian_heatmap, save_videos_grid, save_videos_with_traj | |
| from einops import rearrange | |
| import cv2 | |
| import decord | |
| from PIL import Image | |
| import numpy as np | |
| from scipy.interpolate import PchipInterpolator | |
| SAVE_DIR = "outputs" | |
| os.makedirs(SAVE_DIR, exist_ok=True) | |
| LENGTH = 16 | |
| WIDTH = 512 | |
| HEIGHT = 320 | |
| LAYER_CAPACITY = 4 | |
| DEVICE = "cuda" | |
| WEIGHT_DTYPE = torch.bfloat16 | |
| PIPELINE = None | |
| GENERATOR = None | |
| os.makedirs("checkpoints", exist_ok=True) | |
| snapshot_download( | |
| "Yuppie1204/LayerAnimate-Mix", | |
| local_dir="checkpoints/LayerAnimate-Mix", | |
| ) | |
| TEXT_ENCODER = FrozenOpenCLIPEmbedder().eval() | |
| IMAGE_ENCODER = FrozenOpenCLIPImageEmbedderV2().eval() | |
| default_path = "checkpoints/LayerAnimate-Mix" | |
| scheduler = DDIMScheduler.from_pretrained(default_path, subfolder="scheduler") | |
| image_projector = Resampler.from_pretrained(default_path, subfolder="image_projector").eval() | |
| vae, vae_dualref = None, None | |
| if "I2V" or "Mix" in default_path: | |
| vae = AutoencoderKL.from_pretrained(default_path, subfolder="vae").eval() | |
| if "Interp" or "Mix" in default_path: | |
| vae_dualref = AutoencoderKL_Dualref.from_pretrained(default_path, subfolder="vae_dualref").eval() | |
| unet = UNetModel.from_pretrained(default_path, subfolder="unet").eval() | |
| layer_controlnet = LayerControlNet.from_pretrained(default_path, subfolder="layer_controlnet").eval() | |
| PIPELINE = AnimationPipeline( | |
| vae=vae, vae_dualref=vae_dualref, text_encoder=TEXT_ENCODER, image_encoder=IMAGE_ENCODER, image_projector=image_projector, | |
| unet=unet, layer_controlnet=layer_controlnet, scheduler=scheduler | |
| ).to(device=DEVICE, dtype=WEIGHT_DTYPE) | |
| if "Interp" or "Mix" in default_path: | |
| PIPELINE.vae_dualref.decoder.to(dtype=torch.float32) | |
| TRANSFORMS = transforms.Compose([ | |
| transforms.Resize(min(HEIGHT, WIDTH)), | |
| transforms.CenterCrop((HEIGHT, WIDTH)), | |
| ]) | |
| SAMPLE_GRID = np.meshgrid(np.linspace(0, WIDTH - 1, 10, dtype=int), np.linspace(0, HEIGHT - 1, 10, dtype=int)) | |
| SAMPLE_GRID = np.stack(SAMPLE_GRID, axis=-1).reshape(-1, 1, 2) | |
| SAMPLE_GRID = np.repeat(SAMPLE_GRID, LENGTH, axis=1) # [N, F, 2] | |
| def set_seed(seed, device): | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| return torch.Generator(device).manual_seed(seed) | |
| def set_model(pretrained_model_path): | |
| global PIPELINE | |
| scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") | |
| image_projector = Resampler.from_pretrained(pretrained_model_path, subfolder="image_projector").eval() | |
| vae, vae_dualref = None, None | |
| if "I2V" or "Mix" in pretrained_model_path: | |
| vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").eval() | |
| if "Interp" or "Mix" in pretrained_model_path: | |
| vae_dualref = AutoencoderKL_Dualref.from_pretrained(pretrained_model_path, subfolder="vae_dualref").eval() | |
| unet = UNetModel.from_pretrained(pretrained_model_path, subfolder="unet").eval() | |
| layer_controlnet = LayerControlNet.from_pretrained(pretrained_model_path, subfolder="layer_controlnet").eval() | |
| PIPELINE.update( | |
| vae=vae, vae_dualref=vae_dualref, text_encoder=TEXT_ENCODER, image_encoder=IMAGE_ENCODER, image_projector=image_projector, | |
| unet=unet, layer_controlnet=layer_controlnet, scheduler=scheduler | |
| ) | |
| PIPELINE.to(device=DEVICE, dtype=WEIGHT_DTYPE) | |
| if "Interp" or "Mix" in pretrained_model_path: | |
| PIPELINE.vae_dualref.decoder.to(dtype=torch.float32) | |
| return pretrained_model_path | |
| def upload_image(image): | |
| image = TRANSFORMS(image) | |
| return image | |
| def run(input_image, input_image_end, pretrained_model_path, seed, | |
| prompt, n_prompt, num_inference_steps, guidance_scale, | |
| *layer_args): | |
| generator = set_seed(seed, DEVICE) | |
| global layer_tracking_points | |
| args_layer_tracking_points = [layer_tracking_points[i].value for i in range(LAYER_CAPACITY)] | |
| args_layer_masks = layer_args[:LAYER_CAPACITY] | |
| args_layer_masks_end = layer_args[LAYER_CAPACITY : 2 * LAYER_CAPACITY] | |
| args_layer_controls = layer_args[2 * LAYER_CAPACITY : 3 * LAYER_CAPACITY] | |
| args_layer_scores = list(layer_args[3 * LAYER_CAPACITY : 4 * LAYER_CAPACITY]) | |
| args_layer_sketches = layer_args[4 * LAYER_CAPACITY : 5 * LAYER_CAPACITY] | |
| args_layer_valids = layer_args[5 * LAYER_CAPACITY : 6 * LAYER_CAPACITY] | |
| args_layer_statics = layer_args[6 * LAYER_CAPACITY : 7 * LAYER_CAPACITY] | |
| for layer_idx in range(LAYER_CAPACITY): | |
| if args_layer_controls[layer_idx] != "score": | |
| args_layer_scores[layer_idx] = -1 | |
| if args_layer_statics[layer_idx]: | |
| args_layer_scores[layer_idx] = 0 | |
| mode = "i2v" | |
| image1 = F.to_tensor(input_image) * 2 - 1 | |
| frame_tensor = image1[None].to(DEVICE) # [F, C, H, W] | |
| if input_image_end is not None: | |
| mode = "interpolate" | |
| image2 = F.to_tensor(input_image_end) * 2 - 1 | |
| frame_tensor2 = image2[None].to(DEVICE) | |
| frame_tensor = torch.cat([frame_tensor, frame_tensor2], dim=0) | |
| frame_tensor = frame_tensor[None] | |
| if mode == "interpolate": | |
| layer_masks = torch.zeros((1, LAYER_CAPACITY, 2, 1, HEIGHT, WIDTH), dtype=torch.bool) | |
| else: | |
| layer_masks = torch.zeros((1, LAYER_CAPACITY, 1, 1, HEIGHT, WIDTH), dtype=torch.bool) | |
| for layer_idx in range(LAYER_CAPACITY): | |
| if args_layer_masks[layer_idx] is not None: | |
| mask = F.to_tensor(args_layer_masks[layer_idx]) > 0.5 | |
| layer_masks[0, layer_idx, 0] = mask | |
| if args_layer_masks_end[layer_idx] is not None and mode == "interpolate": | |
| mask = F.to_tensor(args_layer_masks_end[layer_idx]) > 0.5 | |
| layer_masks[0, layer_idx, 1] = mask | |
| layer_masks = layer_masks.to(DEVICE) | |
| layer_regions = layer_masks * frame_tensor[:, None] | |
| layer_validity = torch.tensor([args_layer_valids], dtype=torch.bool, device=DEVICE) | |
| motion_scores = torch.tensor([args_layer_scores], dtype=WEIGHT_DTYPE, device=DEVICE) | |
| layer_static = torch.tensor([args_layer_statics], dtype=torch.bool, device=DEVICE) | |
| sketch = torch.ones((1, LAYER_CAPACITY, LENGTH, 3, HEIGHT, WIDTH), dtype=WEIGHT_DTYPE) | |
| for layer_idx in range(LAYER_CAPACITY): | |
| sketch_path = args_layer_sketches[layer_idx] | |
| if sketch_path is not None: | |
| video_reader = decord.VideoReader(sketch_path) | |
| assert len(video_reader) == LENGTH, f"Input the length of sketch sequence should match the video length." | |
| video_frames = video_reader.get_batch(range(LENGTH)).asnumpy() | |
| sketch_values = [F.to_tensor(TRANSFORMS(Image.fromarray(frame))) for frame in video_frames] | |
| sketch_values = torch.stack(sketch_values) * 2 - 1 | |
| sketch[0, layer_idx] = sketch_values | |
| sketch = sketch.to(DEVICE) | |
| heatmap = torch.zeros((1, LAYER_CAPACITY, LENGTH, 3, HEIGHT, WIDTH), dtype=WEIGHT_DTYPE) | |
| heatmap[:, :, :, 0] -= 1 | |
| trajectory = [] | |
| traj_layer_index = [] | |
| for layer_idx in range(LAYER_CAPACITY): | |
| tracking_points = args_layer_tracking_points[layer_idx] | |
| if args_layer_statics[layer_idx]: | |
| # generate pseudo trajectory for static layers | |
| temp_layer_mask = layer_masks[0, layer_idx, 0, 0].cpu().numpy() | |
| valid_flag = temp_layer_mask[SAMPLE_GRID[:, 0, 1], SAMPLE_GRID[:, 0, 0]] | |
| valid_grid = SAMPLE_GRID[valid_flag] # [F, N, 2] | |
| trajectory.extend(list(valid_grid)) | |
| traj_layer_index.extend([layer_idx] * valid_grid.shape[0]) | |
| else: | |
| for temp_track in tracking_points: | |
| if len(temp_track) > 1: | |
| x = [point[0] for point in temp_track] | |
| y = [point[1] for point in temp_track] | |
| t = np.linspace(0, 1, len(temp_track)) | |
| fx = PchipInterpolator(t, x) | |
| fy = PchipInterpolator(t, y) | |
| t_new = np.linspace(0, 1, LENGTH) | |
| x_new = fx(t_new) | |
| y_new = fy(t_new) | |
| temp_traj = np.stack([x_new, y_new], axis=-1).astype(np.float32) | |
| trajectory.append(temp_traj) | |
| traj_layer_index.append(layer_idx) | |
| elif len(temp_track) == 1: | |
| trajectory.append(np.array(temp_track * LENGTH)) | |
| traj_layer_index.append(layer_idx) | |
| trajectory = np.stack(trajectory) | |
| trajectory = np.transpose(trajectory, (1, 0, 2)) | |
| traj_layer_index = np.array(traj_layer_index) | |
| heatmap = generate_gaussian_heatmap(trajectory, WIDTH, HEIGHT, traj_layer_index, LAYER_CAPACITY, offset=True) | |
| heatmap = rearrange(heatmap, "f n c h w -> (f n) c h w") | |
| graymap, offset = heatmap[:, :1], heatmap[:, 1:] | |
| graymap = graymap / 255. | |
| rad = torch.sqrt(offset[:, 0:1]**2 + offset[:, 1:2]**2) | |
| rad_max = torch.max(rad) | |
| epsilon = 1e-5 | |
| offset = offset / (rad_max + epsilon) | |
| graymap = graymap * 2 - 1 | |
| heatmap = torch.cat([graymap, offset], dim=1) | |
| heatmap = rearrange(heatmap, '(f n) c h w -> n f c h w', n=LAYER_CAPACITY) | |
| heatmap = heatmap[None] | |
| heatmap = heatmap.to(DEVICE) | |
| sample = PIPELINE( | |
| prompt, | |
| LENGTH, | |
| HEIGHT, | |
| WIDTH, | |
| frame_tensor, | |
| layer_masks = layer_masks, | |
| layer_regions = layer_regions, | |
| layer_static = layer_static, | |
| motion_scores = motion_scores, | |
| sketch = sketch, | |
| trajectory = heatmap, | |
| layer_validity = layer_validity, | |
| num_inference_steps = num_inference_steps, | |
| guidance_scale = guidance_scale, | |
| guidance_rescale = 0.7, | |
| negative_prompt = n_prompt, | |
| num_videos_per_prompt = 1, | |
| eta = 1.0, | |
| generator = generator, | |
| fps = 24, | |
| mode = mode, | |
| weight_dtype = WEIGHT_DTYPE, | |
| output_type = "tensor", | |
| ).videos | |
| output_video_path = os.path.join(SAVE_DIR, "video.mp4") | |
| save_videos_grid(sample, output_video_path, fps=8) | |
| output_video_traj_path = os.path.join(SAVE_DIR, "video_with_traj.mp4") | |
| vis_traj_flag = np.zeros(trajectory.shape[1], dtype=bool) | |
| for traj_idx in range(trajectory.shape[1]): | |
| if not args_layer_statics[traj_layer_index[traj_idx]]: | |
| vis_traj_flag[traj_idx] = True | |
| vis_traj = torch.from_numpy(trajectory[:, vis_traj_flag]) | |
| save_videos_with_traj(sample[0], vis_traj, os.path.join(SAVE_DIR, f"video_with_traj.mp4"), fps=8, line_width=7, circle_radius=10) | |
| return output_video_path, output_video_traj_path | |
| def update_layer_region(image, layer_mask): | |
| if image is None or layer_mask is None: | |
| return None, False | |
| layer_mask_tensor = (F.to_tensor(layer_mask) > 0.5).float() | |
| image = F.to_tensor(image) | |
| layer_region = image * layer_mask_tensor | |
| layer_region = F.to_pil_image(layer_region) | |
| layer_region.putalpha(layer_mask) | |
| return layer_region, True | |
| def control_layers(control_type): | |
| if control_type == "score": | |
| return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) | |
| elif control_type == "trajectory": | |
| return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False) | |
| else: | |
| return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) | |
| def visualize_trajectory(tracking_points, first_frame, first_mask, last_frame, last_mask): | |
| first_mask_tensor = (F.to_tensor(first_mask) > 0.5).float() | |
| first_frame = F.to_tensor(first_frame) | |
| first_region = first_frame * first_mask_tensor | |
| first_region = F.to_pil_image(first_region) | |
| first_region.putalpha(first_mask) | |
| transparent_background = first_region.convert('RGBA') | |
| if last_frame is not None and last_mask is not None: | |
| last_mask_tensor = (F.to_tensor(last_mask) > 0.5).float() | |
| last_frame = F.to_tensor(last_frame) | |
| last_region = last_frame * last_mask_tensor | |
| last_region = F.to_pil_image(last_region) | |
| last_region.putalpha(last_mask) | |
| transparent_background_end = last_region.convert('RGBA') | |
| width, height = transparent_background.size | |
| transparent_layer = np.zeros((height, width, 4)) | |
| for track in tracking_points: | |
| if len(track) > 1: | |
| for i in range(len(track)-1): | |
| start_point = np.array(track[i], dtype=np.int32) | |
| end_point = np.array(track[i+1], dtype=np.int32) | |
| vx = end_point[0] - start_point[0] | |
| vy = end_point[1] - start_point[1] | |
| arrow_length = max(np.sqrt(vx**2 + vy**2), 1) | |
| if i == len(track)-2: | |
| cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length) | |
| else: | |
| cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,) | |
| elif len(track) == 1: | |
| cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1) | |
| transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8)) | |
| trajectory_map = Image.alpha_composite(transparent_background, transparent_layer) | |
| if last_frame is not None and last_mask is not None: | |
| trajectory_map_end = Image.alpha_composite(transparent_background_end, transparent_layer) | |
| else: | |
| trajectory_map_end = None | |
| return trajectory_map, trajectory_map_end | |
| def add_drag(layer_idx): | |
| global layer_tracking_points | |
| tracking_points = layer_tracking_points[layer_idx].value | |
| tracking_points.append([]) | |
| return | |
| def delete_last_drag(layer_idx, first_frame, first_mask, last_frame, last_mask): | |
| global layer_tracking_points | |
| tracking_points = layer_tracking_points[layer_idx].value | |
| tracking_points.pop() | |
| trajectory_map, trajectory_map_end = visualize_trajectory(tracking_points, first_frame, first_mask, last_frame, last_mask) | |
| return trajectory_map, trajectory_map_end | |
| def delete_last_step(layer_idx, first_frame, first_mask, last_frame, last_mask): | |
| global layer_tracking_points | |
| tracking_points = layer_tracking_points[layer_idx].value | |
| tracking_points[-1].pop() | |
| trajectory_map, trajectory_map_end = visualize_trajectory(tracking_points, first_frame, first_mask, last_frame, last_mask) | |
| return trajectory_map, trajectory_map_end | |
| def add_tracking_points(layer_idx, first_frame, first_mask, last_frame, last_mask, evt: gr.SelectData): # SelectData is a subclass of EventData | |
| print(f"You selected {evt.value} at {evt.index} from {evt.target}") | |
| global layer_tracking_points | |
| tracking_points = layer_tracking_points[layer_idx].value | |
| tracking_points[-1].append(evt.index) | |
| trajectory_map, trajectory_map_end = visualize_trajectory(tracking_points, first_frame, first_mask, last_frame, last_mask) | |
| return trajectory_map, trajectory_map_end | |
| def reset_states(layer_idx, first_frame, first_mask, last_frame, last_mask): | |
| global layer_tracking_points | |
| layer_tracking_points[layer_idx].value = [[]] | |
| tracking_points = layer_tracking_points[layer_idx].value | |
| trajectory_map, trajectory_map_end = visualize_trajectory(tracking_points, first_frame, first_mask, last_frame, last_mask) | |
| return trajectory_map, trajectory_map_end | |
| def upload_tracking_points(tracking_path, layer_idx, first_frame, first_mask, last_frame, last_mask): | |
| if tracking_path is None: | |
| layer_region, _ = update_layer_region(first_frame, first_mask) | |
| layer_region_end, _ = update_layer_region(last_frame, last_mask) | |
| return layer_region, layer_region_end | |
| global layer_tracking_points | |
| with open(tracking_path, "r") as f: | |
| tracking_points = json.load(f) | |
| layer_tracking_points[layer_idx].value = tracking_points | |
| trajectory_map, trajectory_map_end = visualize_trajectory(tracking_points, first_frame, first_mask, last_frame, last_mask) | |
| return trajectory_map, trajectory_map_end | |
| def reset_all_controls(): | |
| global layer_tracking_points | |
| outputs = [] | |
| # Reset tracking points states | |
| for layer_idx in range(LAYER_CAPACITY): | |
| layer_tracking_points[layer_idx].value = [[]] | |
| # Reset global components | |
| outputs.extend([ | |
| "an anime scene.", # text prompt | |
| "", # negative text prompt | |
| 50, # inference steps | |
| 7.5, # guidance scale | |
| 42, # seed | |
| None, # input image | |
| None, # input image end | |
| None, # output video | |
| None, # output video with trajectory | |
| ]) | |
| # Reset layer controls visibility | |
| outputs.extend([None] * LAYER_CAPACITY) # layer masks | |
| outputs.extend([None] * LAYER_CAPACITY) # layer masks end | |
| outputs.extend([None] * LAYER_CAPACITY) # layer regions | |
| outputs.extend([None] * LAYER_CAPACITY) # layer regions end | |
| outputs.extend(["sketch"] * LAYER_CAPACITY) # layer controls | |
| outputs.extend([gr.update(visible=False, value=-1) for _ in range(LAYER_CAPACITY)]) # layer score controls | |
| outputs.extend([gr.update(visible=False) for _ in range(4 * LAYER_CAPACITY)]) # layer trajectory control 4 buttons | |
| outputs.extend([gr.update(visible=False, value=None) for _ in range(LAYER_CAPACITY)]) # layer trajectory file | |
| outputs.extend([None] * LAYER_CAPACITY) # layer sketch controls | |
| outputs.extend([False] * LAYER_CAPACITY) # layer validity | |
| outputs.extend([False] * LAYER_CAPACITY) # layer statics | |
| return outputs | |
| if __name__ == "__main__": | |
| with gr.Blocks() as demo: | |
| gr.Markdown("""<h1 align="center">LayerAnimate: Layer-level Control for Animation</h1><br>""") | |
| gr.Markdown("""Gradio Demo for <a href='https://arxiv.org/abs/2501.08295'><b>LayerAnimate: Layer-level Control for Animation</b></a>.<br> | |
| Github Repo can be found at https://github.com/IamCreateAI/LayerAnimate<br> | |
| The template is inspired by Framer.""") | |
| gr.Image(label="LayerAnimate: Layer-level Control for Animation", value="__assets__/figs/demos.gif", height=540, width=960) | |
| gr.Markdown("""## Usage: <br> | |
| 1. Select a pretrained model via the "Pretrained Model" dropdown of choices in the right column.<br> | |
| 2. Upload frames in the right column.<br> | |
|   1.1. Upload the first frame.<br> | |
|   1.2. (Optional) Upload the last frame.<br> | |
| 3. Input layer-level controls in the left column.<br> | |
|   2.1. Upload layer mask images for each layer, which can be obtained from many tools such as https://huggingface.co/spaces/yumyum2081/SAM2-Image-Predictor.<br> | |
|   2.2. Choose a control type from "motion score", "trajectory" and "sketch".<br> | |
|   2.3. For trajectory control, you can draw trajectories on layer regions.<br> | |
|     2.3.1. Click "Add New Trajectory" to add a new trajectory.<br> | |
|     2.3.2. Click "Reset" to reset all trajectories.<br> | |
|     2.3.3. Click "Delete Last Step" to delete the lastest clicked control point.<br> | |
|     2.3.4. Click "Delete Last Trajectory" to delete the whole lastest path.<br> | |
|     2.3.5. Or upload a trajectory file in json format, we provide examples below.<br> | |
|   2.4. For sketch control, you can upload a sketch video.<br> | |
| 4. We provide four layers for you to control, and it is not necessary to use all of them.<br> | |
| 5. Click "Run" button to generate videos. <br> | |
| 6. **Note: Remember to click "Clear" button to clear all the controls before switching to another example.**<br> | |
| """) | |
| layer_indices = [gr.Number(value=i, visible=False) for i in range(LAYER_CAPACITY)] | |
| layer_tracking_points = [gr.State([[]]) for _ in range(LAYER_CAPACITY)] | |
| layer_masks = [] | |
| layer_masks_end = [] | |
| layer_regions = [] | |
| layer_regions_end = [] | |
| layer_controls = [] | |
| layer_score_controls = [] | |
| layer_traj_controls = [] | |
| layer_traj_files = [] | |
| layer_sketch_controls = [] | |
| layer_statics = [] | |
| layer_valids = [] | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| for layer_idx in range(LAYER_CAPACITY): | |
| with gr.Accordion(label=f"Layer {layer_idx+1}", open=True if layer_idx == 0 else False): | |
| gr.Markdown("""<div align="center"><b>Layer Masks</b></div>""") | |
| gr.Markdown("**Note**: Layer mask for the last frame is not required in I2V mode.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| layer_masks.append(gr.Image( | |
| label="Layer Mask for First Frame", | |
| height=320, | |
| width=512, | |
| image_mode="L", | |
| type="pil", | |
| )) | |
| with gr.Column(): | |
| layer_masks_end.append(gr.Image( | |
| label="Layer Mask for Last Frame", | |
| height=320, | |
| width=512, | |
| image_mode="L", | |
| type="pil", | |
| )) | |
| gr.Markdown("""<div align="center"><b>Layer Regions</b></div>""") | |
| with gr.Row(): | |
| with gr.Column(): | |
| layer_regions.append(gr.Image( | |
| label="Layer Region for First Frame", | |
| height=320, | |
| width=512, | |
| image_mode="RGBA", | |
| type="pil", | |
| # value=Image.new("RGBA", (512, 320), (255, 255, 255, 0)), | |
| )) | |
| with gr.Column(): | |
| layer_regions_end.append(gr.Image( | |
| label="Layer Region for Last Frame", | |
| height=320, | |
| width=512, | |
| image_mode="RGBA", | |
| type="pil", | |
| # value=Image.new("RGBA", (512, 320), (255, 255, 255, 0)), | |
| )) | |
| layer_controls.append( | |
| gr.Radio(["score", "trajectory", "sketch"], label="Choose A Control Type", value="sketch") | |
| ) | |
| layer_score_controls.append( | |
| gr.Number(label="Motion Score", value=-1, visible=False) | |
| ) | |
| layer_traj_controls.append( | |
| [ | |
| gr.Button(value="Add New Trajectory", visible=False), | |
| gr.Button(value="Reset", visible=False), | |
| gr.Button(value="Delete Last Step", visible=False), | |
| gr.Button(value="Delete Last Trajectory", visible=False), | |
| ] | |
| ) | |
| layer_traj_files.append( | |
| gr.File(label="Trajectory File", visible=False) | |
| ) | |
| layer_sketch_controls.append( | |
| gr.Video(label="Sketch", height=320, width=512, visible=True) | |
| ) | |
| layer_controls[layer_idx].change( | |
| fn=control_layers, | |
| inputs=layer_controls[layer_idx], | |
| outputs=[layer_score_controls[layer_idx], *layer_traj_controls[layer_idx], layer_traj_files[layer_idx], layer_sketch_controls[layer_idx]] | |
| ) | |
| with gr.Row(): | |
| layer_valids.append(gr.Checkbox(label="Valid", info="Is the layer valid?")) | |
| layer_statics.append(gr.Checkbox(label="Static", info="Is the layer static?")) | |
| with gr.Column(scale=1): | |
| pretrained_model_path = gr.Dropdown( | |
| label="Pretrained Model", | |
| choices=[ | |
| "checkpoints/LayerAnimate-Mix", | |
| ], | |
| value="checkpoints/LayerAnimate-Mix", | |
| ) | |
| text_prompt = gr.Textbox(label="Text Prompt", value="an anime scene.") | |
| text_n_prompt = gr.Textbox(label="Negative Text Prompt", value="") | |
| with gr.Row(): | |
| num_inference_steps = gr.Number(label="Inference Steps", value=50, minimum=1, maximum=1000) | |
| guidance_scale = gr.Number(label="Guidance Scale", value=7.5) | |
| seed = gr.Number(label="Seed", value=42) | |
| with gr.Row(): | |
| input_image = gr.Image( | |
| label="First Frame", | |
| height=320, | |
| width=512, | |
| type="pil", | |
| ) | |
| input_image_end = gr.Image( | |
| label="Last Frame", | |
| height=320, | |
| width=512, | |
| type="pil", | |
| ) | |
| run_button = gr.Button(value="Run") | |
| with gr.Row(): | |
| output_video = gr.Video( | |
| label="Output Video", | |
| height=320, | |
| width=512, | |
| ) | |
| output_video_traj = gr.Video( | |
| label="Output Video with Trajectory", | |
| height=320, | |
| width=512, | |
| ) | |
| clear_button = gr.Button(value="Clear") | |
| with gr.Row(): | |
| gr.Markdown(""" | |
| ## Citation | |
| ```bibtex | |
| @article{yang2025layeranimate, | |
| author = {Yang, Yuxue and Fan, Lue and Lin, Zuzeng and Wang, Feng and Zhang, Zhaoxiang}, | |
| title = {LayerAnimate: Layer-level Control for Animation}, | |
| journal = {arXiv preprint arXiv:2501.08295}, | |
| year = {2025}, | |
| } | |
| ``` | |
| """) | |
| pretrained_model_path.input(set_model, pretrained_model_path, pretrained_model_path) | |
| input_image.upload(upload_image, input_image, input_image) | |
| input_image_end.upload(upload_image, input_image_end, input_image_end) | |
| for i in range(LAYER_CAPACITY): | |
| layer_masks[i].upload(upload_image, layer_masks[i], layer_masks[i]) | |
| layer_masks[i].change(update_layer_region, [input_image, layer_masks[i]], [layer_regions[i], layer_valids[i]]) | |
| layer_masks_end[i].upload(upload_image, layer_masks_end[i], layer_masks_end[i]) | |
| layer_masks_end[i].change(update_layer_region, [input_image_end, layer_masks_end[i]], [layer_regions_end[i], layer_valids[i]]) | |
| layer_traj_controls[i][0].click(add_drag, layer_indices[i], None) | |
| layer_traj_controls[i][1].click( | |
| reset_states, | |
| [layer_indices[i], input_image, layer_masks[i], input_image_end, layer_masks_end[i]], | |
| [layer_regions[i], layer_regions_end[i]] | |
| ) | |
| layer_traj_controls[i][2].click( | |
| delete_last_step, | |
| [layer_indices[i], input_image, layer_masks[i], input_image_end, layer_masks_end[i]], | |
| [layer_regions[i], layer_regions_end[i]] | |
| ) | |
| layer_traj_controls[i][3].click( | |
| delete_last_drag, | |
| [layer_indices[i], input_image, layer_masks[i], input_image_end, layer_masks_end[i]], | |
| [layer_regions[i], layer_regions_end[i]] | |
| ) | |
| layer_traj_files[i].change( | |
| upload_tracking_points, | |
| [layer_traj_files[i], layer_indices[i], input_image, layer_masks[i], input_image_end, layer_masks_end[i]], | |
| [layer_regions[i], layer_regions_end[i]] | |
| ) | |
| layer_regions[i].select( | |
| add_tracking_points, | |
| [layer_indices[i], input_image, layer_masks[i], input_image_end, layer_masks_end[i]], | |
| [layer_regions[i], layer_regions_end[i]] | |
| ) | |
| layer_regions_end[i].select( | |
| add_tracking_points, | |
| [layer_indices[i], input_image, layer_masks[i], input_image_end, layer_masks_end[i]], | |
| [layer_regions[i], layer_regions_end[i]] | |
| ) | |
| run_button.click( | |
| run, | |
| [input_image, input_image_end, pretrained_model_path, seed, text_prompt, text_n_prompt, num_inference_steps, guidance_scale, | |
| *layer_masks, *layer_masks_end, *layer_controls, *layer_score_controls, *layer_sketch_controls, *layer_valids, *layer_statics], | |
| [output_video, output_video_traj] | |
| ) | |
| clear_button.click( | |
| reset_all_controls, | |
| [], | |
| [ | |
| text_prompt, text_n_prompt, num_inference_steps, guidance_scale, seed, | |
| input_image, input_image_end, output_video, output_video_traj, | |
| *layer_masks, *layer_masks_end, *layer_regions, *layer_regions_end, | |
| *layer_controls, *layer_score_controls, *[button for temp_layer_controls in layer_traj_controls for button in temp_layer_controls], *layer_traj_files, | |
| *layer_sketch_controls, *layer_valids, *layer_statics | |
| ] | |
| ) | |
| examples = gr.Examples( | |
| examples=[ | |
| [ | |
| "__assets__/demos/demo_3/first_frame.jpg", "__assets__/demos/demo_3/last_frame.jpg", | |
| "score", "__assets__/demos/demo_3/layer_0.jpg", "__assets__/demos/demo_3/layer_0_last.jpg", 0.4, None, None, True, False, | |
| "score", "__assets__/demos/demo_3/layer_1.jpg", "__assets__/demos/demo_3/layer_1_last.jpg", 0.2, None, None, True, False, | |
| "trajectory", "__assets__/demos/demo_3/layer_2.jpg", "__assets__/demos/demo_3/layer_2_last.jpg", -1, "__assets__/demos/demo_3/trajectory.json", None, True, False, | |
| "sketch", "__assets__/demos/demo_3/layer_3.jpg", "__assets__/demos/demo_3/layer_3_last.jpg", -1, None, "__assets__/demos/demo_3/sketch.mp4", True, False, | |
| 52 | |
| ], | |
| [ | |
| "__assets__/demos/demo_4/first_frame.jpg", None, | |
| "score", "__assets__/demos/demo_4/layer_0.jpg", None, 0.0, None, None, True, True, | |
| "trajectory", "__assets__/demos/demo_4/layer_1.jpg", None, -1, "__assets__/demos/demo_4/trajectory.json", None, True, False, | |
| "sketch", "__assets__/demos/demo_4/layer_2.jpg", None, -1, None, "__assets__/demos/demo_4/sketch.mp4", True, False, | |
| "score", None, None, -1, None, None, False, False, | |
| 42 | |
| ], | |
| [ | |
| "__assets__/demos/demo_5/first_frame.jpg", None, | |
| "sketch", "__assets__/demos/demo_5/layer_0.jpg", None, -1, None, "__assets__/demos/demo_5/sketch.mp4", True, False, | |
| "trajectory", "__assets__/demos/demo_5/layer_1.jpg", None, -1, "__assets__/demos/demo_5/trajectory.json", None, True, False, | |
| "score", None, None, -1, None, None, False, False, | |
| "score", None, None, -1, None, None, False, False, | |
| 47 | |
| ], | |
| ], | |
| inputs=[ | |
| input_image, input_image_end, | |
| layer_controls[0], layer_masks[0], layer_masks_end[0], layer_score_controls[0], layer_traj_files[0], layer_sketch_controls[0], layer_valids[0], layer_statics[0], | |
| layer_controls[1], layer_masks[1], layer_masks_end[1], layer_score_controls[1], layer_traj_files[1], layer_sketch_controls[1], layer_valids[1], layer_statics[1], | |
| layer_controls[2], layer_masks[2], layer_masks_end[2], layer_score_controls[2], layer_traj_files[2], layer_sketch_controls[2], layer_valids[2], layer_statics[2], | |
| layer_controls[3], layer_masks[3], layer_masks_end[3], layer_score_controls[3], layer_traj_files[3], layer_sketch_controls[3], layer_valids[3], layer_statics[3], | |
| seed | |
| ], | |
| ) | |
| demo.launch() |