Carlex22222 commited on
Commit
67137e5
·
verified ·
1 Parent(s): b6d86f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -59
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py (VERSÃO FINAL COMPLETA PARA GRADIO)
2
 
3
  import gradio as gr
4
  import os
@@ -10,108 +10,111 @@ from pathlib import Path
10
  from huggingface_hub import snapshot_download
11
 
12
  # --- BLOCO DE CONFIGURAÇÃO E DOWNLOAD DE MODELO ---
13
-
14
- # Diretório de trabalho principal e diretório do código SeedVR
15
  APP_DIR = "/app"
16
  SEEDVR_DIR = os.path.join(APP_DIR, "SeedVR")
17
-
18
- # Usamos /tmp, um diretório com permissão de escrita garantida, para modelos e arquivos temporários.
19
  MODEL_CACHE_DIR = "/tmp/models"
20
  CKPTS_DIR = os.path.join(MODEL_CACHE_DIR, "ckpts")
21
 
22
- # Garante que o diretório de checkpoints exista antes de qualquer coisa
23
  os.makedirs(CKPTS_DIR, exist_ok=True)
24
 
25
- # Verifica se um arquivo de modelo chave já existe para evitar redownloads a cada reinício.
26
  if not Path(CKPTS_DIR).joinpath("seedvr2_ema_3b.pth").exists():
27
- print("Baixando os checkpoints do modelo para /tmp/models/ckpts... Isso pode levar alguns minutos.")
28
  snapshot_download(
29
  repo_id="ByteDance-Seed/SeedVR2-3B",
30
  local_dir=CKPTS_DIR,
31
- local_dir_use_symlinks=False, # Parâmetro obsoleto, mas mantido por segurança
32
- allow_patterns=["*.pth", "*.pt"] # Baixa apenas os arquivos de modelo necessários
33
  )
34
  print("Download do modelo concluído.")
35
  else:
36
- print("Checkpoints do modelo já existem em /tmp. Pulando o download.")
37
 
38
  # --------------------------------------------------------------------
39
 
40
  def run_inference(video_path, seed, res_h, res_w):
41
- """
42
- Função principal que será chamada pela interface Gradio.
43
- Ela executa o script torchrun em um subprocesso e transmite os logs em tempo real.
44
- """
45
  if video_path is None:
46
  raise gr.Error("Por favor, faça o upload de um arquivo de vídeo ou imagem.")
47
-
48
- # Cria diretórios temporários únicos para esta execução em /tmp
49
  job_id = str(uuid.uuid4())
50
  input_dir = os.path.join("/tmp", "temp_inputs", job_id)
51
  output_dir = os.path.join("/tmp", "temp_outputs", job_id)
52
  os.makedirs(input_dir, exist_ok=True)
53
  os.makedirs(output_dir, exist_ok=True)
54
-
55
- # O Gradio nos dá um caminho temporário. Copiamos o arquivo para nosso diretório de trabalho.
56
  shutil.copy(video_path, input_dir)
57
-
58
  log_output = ""
59
-
 
 
60
  try:
61
- # O script de inferência é executado a partir de SEEDVR_DIR, então os caminhos precisam ser relativos
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  input_folder_relative = os.path.relpath(input_dir, SEEDVR_DIR)
63
  output_folder_relative = os.path.relpath(output_dir, SEEDVR_DIR)
 
 
 
64
 
65
  command = [
66
  "torchrun", "--nproc-per-node=4",
67
- "projects/inference_seedvr2_3b.py",
68
  "--video_path", input_folder_relative,
69
  "--output_dir", output_folder_relative,
70
  "--seed", str(seed),
71
  "--res_h", str(res_h),
72
  "--res_w", str(res_w),
73
- # Argumento crucial que diz ao script onde encontrar os modelos baixados
74
- "--ckpt_dir", CKPTS_DIR
75
  ]
76
-
77
- # Força o Python a não usar buffer de saída, garantindo logs em tempo real
78
  env = os.environ.copy()
79
  env["PYTHONUNBUFFERED"] = "1"
80
-
81
- log_output += f"Executando comando: {' '.join(command)}\n\n"
82
- yield None, None, log_output # Limpa saídas antigas e mostra o comando na caixa de logs
83
-
84
  process = subprocess.Popen(
85
- command,
86
- cwd=SEEDVR_DIR,
87
- stdout=subprocess.PIPE,
88
- stderr=subprocess.STDOUT,
89
- text=True,
90
- encoding='utf-8',
91
- env=env
92
  )
93
-
94
- # Loop para capturar e transmitir a saída do subprocesso em tempo real para a UI
95
  while True:
96
  output = process.stdout.readline()
97
- if output == '' and process.poll() is not None:
98
- break
99
  if output:
100
  log_output += output
101
- # O yield atualiza a caixa de logs da interface Gradio
102
  yield None, None, log_output
