File size: 6,095 Bytes
9e15530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b4394f
 
9e15530
 
 
 
 
e5e8f93
 
 
 
 
 
 
 
 
 
8004d8d
e5e8f93
 
 
9e15530
 
 
e5e8f93
 
 
9e15530
 
 
 
17c4e47
 
9e15530
 
 
 
 
 
17c4e47
5a3e844
9e15530
 
5a3e844
9e15530
 
 
 
 
 
5a3e844
9e15530
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import gradio as gr
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from transformers import pipeline
import torch
from PIL import Image
from datasets import load_dataset
import soundfile as sf
import random
import string
import spaces

#--- IMAGE CAPTION-
def model():

    model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
    return model
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

if gr.NO_RELOAD: 
    llm_model=model()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
llm_model.to(device)

max_length = 16
num_beams = 4
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}

def predict_step(image_paths):
  images = []
  for image_path in image_paths:
    i_image = Image.open(image_path)
    if i_image.mode != "RGB":
      i_image = i_image.convert(mode="RGB")

    images.append(i_image)

  pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
  pixel_values = pixel_values.to(device)

  output_ids = llm_model.generate(pixel_values, **gen_kwargs)

  preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
  preds = [pred.strip() for pred in preds]
  return preds

##----TEXT TO SPEECH

# load the processor
def load_processor():
   processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
   return processor

# load the model
def load_speech_model():
    speech_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
    return speech_model

# load the vocoder, that is the voice 
def load_vocoder():
    vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
    return vocoder

# we load this dataset to get the speaker embeddings
def load_embeddings_dataset():
    embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
    return embeddings_dataset

# speaker ids from the embeddings dataset
speakers = {
    'awb': 0,     # Scottish male
    'bdl': 1138,  # US male
    'clb': 2271,  # US female
    'jmk': 3403,  # Canadian male
    'ksp': 4535,  # Indian male
    'rms': 5667,  # US male
    'slt': 6799   # US female
}


def save_text_to_speech(text, speaker=None):
    # preprocess text
    inputs = processor(text=text, return_tensors="pt").to(device)
    if speaker is not None:
        # load xvector containing speaker's voice characteristics from a dataset
        speaker_embeddings = torch.tensor(embeddings_dataset[speaker]["xvector"]).unsqueeze(0).to(device)
    else:
        # random vector, meaning a random voice
        speaker_embeddings = torch.randn((1, 512)).to(device)
    # generate speech with the models
    speech = speech_model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder)
    if speaker is not None:
        # if we have a speaker, we use the speaker's ID in the filename
        output_filename = f"{speaker}-{'-'.join(text.split()[:6])}.mp3"
        #output_filename = "speech.mp3"
    else:
        # if we don't have a speaker, we use a random string in the filename
        random_str = ''.join(random.sample(string.ascii_letters+string.digits, k=5))
        output_filename = f"{random_str}-{'-'.join(text.split()[:6])}.mp3"
        #output_filename = "speech.mp3"
    # save the generated speech to a file with 16KHz sampling rate
    sf.write(output_filename, speech.cpu().numpy(), samplerate=16000)
    # return the filename for reference
    return output_filename

def load_text_generator():
    gen = pipeline('text-generation', model='gpt2') # uses GPT-2
    return gen

if gr.NO_RELOAD:
    processor = load_processor()
    speech_model=load_speech_model()
    vocoder=load_vocoder()
    embeddings_dataset = load_embeddings_dataset()
    gen=load_text_generator()

def gradio_predict(image):
    if image is None:
        return "" 
    image_path = "temp_image.jpg"
    image.save(image_path)  # Save the uploaded image temporarily
    prediction = predict_step([image_path])
    return prediction[0].capitalize() if prediction else "Prediction failed."

import re
def remove_last_incomplete_sentence(text):
    # Find all sentences ending with ., !, or ?
    sentences = re.findall(r'[^.!?]*[.!?]', text, re.DOTALL)
    
    # If there's no complete sentence found, return the original text
    if not sentences:
        return text
    
    # Join the complete sentences
    cleaned_text = ''.join(sentences).strip()
    
    return cleaned_text

@spaces.GPU()
def get_story(pred):
    gen_text=gen(pred, max_length=100,)[0]
    cleaned_text = remove_last_incomplete_sentence(gen_text['generated_text'])
    output_filename_2 = save_text_to_speech(cleaned_text, speaker=speakers["slt"])
    return cleaned_text, output_filename_2

#---FRONT END

DESCRIPTION = """ # PictoVerse
        ### Dive into the multiverse of storytelling with PictoVerse, where every image unveils an array of parallel dimensions. 
        PictoVerse crafts captivating narratives from your photos, each set in a distinct universe of its own.
        """

with gr.Blocks() as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(type='pil', label="Image")
            clear_button = gr.Button("Clear")
        with gr.Column(scale=4):
            output_text = gr.Textbox(label="Prediction")
            gen_text = gr.Textbox(label="Generated Story")
            output_filename_2=gr.Audio(label='Audio')
            button1 = gr.Button("Generate Story and Audio")
            button1.click(fn=get_story, inputs=output_text, outputs=[gen_text, output_filename_2])
            

    input_image.change(fn=gradio_predict, inputs=input_image, outputs=output_text)
    clear_button.click(lambda: (None, "", "", None), inputs=[], outputs=[input_image, output_text, gen_text, output_filename_2])


demo.launch()