File size: 1,929 Bytes
bcdf9fa 7687458 679a158 2b3dbdf bcdf9fa 7687458 bcdf9fa 679a158 bcdf9fa 2b3dbdf a23e411 f5b2793 fb0bb52 bcdf9fa 7687458 f5b2793 bf33a33 bcdf9fa 679a158 bcdf9fa 2b3dbdf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import spaces
from prompts.base_instruction import basic_instruction
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("braindeck/text2text", trust_remote_code=True, subfolder="checkpoints/model")
model = AutoModelForCausalLM.from_pretrained("braindeck/text2text", trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto", subfolder="checkpoints/model")
@spaces.GPU
def generate_response(prompt):
"""
Generates a response from the model.
"""
chat = basic_instruction(prompt, "braindeck/text2text")
inputs = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
input_length = inputs.shape[1]
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False)
# Decode the generated text
generated_tokens = outputs[0, input_length:]
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
# Remove the think block
ix = generated_text.find("</think>")
if ix != -1:
generated_text = generated_text[ix + len("</think>") :].lstrip()
return generated_text
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Fine-tuned Text-to-Text Generation")
gr.Markdown("Enter a prompt and the model will generate a response.")
with gr.Row():
prompt_input = gr.Textbox(label="Prompt", lines=4, placeholder="Enter your prompt here...")
with gr.Row():
generate_button = gr.Button("Generate")
with gr.Row():
response_output = gr.Textbox(label="Response", lines=8, interactive=False)
generate_button.click(
fn=generate_response,
inputs=prompt_input,
outputs=response_output
)
if __name__ == "__main__":
demo.launch() |