103
-
104
- return_code = process.poll()
105
- if return_code != 0:
106
- raise gr.Error(f"A inferência falhou com o código de saída {return_code}. Verifique os logs para detalhes.")
107
 
108
  output_files = [f for f in os.listdir(output_dir) if f.endswith(('.mp4', '.png', '.jpg', '.jpeg'))]
109
  if not output_files:
110
- raise gr.Error("A inferência foi concluída, mas nenhum arquivo de saída foi encontrado.")
111
-
112
  result_path = os.path.join(output_dir, output_files[0])
113
-
114
- # Determina se a saída é imagem ou vídeo e a retorna para o componente correto
115
  media_type, _ = mimetypes.guess_type(result_path)
116
  if media_type and media_type.startswith("image"):
117
  yield result_path, None, log_output
@@ -119,14 +122,14 @@ def run_inference(video_path, seed, res_h, res_w):
119
  yield None, result_path, log_output
120
 
121
  finally:
122
- # Limpa o diretório de entrada temporário após a conclusão ou falha
123
  shutil.rmtree(input_dir, ignore_errors=True)
124
- # O diretório de saída é mantido para que o Gradio possa exibir o resultado.
125
- # Gradio gerencia a limpeza desses arquivos de saída.
 
126
 
127
- # --- Interface Gráfica Gradio ---
128
  with gr.Blocks(css="footer {display: none !important}") as demo:
129
- gr.Markdown("# 🚀 Interface de Inferência para SeedVR2")
130
  gr.Markdown("Faça o upload de um vídeo ou imagem, ajuste os parâmetros e clique em 'Executar'.")
131
 
132
  with gr.Row():
@@ -137,18 +140,18 @@ with gr.Blocks(css="footer {display: none !important}") as demo:
137
  res_h = gr.Number(value=720, label="Altura da Saída (res_h)")
138
  res_w = gr.Number(value=1280, label="Largura da Saída (res_w)")
139
  run_button = gr.Button("Executar", variant="primary")
140
-
141
  with gr.Column(scale=2):
142
  output_image = gr.Image(label="Saída de Imagem")
143
  output_video = gr.Video(label="Saída de Vídeo")
144
  log_box = gr.Textbox(label="Logs em Tempo Real", lines=15, autoscroll=True, interactive=False)
145
-
146
  run_button.click(
147
  fn=run_inference,
148
  inputs=[input_media, seed, res_h, res_w],
149
  outputs=[output_image, output_video, log_box]
150
  )
151
-
152
 
153
 
154
  demo.queue(max_size=10).launch()
 
1
+ # app.py (VERSÃO FINAL COM MONKEY PATCHING)
2
 
3
  import gradio as gr
4
  import os
 
10
  from huggingface_hub import snapshot_download
11
 
12
  # --- BLOCO DE CONFIGURAÇÃO E DOWNLOAD DE MODELO ---
 
 
13
  APP_DIR = "/app"
14
  SEEDVR_DIR = os.path.join(APP_DIR, "SeedVR")
 
 
15
  MODEL_CACHE_DIR = "/tmp/models"
16
  CKPTS_DIR = os.path.join(MODEL_CACHE_DIR, "ckpts")
17
 
 
18
  os.makedirs(CKPTS_DIR, exist_ok=True)
19
 
 
20
  if not Path(CKPTS_DIR).joinpath("seedvr2_ema_3b.pth").exists():
21
+ print("Baixando os checkpoints do modelo para /tmp/models/ckpts...")
22
  snapshot_download(
23
  repo_id="ByteDance-Seed/SeedVR2-3B",
24
  local_dir=CKPTS_DIR,
25
+ local_dir_use_symlinks=False,
26
+ allow_patterns=["*.pth", "*.pt"]
27
  )
28
  print("Download do modelo concluído.")
29
  else:
30
+ print("Checkpoints do modelo já existem em /tmp.")
31
 
32
  # --------------------------------------------------------------------
33
 
34
  def run_inference(video_path, seed, res_h, res_w):
 
 
 
 
35
  if video_path is None:
36
  raise gr.Error("Por favor, faça o upload de um arquivo de vídeo ou imagem.")
37
+
 
38
  job_id = str(uuid.uuid4())
39
  input_dir = os.path.join("/tmp", "temp_inputs", job_id)
40
  output_dir = os.path.join("/tmp", "temp_outputs", job_id)
41
  os.makedirs(input_dir, exist_ok=True)
42
  os.makedirs(output_dir, exist_ok=True)
43
+
 
44
  shutil.copy(video_path, input_dir)
45
+
46
  log_output = ""
47
+
48
+ # --- LÓGICA DO MONKEY PATCHING ---
49
+ patched_script_path = os.path.join("/tmp", f"inference_patched_{job_id}.py")
50
  try:
