Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan | |
| from transformers import pipeline | |
| import torch | |
| from PIL import Image | |
| from datasets import load_dataset | |
| import soundfile as sf | |
| import random | |
| import string | |
| import spaces | |
| #--- IMAGE CAPTION- | |
| def model(): | |
| model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
| return model | |
| feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
| tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
| if gr.NO_RELOAD: | |
| llm_model=model() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| llm_model.to(device) | |
| max_length = 16 | |
| num_beams = 4 | |
| gen_kwargs = {"max_length": max_length, "num_beams": num_beams} | |
| def predict_step(image_paths): | |
| images = [] | |
| for image_path in image_paths: | |
| i_image = Image.open(image_path) | |
| if i_image.mode != "RGB": | |
| i_image = i_image.convert(mode="RGB") | |
| images.append(i_image) | |
| pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values | |
| pixel_values = pixel_values.to(device) | |
| output_ids = llm_model.generate(pixel_values, **gen_kwargs) | |
| preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) | |
| preds = [pred.strip() for pred in preds] | |
| return preds | |
| ##----TEXT TO SPEECH | |
| # load the processor | |
| def load_processor(): | |
| processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") | |
| return processor | |
| # load the model | |
| def load_speech_model(): | |
| speech_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device) | |
| return speech_model | |
| # load the vocoder, that is the voice | |
| def load_vocoder(): | |
| vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device) | |
| return vocoder | |
| # we load this dataset to get the speaker embeddings | |
| def load_embeddings_dataset(): | |
| embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") | |
| return embeddings_dataset | |
| # speaker ids from the embeddings dataset | |
| speakers = { | |
| 'awb': 0, # Scottish male | |
| 'bdl': 1138, # US male | |
| 'clb': 2271, # US female | |
| 'jmk': 3403, # Canadian male | |
| 'ksp': 4535, # Indian male | |
| 'rms': 5667, # US male | |
| 'slt': 6799 # US female | |
| } | |
| def save_text_to_speech(text, speaker=None): | |
| # preprocess text | |
| inputs = processor(text=text, return_tensors="pt").to(device) | |
| if speaker is not None: | |
| # load xvector containing speaker's voice characteristics from a dataset | |
| speaker_embeddings = torch.tensor(embeddings_dataset[speaker]["xvector"]).unsqueeze(0).to(device) | |
| else: | |
| # random vector, meaning a random voice | |
| speaker_embeddings = torch.randn((1, 512)).to(device) | |
| # generate speech with the models | |
| speech = speech_model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder) | |
| if speaker is not None: | |
| # if we have a speaker, we use the speaker's ID in the filename | |
| output_filename = f"{speaker}-{'-'.join(text.split()[:6])}.mp3" | |
| #output_filename = "speech.mp3" | |
| else: | |
| # if we don't have a speaker, we use a random string in the filename | |
| random_str = ''.join(random.sample(string.ascii_letters+string.digits, k=5)) | |
| output_filename = f"{random_str}-{'-'.join(text.split()[:6])}.mp3" | |
| #output_filename = "speech.mp3" | |
| # save the generated speech to a file with 16KHz sampling rate | |
| sf.write(output_filename, speech.cpu().numpy(), samplerate=16000) | |
| # return the filename for reference | |
| return output_filename | |
| def load_text_generator(): | |
| gen = pipeline('text-generation', model='gpt2') # uses GPT-2 | |
| return gen | |
| if gr.NO_RELOAD: | |
| processor = load_processor() | |
| speech_model=load_speech_model() | |
| vocoder=load_vocoder() | |
| embeddings_dataset = load_embeddings_dataset() | |
| gen=load_text_generator() | |
| def gradio_predict(image): | |
| if image is None: | |
| return "" | |
| image_path = "temp_image.jpg" | |
| image.save(image_path) # Save the uploaded image temporarily | |
| prediction = predict_step([image_path]) | |
| return prediction[0].capitalize() if prediction else "Prediction failed." | |
| import re | |
| def remove_last_incomplete_sentence(text): | |
| # Find all sentences ending with ., !, or ? | |
| sentences = re.findall(r'[^.!?]*[.!?]', text, re.DOTALL) | |
| # If there's no complete sentence found, return the original text | |
| if not sentences: | |
| return text | |
| # Join the complete sentences | |
| cleaned_text = ''.join(sentences).strip() | |
| return cleaned_text | |
| def get_story(pred): | |
| gen_text=gen(pred, max_length=100,)[0] | |
| cleaned_text = remove_last_incomplete_sentence(gen_text['generated_text']) | |
| output_filename_2 = save_text_to_speech(cleaned_text, speaker=speakers["slt"]) | |
| return cleaned_text, output_filename_2 | |
| #---FRONT END | |
| DESCRIPTION = """ # PictoVerse | |
| ### Dive into the multiverse of storytelling with PictoVerse, where every image unveils an array of parallel dimensions. | |
| PictoVerse crafts captivating narratives from your photos, each set in a distinct universe of its own. | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(type='pil', label="Image") | |
| clear_button = gr.Button("Clear") | |
| with gr.Column(scale=4): | |
| output_text = gr.Textbox(label="Prediction") | |
| gen_text = gr.Textbox(label="Generated Story") | |
| output_filename_2=gr.Audio(label='Audio') | |
| button1 = gr.Button("Generate Story and Audio") | |
| button1.click(fn=get_story, inputs=output_text, outputs=[gen_text, output_filename_2]) | |
| input_image.change(fn=gradio_predict, inputs=input_image, outputs=output_text) | |
| clear_button.click(lambda: (None, "", "", None), inputs=[], outputs=[input_image, output_text, gen_text, output_filename_2]) | |
| demo.launch() | |