Spaces:
Runtime error
Runtime error
| import argparse | |
| import os | |
| import re | |
| import threading | |
| import time | |
| from datetime import datetime, timedelta | |
| import torch | |
| from threading import Thread, Event | |
| from PIL import Image, ImageDraw | |
| import gradio as gr | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| TextIteratorStreamer, | |
| ) | |
| from typing import List | |
| import spaces | |
| stop_event = Event() | |
| def delete_old_files(): | |
| while True: | |
| now = datetime.now() | |
| cutoff = now - timedelta(minutes=10) | |
| directories = ["./outputs", "./gradio_tmp"] | |
| for directory in directories: | |
| for filename in os.listdir(directory): | |
| file_path = os.path.join(directory, filename) | |
| if os.path.isfile(file_path): | |
| file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path)) | |
| if file_mtime < cutoff: | |
| os.remove(file_path) | |
| time.sleep(600) | |
| threading.Thread(target=delete_old_files, daemon=True).start() | |
| def draw_boxes_on_image(image: Image.Image, boxes: List[List[float]], save_path: str): | |
| draw = ImageDraw.Draw(image) | |
| for box in boxes: | |
| x_min = int(box[0] * image.width) | |
| y_min = int(box[1] * image.height) | |
| x_max = int(box[2] * image.width) | |
| y_max = int(box[3] * image.height) | |
| draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3) | |
| image.save(save_path) | |
| def preprocess_messages(history, img_path, platform_str, format_str): | |
| history_step = [] | |
| for task, model_msg in history: | |
| grounded_pattern = r"Grounded Operation:\s*(.*)" | |
| matches_history = re.search(grounded_pattern, model_msg) | |
| if matches_history: | |
| grounded_operation = matches_history.group(1) | |
| history_step.append(grounded_operation) | |
| history_str = "\nHistory steps: " | |
| if history_step: | |
| for i, step in enumerate(history_step): | |
| history_str += f"\n{i}. {step}" | |
| if history: | |
| task = history[-1][0] | |
| else: | |
| task = "No task provided" | |
| query = f"Task: {task}{history_str}\n{platform_str}{format_str}" | |
| image = Image.open(img_path).convert("RGB") | |
| return query, image | |
| def predict(history, max_length, img_path, platform_str, format_str, output_dir): | |
| # Reset the stop_event at the start of prediction | |
| stop_event.clear() | |
| # Remember history length before this round (for rollback if stopped) | |
| prev_len = len(history) | |
| query, image = preprocess_messages(history, img_path, platform_str, format_str) | |
| inputs = tokenizer.apply_chat_template( | |
| [{"role": "user", "image": image, "content": query}], | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_tensors="pt", | |
| return_dict=True, | |
| ).to(model.device) | |
| streamer = TextIteratorStreamer( | |
| tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| generate_kwargs = { | |
| "input_ids": inputs["input_ids"], | |
| "attention_mask": inputs["attention_mask"], | |
| "position_ids": inputs["position_ids"], | |
| "images": inputs["images"], | |
| "streamer": streamer, | |
| "max_length": max_length, | |
| "do_sample": True, | |
| "top_k": 1, | |
| } | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| for new_token in streamer: | |
| # Check if stop event is set | |
| if stop_event.is_set(): | |
| # Stop generation immediately | |
| # Rollback the last round user input | |
| while len(history) > prev_len: | |
| history.pop() | |
| yield history, None | |
| return | |
| if new_token: | |
| history[-1][1] += new_token | |
| yield history, None | |
| # If finished without stop event | |
| response = history[-1][1] | |
| box_pattern = r"box=\[\[?(\d+),(\d+),(\d+),(\d+)\]?\]" | |
| matches = re.findall(box_pattern, response) | |
| if matches: | |
| boxes = [[int(x) / 1000 for x in match] for match in matches] | |
| os.makedirs(output_dir, exist_ok=True) | |
| base_name = os.path.splitext(os.path.basename(img_path))[0] | |
| round_num = sum(1 for (u, m) in history if u and m) | |
| output_path = os.path.join(output_dir, f"{base_name}_{round_num}.png") | |
| image = Image.open(img_path).convert("RGB") | |
| draw_boxes_on_image(image, boxes, output_path) | |
| yield history, output_path | |
| else: | |
| yield history, None | |
| def user(task, history): | |
| return "", history + [[task, ""]] | |
| def undo_last_round(history, output_img): | |
| if history: | |
| history.pop() | |
| return history, None | |
| def clear_all_history(): | |
| return None, None | |
| def stop_now(): | |
| stop_event.set() | |
| return gr.update(), gr.update() | |
| def main(): | |
| parser = argparse.ArgumentParser(description="CogAgent Gradio Demo") | |
| parser.add_argument("--model_dir", default="THUDM/cogagent-9b-20241220", help="Path or identifier of the model.") | |
| parser.add_argument("--format_key", default="action_op_sensitive", help="Key to select the prompt format.") | |
| parser.add_argument("--platform", default="Mac", help="Platform information string.") | |
| parser.add_argument("--output_dir", default="outputs", help="Directory to save annotated images.") | |
| args = parser.parse_args() | |
| format_dict = { | |
| "action_op_sensitive": "(Answer in Action-Operation-Sensitive format.)", | |
| "status_plan_action_op": "(Answer in Status-Plan-Action-Operation format.)", | |
| "status_action_op_sensitive": "(Answer in Status-Action-Operation-Sensitive format.)", | |
| "status_action_op": "(Answer in Status-Action-Operation format.)", | |
| "action_op": "(Answer in Action-Operation format.)" | |
| } | |
| if args.format_key not in format_dict: | |
| raise ValueError(f"Invalid format_key. Available keys: {list(format_dict.keys())}") | |
| global tokenizer, model | |
| tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| args.model_dir, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto" | |
| ).eval() | |
| platform_str = f"(Platform: {args.platform})\n" | |
| format_str = format_dict[args.format_key] | |
| with gr.Blocks(analytics_enabled=False) as demo: | |
| gr.HTML("<h1 align='center'>CogAgent-9B-20241220 Demo</h1>") | |
| gr.HTML( | |
| """ | |
| <p align='center' style='color:red;'>This demo is for learning and communication purposes only. Users must assume responsibility for the risks associated with AI-generated planning and operations.</p> | |
| <p align='center' style='color:red;'>In this demo, the model assumes that the user is using a Mac operating system. Therefore, it is recommended to upload screenshots taken on a Mac.</p> | |
| <p align='left' style='color:black;'>1. Upload a screenshot from your computer (must be from a Mac, and a full-screen screenshot).<br> | |
| 2. Provide your instructions to CogAgent (e.g., send a message to XXX).<br> | |
| 3. Wait for CogAgent to return specific operations. If bounding boxes (Bbox) are detected, they will be displayed in the image area on the right.</p> | |
| <p align='left' style='color:black;'>The model will only return the next step's instructions. The online demo cannot control your computer. Please visit the <a href="https://github.com/THUDM/CogAgent">GitHub repository</a> for the full version of the demo.</p> | |
| """ | |
| ) | |
| with gr.Row(): | |
| img_path = gr.Image(label="Upload a Screenshot", type="filepath", height=400) | |
| output_img = gr.Image(type="filepath", label="Annotated Image(If Bbox Return)", height=400, interactive=False) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot(height=300) | |
| task = gr.Textbox(show_label=True, placeholder="Input...", label="Task") | |
| submitBtn = gr.Button("Submit") | |
| with gr.Column(scale=1): | |
| max_length = gr.Slider(0, 8192, value=1024, step=1.0, label="Maximum length", interactive=True) | |
| undo_last_round_btn = gr.Button("Back to Last Round") | |
| clear_history_btn = gr.Button("Clear All History") | |
| # 添加红色的立刻中断按钮,点击后中断生成并回滚当前轮历史 | |
| stop_now_btn = gr.Button("Stop Now", variant="stop") | |
| submitBtn.click( | |
| user, [task, chatbot], [task, chatbot], queue=False | |
| ).then( | |
| predict, | |
| [chatbot, max_length, img_path, gr.State(platform_str), gr.State(format_str), | |
| gr.State(args.output_dir)], | |
| [chatbot, output_img], | |
| queue=True | |
| ) | |
| undo_last_round_btn.click(undo_last_round, [chatbot, output_img], [chatbot, output_img], queue=False) | |
| clear_history_btn.click(clear_all_history, None, [chatbot, output_img], queue=False) | |
| stop_now_btn.click(stop_now, None, [chatbot, output_img], queue=False) | |
| demo.queue() | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() | |