Update app.py
Browse files
app.py
CHANGED
|
@@ -368,75 +368,103 @@ def main(device, segment_type):
|
|
| 368 |
|
| 369 |
@spaces.GPU
|
| 370 |
def generate_image(prompt1, negative_prompt, man, woman, resolution, local_prompt1, local_prompt2, seed, condition, condition_img1, style):
|
| 371 |
-
try:
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
image = sample_image(
|
| 441 |
pipe,
|
| 442 |
input_prompt=input_prompt,
|
|
@@ -444,47 +472,19 @@ def main(device, segment_type):
|
|
| 444 |
input_neg_prompt=[negative_prompt] * len(input_prompt),
|
| 445 |
generator=torch.Generator(device).manual_seed(seed),
|
| 446 |
controller=controller,
|
| 447 |
-
stage=
|
|
|
|
| 448 |
lora_list=pipe_list,
|
| 449 |
styleL=styleL,
|
| 450 |
**kwargs)
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
if pipe.tokenizer("woman")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
|
| 460 |
-
mask2 = predict_mask(detect_model, sam, image[0], 'woman', args.segment_type, confidence=0.15,
|
| 461 |
-
threshold=0.5)
|
| 462 |
-
else:
|
| 463 |
-
mask2 = None
|
| 464 |
-
|
| 465 |
-
if mask1 is None and mask2 is None:
|
| 466 |
-
output_list.append(image[1])
|
| 467 |
-
else:
|
| 468 |
-
image = sample_image(
|
| 469 |
-
pipe,
|
| 470 |
-
input_prompt=input_prompt,
|
| 471 |
-
concept_models=pipe_concept,
|
| 472 |
-
input_neg_prompt=[negative_prompt] * len(input_prompt),
|
| 473 |
-
generator=torch.Generator(device).manual_seed(seed),
|
| 474 |
-
controller=controller,
|
| 475 |
-
stage=2,
|
| 476 |
-
region_masks=[mask1, mask2],
|
| 477 |
-
lora_list=pipe_list,
|
| 478 |
-
styleL=styleL,
|
| 479 |
-
**kwargs)
|
| 480 |
-
output_list.append(image[1])
|
| 481 |
-
else:
|
| 482 |
-
output_list.append(None)
|
| 483 |
-
output_list.append(spatial_condition)
|
| 484 |
-
return output_list
|
| 485 |
-
except:
|
| 486 |
-
print("error")
|
| 487 |
-
return
|
| 488 |
|
| 489 |
def get_local_value_man(input):
|
| 490 |
return character_man[input][0]
|
|
|
|
| 368 |
|
| 369 |
@spaces.GPU
|
| 370 |
def generate_image(prompt1, negative_prompt, man, woman, resolution, local_prompt1, local_prompt2, seed, condition, condition_img1, style):
|
| 371 |
+
# try:
|
| 372 |
+
path1 = lorapath_man[man]
|
| 373 |
+
path2 = lorapath_woman[woman]
|
| 374 |
+
pipe_concept.unload_lora_weights()
|
| 375 |
+
pipe.unload_lora_weights()
|
| 376 |
+
pipe_list = build_model_lora(pipe_concept, path1 + "|" + path2, lorapath_styles[style], condition, args, pipe)
|
| 377 |
+
|
| 378 |
+
if lorapath_styles[style] is not None and os.path.exists(lorapath_styles[style]):
|
| 379 |
+
styleL = True
|
| 380 |
+
else:
|
| 381 |
+
styleL = False
|
| 382 |
+
|
| 383 |
+
input_list = [prompt1]
|
| 384 |
+
condition_list = [condition_img1]
|
| 385 |
+
output_list = []
|
| 386 |
+
|
| 387 |
+
width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
|
| 388 |
+
|
| 389 |
+
kwargs = {
|
| 390 |
+
'height': height,
|
| 391 |
+
'width': width,
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
for prompt, condition_img in zip(input_list, condition_list):
|
| 395 |
+
if prompt!='':
|
| 396 |
+
input_prompt = []
|
| 397 |
+
p = '{prompt}, 35mm photograph, film, professional, 4k, highly detailed.'
|
| 398 |
+
if styleL:
|
| 399 |
+
p = styles[style] + p
|
| 400 |
+
input_prompt.append([p.replace("{prompt}", prompt), p.replace("{prompt}", prompt)])
|
| 401 |
+
if styleL:
|
| 402 |
+
input_prompt.append([(styles[style] + local_prompt1, character_man.get(man)[1]),
|
| 403 |
+
(styles[style] + local_prompt2, character_woman.get(woman)[1])])
|
| 404 |
+
else:
|
| 405 |
+
input_prompt.append([(local_prompt1, character_man.get(man)[1]),
|
| 406 |
+
(local_prompt2, character_woman.get(woman)[1])])
|
| 407 |
+
|
| 408 |
+
if condition == 'Human pose' and condition_img is not None:
|
| 409 |
+
index = ratio_list.index(
|
| 410 |
+
min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
|
| 411 |
+
resolution = resolution_list[index]
|
| 412 |
+
width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
|
| 413 |
+
kwargs['height'] = height
|
| 414 |
+
kwargs['width'] = width
|
| 415 |
+
condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
|
| 416 |
+
spatial_condition = get_humanpose(condition_img)
|
| 417 |
+
elif condition == 'Canny Edge' and condition_img is not None:
|
| 418 |
+
index = ratio_list.index(
|
| 419 |
+
min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
|
| 420 |
+
resolution = resolution_list[index]
|
| 421 |
+
width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
|
| 422 |
+
kwargs['height'] = height
|
| 423 |
+
kwargs['width'] = width
|
| 424 |
+
condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
|
| 425 |
+
spatial_condition = get_cannyedge(condition_img)
|
| 426 |
+
elif condition == 'Depth' and condition_img is not None:
|
| 427 |
+
index = ratio_list.index(
|
| 428 |
+
min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
|
| 429 |
+
resolution = resolution_list[index]
|
| 430 |
+
width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
|
| 431 |
+
kwargs['height'] = height
|
| 432 |
+
kwargs['width'] = width
|
| 433 |
+
condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
|
| 434 |
+
spatial_condition = get_depth(condition_img)
|
| 435 |
+
else:
|
| 436 |
+
spatial_condition = None
|
| 437 |
+
|
| 438 |
+
kwargs['spatial_condition'] = spatial_condition
|
| 439 |
+
controller.reset()
|
| 440 |
+
image = sample_image(
|
| 441 |
+
pipe,
|
| 442 |
+
input_prompt=input_prompt,
|
| 443 |
+
concept_models=pipe_concept,
|
| 444 |
+
input_neg_prompt=[negative_prompt] * len(input_prompt),
|
| 445 |
+
generator=torch.Generator(device).manual_seed(seed),
|
| 446 |
+
controller=controller,
|
| 447 |
+
stage=1,
|
| 448 |
+
lora_list=pipe_list,
|
| 449 |
+
styleL=styleL,
|
| 450 |
+
**kwargs)
|
| 451 |
+
|
| 452 |
+
controller.reset()
|
| 453 |
+
if pipe.tokenizer("man")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
|
| 454 |
+
mask1 = predict_mask(detect_model, sam, image[0], 'man', args.segment_type, confidence=0.15,
|
| 455 |
+
threshold=0.5)
|
| 456 |
+
else:
|
| 457 |
+
mask1 = None
|
| 458 |
+
|
| 459 |
+
if pipe.tokenizer("woman")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
|
| 460 |
+
mask2 = predict_mask(detect_model, sam, image[0], 'woman', args.segment_type, confidence=0.15,
|
| 461 |
+
threshold=0.5)
|
| 462 |
+
else:
|
| 463 |
+
mask2 = None
|
| 464 |
+
|
| 465 |
+
if mask1 is None and mask2 is None:
|
| 466 |
+
output_list.append(image[1])
|
| 467 |
+
else:
|
| 468 |
image = sample_image(
|
| 469 |
pipe,
|
| 470 |
input_prompt=input_prompt,
|
|
|
|
| 472 |
input_neg_prompt=[negative_prompt] * len(input_prompt),
|
| 473 |
generator=torch.Generator(device).manual_seed(seed),
|
| 474 |
controller=controller,
|
| 475 |
+
stage=2,
|
| 476 |
+
region_masks=[mask1, mask2],
|
| 477 |
lora_list=pipe_list,
|
| 478 |
styleL=styleL,
|
| 479 |
**kwargs)
|
| 480 |
+
output_list.append(image[1])
|
| 481 |
+
else:
|
| 482 |
+
output_list.append(None)
|
| 483 |
+
output_list.append(spatial_condition)
|
| 484 |
+
return output_list
|
| 485 |
+
# except:
|
| 486 |
+
# print("error")
|
| 487 |
+
# return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
|
| 489 |
def get_local_value_man(input):
|
| 490 |
return character_man[input][0]
|