Spaces:
Runtime error
Runtime error
File size: 6,095 Bytes
9e15530 5b4394f 9e15530 e5e8f93 8004d8d e5e8f93 9e15530 e5e8f93 9e15530 17c4e47 9e15530 17c4e47 5a3e844 9e15530 5a3e844 9e15530 5a3e844 9e15530 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import gradio as gr
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from transformers import pipeline
import torch
from PIL import Image
from datasets import load_dataset
import soundfile as sf
import random
import string
import spaces
#--- IMAGE CAPTION-
def model():
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
return model
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
if gr.NO_RELOAD:
llm_model=model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
llm_model.to(device)
max_length = 16
num_beams = 4
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
def predict_step(image_paths):
images = []
for image_path in image_paths:
i_image = Image.open(image_path)
if i_image.mode != "RGB":
i_image = i_image.convert(mode="RGB")
images.append(i_image)
pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
output_ids = llm_model.generate(pixel_values, **gen_kwargs)
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
return preds
##----TEXT TO SPEECH
# load the processor
def load_processor():
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
return processor
# load the model
def load_speech_model():
speech_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
return speech_model
# load the vocoder, that is the voice
def load_vocoder():
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
return vocoder
# we load this dataset to get the speaker embeddings
def load_embeddings_dataset():
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
return embeddings_dataset
# speaker ids from the embeddings dataset
speakers = {
'awb': 0, # Scottish male
'bdl': 1138, # US male
'clb': 2271, # US female
'jmk': 3403, # Canadian male
'ksp': 4535, # Indian male
'rms': 5667, # US male
'slt': 6799 # US female
}
def save_text_to_speech(text, speaker=None):
# preprocess text
inputs = processor(text=text, return_tensors="pt").to(device)
if speaker is not None:
# load xvector containing speaker's voice characteristics from a dataset
speaker_embeddings = torch.tensor(embeddings_dataset[speaker]["xvector"]).unsqueeze(0).to(device)
else:
# random vector, meaning a random voice
speaker_embeddings = torch.randn((1, 512)).to(device)
# generate speech with the models
speech = speech_model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder)
if speaker is not None:
# if we have a speaker, we use the speaker's ID in the filename
output_filename = f"{speaker}-{'-'.join(text.split()[:6])}.mp3"
#output_filename = "speech.mp3"
else:
# if we don't have a speaker, we use a random string in the filename
random_str = ''.join(random.sample(string.ascii_letters+string.digits, k=5))
output_filename = f"{random_str}-{'-'.join(text.split()[:6])}.mp3"
#output_filename = "speech.mp3"
# save the generated speech to a file with 16KHz sampling rate
sf.write(output_filename, speech.cpu().numpy(), samplerate=16000)
# return the filename for reference
return output_filename
def load_text_generator():
gen = pipeline('text-generation', model='gpt2') # uses GPT-2
return gen
if gr.NO_RELOAD:
processor = load_processor()
speech_model=load_speech_model()
vocoder=load_vocoder()
embeddings_dataset = load_embeddings_dataset()
gen=load_text_generator()
def gradio_predict(image):
if image is None:
return ""
image_path = "temp_image.jpg"
image.save(image_path) # Save the uploaded image temporarily
prediction = predict_step([image_path])
return prediction[0].capitalize() if prediction else "Prediction failed."
import re
def remove_last_incomplete_sentence(text):
# Find all sentences ending with ., !, or ?
sentences = re.findall(r'[^.!?]*[.!?]', text, re.DOTALL)
# If there's no complete sentence found, return the original text
if not sentences:
return text
# Join the complete sentences
cleaned_text = ''.join(sentences).strip()
return cleaned_text
@spaces.GPU()
def get_story(pred):
gen_text=gen(pred, max_length=100,)[0]
cleaned_text = remove_last_incomplete_sentence(gen_text['generated_text'])
output_filename_2 = save_text_to_speech(cleaned_text, speaker=speakers["slt"])
return cleaned_text, output_filename_2
#---FRONT END
DESCRIPTION = """ # PictoVerse
### Dive into the multiverse of storytelling with PictoVerse, where every image unveils an array of parallel dimensions.
PictoVerse crafts captivating narratives from your photos, each set in a distinct universe of its own.
"""
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type='pil', label="Image")
clear_button = gr.Button("Clear")
with gr.Column(scale=4):
output_text = gr.Textbox(label="Prediction")
gen_text = gr.Textbox(label="Generated Story")
output_filename_2=gr.Audio(label='Audio')
button1 = gr.Button("Generate Story and Audio")
button1.click(fn=get_story, inputs=output_text, outputs=[gen_text, output_filename_2])
input_image.change(fn=gradio_predict, inputs=input_image, outputs=output_text)
clear_button.click(lambda: (None, "", "", None), inputs=[], outputs=[input_image, output_text, gen_text, output_filename_2])
demo.launch()
|