51
+ original_script_path = os.path.join(SEEDVR_DIR, "projects", "inference_seedvr2_3b.py")
52
+ with open(original_script_path, 'r') as f:
53
+ script_content = f.read()
54
+
55
+ # Define os caminhos hardcoded a serem substituídos
56
+ default_dit_path = "'./ckpts/seedvr2_ema_3b.pth'"
57
+ default_vae_path = "'./ckpts/ema_vae.pth'" # Assumindo que o VAE também é carregado assim no original
58
+
59
+ # Define os novos caminhos que apontam para nosso diretório em /tmp
60
+ patched_dit_path = f"'{os.path.join(CKPTS_DIR, 'seedvr2_ema_3b.pth')}'"
61
+ patched_vae_path = f"'{os.path.join(CKPTS_DIR, 'ema_vae.pth')}'"
62
+
63
+ # Aplica o "patch" substituindo o texto
64
+ script_content = script_content.replace(default_dit_path, patched_dit_path)
65
+ # Tenta substituir o caminho do VAE também, se existir
66
+ script_content = script_content.replace(default_vae_path, patched_vae_path)
67
+
68
+ # Salva o script modificado em um arquivo temporário
69
+ with open(patched_script_path, 'w') as f:
70
+ f.write(script_content)
71
+ print(f"Script de inferência 'remendado' e salvo em: {patched_script_path}")
72
+
73
+ # ------------------------------------
74
+
75
  input_folder_relative = os.path.relpath(input_dir, SEEDVR_DIR)
76
  output_folder_relative = os.path.relpath(output_dir, SEEDVR_DIR)
77
+
78
+ # O torchrun deve executar o script remendado
79
+ patched_script_relative_path = os.path.relpath(patched_script_path, SEEDVR_DIR)
80
 
81
  command = [
82
  "torchrun", "--nproc-per-node=4",
83
+ patched_script_relative_path, # <-- USA O SCRIPT MODIFICADO
84
  "--video_path", input_folder_relative,
85
  "--output_dir", output_folder_relative,
86
  "--seed", str(seed),
87
  "--res_h", str(res_h),
88
  "--res_w", str(res_w),
 
 
89
  ]
90
+
 
91
  env = os.environ.copy()
92
  env["PYTHONUNBUFFERED"] = "1"
93
+
94
+ log_output += f"Executando comando com script remendado: {' '.join(command)}\n\n"
95
+ yield None, None, log_output
96
+
97
  process = subprocess.Popen(
98
+ command, cwd=SEEDVR_DIR, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
99
+ text=True, encoding='utf-8', env=env
 
 
 
 
 
100
  )
101
+
 
102
  while True:
103
  output = process.stdout.readline()
104
+ if output == '' and process.poll() is not None: break
 
105
  if output:
106
  log_output += output
 
107
  yield None, None, log_output
108
+
109
+ if process.poll() != 0:
110
+ raise gr.Error(f"A inferência falhou. Verifique os logs.")
 
111
 
112
  output_files = [f for f in os.listdir(output_dir) if f.endswith(('.mp4', '.png', '.jpg', '.jpeg'))]
113
  if not output_files:
114
+ raise gr.Error("Nenhum arquivo de saída foi encontrado.")
115
+
116
  result_path = os.path.join(output_dir, output_files[0])
117
+
 
118
  media_type, _ = mimetypes.guess_type(result_path)
119
  if media_type and media_type.startswith("image"):
120
  yield result_path, None, log_output
 
122
  yield None, result_path, log_output
123
 
124
  finally:
 
125
  shutil.rmtree(input_dir, ignore_errors=True)
126
+ # Limpa o script temporário que criamos
127
+ if os.path.exists(patched_script_path):
128
+ os.remove(patched_script_path)
129
 
130
+ # --- Interface Gráfica Gradio (sem alterações) ---
131
  with gr.Blocks(css="footer {display: none !important}") as demo:
132
+ gr.Markdown("# 🚀 Interface de Inferência para SeedVR2 (com Monkey Patching)")
133
  gr.Markdown("Faça o upload de um vídeo ou imagem, ajuste os parâmetros e clique em 'Executar'.")
134
 
135
  with gr.Row():
 
140
  res_h = gr.Number(value=720, label="Altura da Saída (res_h)")
141
  res_w = gr.Number(value=1280, label="Largura da Saída (res_w)")
142
  run_button = gr.Button("Executar", variant="primary")
143
+
144
  with gr.Column(scale=2):
145
  output_image = gr.Image(label="Saída de Imagem")
146
  output_video = gr.Video(label="Saída de Vídeo")
147
  log_box = gr.Textbox(label="Logs em Tempo Real", lines=15, autoscroll=True, interactive=False)
148
+
149
  run_button.click(
150
  fn=run_inference,
151
  inputs=[input_media, seed, res_h, res_w],
152
  outputs=[output_image, output_video, log_box]
153
  )
154
+
155
 
156
 
157
  demo.queue(max_size=10).launch()