zzl
commited on
Commit
·
2bbc3ee
1
Parent(s):
9fd9429
release
Browse files- app.py +6 -7
- demo_img.py +34 -10
- demo_vid.py +20 -2
- utils.py +18 -0
app.py
CHANGED
|
@@ -14,19 +14,18 @@ with gr.Blocks(css='style.css') as demo:
|
|
| 14 |
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
|
| 15 |
<a href="https://paper99.github.io" style="color:blue;">Zhen Li</a><sup>1*</sup>,
|
| 16 |
<a href="https://github.com/NK-CS-ZZL" style="color:blue;">Zuo-Liang Zhu</a><sup>1*</sup>,
|
| 17 |
-
<a href="https://github.com/hlh981029" style="color:blue;">Ling-Hao Han</a><sup>1
|
| 18 |
-
<a href="https://houqb.github.io" style="color:blue;">Qibin Hou</a><sup>1
|
| 19 |
-
<a href="https://github.com" style="color:blue;">Chun-Le Guo</a><sup>1
|
| 20 |
-
<a href="https://mmcheng.net" style="color:blue;">Ming-Ming Cheng</a><sup>1
|
| 21 |
</h2>
|
| 22 |
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
|
| 23 |
-
<sup>1</sup>Nankai University <sup>*</sup> represents the equal contribution
|
| 24 |
</h2>
|
| 25 |
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
|
| 26 |
[<a href="https://arxiv.org/abs/2303.13439" style="color:blue;">arXiv</a>]
|
| 27 |
[<a href="https://github.com/MCG-NKU/AMT" style="color:blue;">GitHub</a>]
|
| 28 |
-
[<a href="https://
|
| 29 |
-
[<a href="https://github.com/MCG-NKU/AMT" style="color:blue;">Replicate</a>]
|
| 30 |
</h2>
|
| 31 |
<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
|
| 32 |
"""
|
|
|
|
| 14 |
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
|
| 15 |
<a href="https://paper99.github.io" style="color:blue;">Zhen Li</a><sup>1*</sup>,
|
| 16 |
<a href="https://github.com/NK-CS-ZZL" style="color:blue;">Zuo-Liang Zhu</a><sup>1*</sup>,
|
| 17 |
+
<a href="https://github.com/hlh981029" style="color:blue;">Ling-Hao Han</a><sup>1</sup>,
|
| 18 |
+
<a href="https://houqb.github.io" style="color:blue;">Qibin Hou</a><sup>1</sup>,
|
| 19 |
+
<a href="https://github.com" style="color:blue;">Chun-Le Guo</a><sup>1</sup>,
|
| 20 |
+
<a href="https://mmcheng.net" style="color:blue;">Ming-Ming Cheng</a><sup>1</sup>,
|
| 21 |
</h2>
|
| 22 |
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
|
| 23 |
+
<sup>1</sup>Nankai University <sup>*</sup> represents the equal contribution.
|
| 24 |
</h2>
|
| 25 |
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
|
| 26 |
[<a href="https://arxiv.org/abs/2303.13439" style="color:blue;">arXiv</a>]
|
| 27 |
[<a href="https://github.com/MCG-NKU/AMT" style="color:blue;">GitHub</a>]
|
| 28 |
+
[<a href="https://colab.research.google.com/drive/1IeVO5BmLouhRh6fL2z_y18kgubotoaBq?usp=sharing" style="color:blue;">Colab</a>]
|
|
|
|
| 29 |
</h2>
|
| 30 |
<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
|
| 31 |
"""
|
demo_img.py
CHANGED
|
@@ -7,8 +7,11 @@ from huggingface_hub import hf_hub_download
|
|
| 7 |
from networks.amts import Model as AMTS
|
| 8 |
from networks.amtl import Model as AMTL
|
| 9 |
from networks.amtg import Model as AMTG
|
| 10 |
-
from utils import
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
| 12 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 13 |
model_dict = {
|
| 14 |
'AMT-S': AMTS, 'AMT-L': AMTL, 'AMT-G': AMTG
|
|
@@ -23,22 +26,43 @@ def img2vid(model_type, img0, img1, frame_ratio, iters):
|
|
| 23 |
model.eval()
|
| 24 |
img0_t = img2tensor(img0).to(device)
|
| 25 |
img1_t = img2tensor(img1).to(device)
|
| 26 |
-
padder = InputPadder(img0_t.shape, 16)
|
| 27 |
-
img0_t, img1_t = padder.pad(img0_t, img1_t)
|
| 28 |
inputs = [img0_t, img1_t]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device)
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
for i in range(iters):
|
| 32 |
print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}')
|
| 33 |
-
outputs = [
|
| 34 |
for in_0, in_1 in zip(inputs[:-1], inputs[1:]):
|
|
|
|
|
|
|
| 35 |
with torch.no_grad():
|
| 36 |
-
imgt_pred = model(in_0, in_1, embt, eval=True)['imgt_pred']
|
| 37 |
-
|
| 38 |
-
in_1 = padder.unpad(in_1)
|
| 39 |
-
outputs += [imgt_pred, in_1]
|
| 40 |
inputs = outputs
|
| 41 |
-
|
| 42 |
out_path = 'results'
|
| 43 |
size = outputs[0].shape[2:][::-1]
|
| 44 |
writer = cv2.VideoWriter(f'{out_path}/demo.mp4', cv2.VideoWriter_fourcc(*'mp4v'), frame_ratio, size)
|
|
|
|
| 7 |
from networks.amts import Model as AMTS
|
| 8 |
from networks.amtl import Model as AMTL
|
| 9 |
from networks.amtg import Model as AMTG
|
| 10 |
+
from utils import (
|
| 11 |
+
img2tensor, tensor2img,
|
| 12 |
+
InputPadder,
|
| 13 |
+
check_dim_and_resize
|
| 14 |
+
)
|
| 15 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 16 |
model_dict = {
|
| 17 |
'AMT-S': AMTS, 'AMT-L': AMTL, 'AMT-G': AMTG
|
|
|
|
| 26 |
model.eval()
|
| 27 |
img0_t = img2tensor(img0).to(device)
|
| 28 |
img1_t = img2tensor(img1).to(device)
|
|
|
|
|
|
|
| 29 |
inputs = [img0_t, img1_t]
|
| 30 |
+
|
| 31 |
+
if device == 'cpu':
|
| 32 |
+
# Do not resize in cpu mode
|
| 33 |
+
anchor_resolution = 8192*8192
|
| 34 |
+
anchor_memory = 1
|
| 35 |
+
anchor_memory_bias = 0
|
| 36 |
+
vram_avail = 1
|
| 37 |
+
elif device == 'cuda':
|
| 38 |
+
anchor_resolution = 1024 * 512
|
| 39 |
+
anchor_memory = 1500 * 1024**2
|
| 40 |
+
anchor_memory_bias = 2500 * 1024**2
|
| 41 |
+
vram_avail = torch.cuda.get_device_properties(device).total_memory
|
| 42 |
embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device)
|
| 43 |
|
| 44 |
+
inputs = check_dim_and_resize(inputs)
|
| 45 |
+
h, w = inputs[0].shape[-2:]
|
| 46 |
+
scale = anchor_resolution / (h * w) * np.sqrt((vram_avail - anchor_memory_bias) / anchor_memory)
|
| 47 |
+
scale = 1 if scale > 1 else scale
|
| 48 |
+
scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16
|
| 49 |
+
if scale < 1:
|
| 50 |
+
print(f"Due to the limited VRAM, the video will be scaled by {scale:.2f}")
|
| 51 |
+
padding = int(16 / scale)
|
| 52 |
+
padder = InputPadder(inputs[0].shape, padding)
|
| 53 |
+
inputs = padder.pad(*inputs)
|
| 54 |
+
|
| 55 |
for i in range(iters):
|
| 56 |
print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}')
|
| 57 |
+
outputs = [inputs[0]]
|
| 58 |
for in_0, in_1 in zip(inputs[:-1], inputs[1:]):
|
| 59 |
+
in_0 = in_0.to(device)
|
| 60 |
+
in_1 = in_1.to(device)
|
| 61 |
with torch.no_grad():
|
| 62 |
+
imgt_pred = model(in_0, in_1, embt, scale_factor=scale, eval=True)['imgt_pred']
|
| 63 |
+
outputs += [imgt_pred.cpu(), in_1.cpu()]
|
|
|
|
|
|
|
| 64 |
inputs = outputs
|
| 65 |
+
outputs = padder.unpad(*outputs)
|
| 66 |
out_path = 'results'
|
| 67 |
size = outputs[0].shape[2:][::-1]
|
| 68 |
writer = cv2.VideoWriter(f'{out_path}/demo.mp4', cv2.VideoWriter_fourcc(*'mp4v'), frame_ratio, size)
|
demo_vid.py
CHANGED
|
@@ -27,7 +27,25 @@ def vid2vid(model_type, video, iters):
|
|
| 27 |
inputs = []
|
| 28 |
h = int(vcap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 29 |
w = int(vcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
while True:
|
| 32 |
ret, frame = vcap.read()
|
| 33 |
if ret is False:
|
|
@@ -43,7 +61,7 @@ def vid2vid(model_type, video, iters):
|
|
| 43 |
outputs = [inputs[0]]
|
| 44 |
for in_0, in_1 in zip(inputs[:-1], inputs[1:]):
|
| 45 |
with torch.no_grad():
|
| 46 |
-
imgt_pred = model(in_0, in_1, embt, eval=True)['imgt_pred']
|
| 47 |
imgt_pred = padder.unpad(imgt_pred)
|
| 48 |
in_1 = padder.unpad(in_1)
|
| 49 |
outputs += [imgt_pred, in_1]
|
|
|
|
| 27 |
inputs = []
|
| 28 |
h = int(vcap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 29 |
w = int(vcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 30 |
+
if device == 'cpu':
|
| 31 |
+
# Do not resize in cpu mode
|
| 32 |
+
anchor_resolution = 8192*8192
|
| 33 |
+
anchor_memory = 1
|
| 34 |
+
anchor_memory_bias = 0
|
| 35 |
+
vram_avail = 1
|
| 36 |
+
elif device == 'cuda':
|
| 37 |
+
anchor_resolution = 1024 * 512
|
| 38 |
+
anchor_memory = 1500 * 1024**2
|
| 39 |
+
anchor_memory_bias = 2500 * 1024**2
|
| 40 |
+
vram_avail = torch.cuda.get_device_properties(device).total_memory
|
| 41 |
+
|
| 42 |
+
scale = anchor_resolution / (h * w) * np.sqrt((vram_avail - anchor_memory_bias) / anchor_memory)
|
| 43 |
+
scale = 1 if scale > 1 else scale
|
| 44 |
+
scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16
|
| 45 |
+
if scale < 1:
|
| 46 |
+
print(f"Due to the limited VRAM, the video will be scaled by {scale:.2f}")
|
| 47 |
+
padding = int(16 / scale)
|
| 48 |
+
padder = InputPadder(inputs[0].shape, padding)
|
| 49 |
while True:
|
| 50 |
ret, frame = vcap.read()
|
| 51 |
if ret is False:
|
|
|
|
| 61 |
outputs = [inputs[0]]
|
| 62 |
for in_0, in_1 in zip(inputs[:-1], inputs[1:]):
|
| 63 |
with torch.no_grad():
|
| 64 |
+
imgt_pred = model(in_0, in_1, embt, scale_factor=scale, eval=True)['imgt_pred']
|
| 65 |
imgt_pred = padder.unpad(imgt_pred)
|
| 66 |
in_1 = padder.unpad(in_1)
|
| 67 |
outputs += [imgt_pred, in_1]
|
utils.py
CHANGED
|
@@ -227,3 +227,21 @@ def warp(img, flow):
|
|
| 227 |
grid_ = (grid + flow_).permute(0, 2, 3, 1)
|
| 228 |
output = F.grid_sample(input=img, grid=grid_, mode='bilinear', padding_mode='border', align_corners=True)
|
| 229 |
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
grid_ = (grid + flow_).permute(0, 2, 3, 1)
|
| 228 |
output = F.grid_sample(input=img, grid=grid_, mode='bilinear', padding_mode='border', align_corners=True)
|
| 229 |
return output
|
| 230 |
+
|
| 231 |
+
def check_dim_and_resize(tensor_list):
|
| 232 |
+
shape_list = []
|
| 233 |
+
for t in tensor_list:
|
| 234 |
+
shape_list.append(t.shape[2:])
|
| 235 |
+
|
| 236 |
+
if len(set(shape_list)) > 1:
|
| 237 |
+
desired_shape = shape_list[0]
|
| 238 |
+
print(f'Inconsistent size of input video frames. All frames will be resized to {desired_shape}')
|
| 239 |
+
|
| 240 |
+
resize_tensor_list = []
|
| 241 |
+
for t in tensor_list:
|
| 242 |
+
resize_tensor_list.append(torch.nn.functional.interpolate(t, size=tuple(desired_shape), mode='bilinear'))
|
| 243 |
+
|
| 244 |
+
tensor_list = resize_tensor_list
|
| 245 |
+
|
| 246 |
+
return tensor_list
|
| 247 |
+
|