File size: 1,929 Bytes
bcdf9fa
7687458
 
679a158
2b3dbdf
bcdf9fa
7687458
 
 
bcdf9fa
679a158
bcdf9fa
 
 
 
2b3dbdf
a23e411
f5b2793
fb0bb52
bcdf9fa
7687458
f5b2793
 
bf33a33
 
 
 
 
bcdf9fa
 
 
 
 
679a158
bcdf9fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b3dbdf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import spaces
from prompts.base_instruction import basic_instruction

# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("braindeck/text2text", trust_remote_code=True, subfolder="checkpoints/model")
model = AutoModelForCausalLM.from_pretrained("braindeck/text2text", trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto", subfolder="checkpoints/model")

@spaces.GPU
def generate_response(prompt):
    """
    Generates a response from the model.
    """
    chat = basic_instruction(prompt, "braindeck/text2text")
    inputs = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
    input_length = inputs.shape[1]
    outputs = model.generate(inputs, max_new_tokens=512, do_sample=False)
    
    # Decode the generated text
    generated_tokens = outputs[0, input_length:]
    generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
    
    # Remove the think block
    ix = generated_text.find("</think>")
    if ix != -1:
        generated_text = generated_text[ix + len("</think>") :].lstrip()
        
    return generated_text

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Fine-tuned Text-to-Text Generation")
    gr.Markdown("Enter a prompt and the model will generate a response.")
    
    with gr.Row():
        prompt_input = gr.Textbox(label="Prompt", lines=4, placeholder="Enter your prompt here...")
        
    with gr.Row():
        generate_button = gr.Button("Generate")
        
    with gr.Row():
        response_output = gr.Textbox(label="Response", lines=8, interactive=False)
        
    generate_button.click(
        fn=generate_response,
        inputs=prompt_input,
        outputs=response_output
    )

if __name__ == "__main__":
    demo.launch()