Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| import cv2 | |
| import zipfile | |
| import uuid | |
| import numpy as np | |
| import gradio as gr | |
| from naming import im2c | |
| from collections import Counter | |
| COLOR_NAME = ['black', 'brown', 'blue', 'gray', 'green', 'orange', 'pink', 'purple', 'red', 'white', 'yellow'] | |
| def get_top_names(img): | |
| # resize images to smaller size | |
| anchor = 256 | |
| width = img.shape[1] | |
| height = img.shape[0] | |
| if width > 512 or height > 512: | |
| if width >= height: | |
| dim = (np.floor(width/height*anchor).astype(int), anchor) | |
| else: | |
| dim = (anchor, np.floor(height/width*anchor).astype(int)) | |
| img = cv2.resize(img, dim, interpolation=cv2.INTER_LINEAR) | |
| # obtain color names of all the pixels | |
| w2c = np.load('w2c11_j.npy').astype(np.float16) | |
| _, _, name_idx_img, _ = im2c(img, w2c) | |
| # compute the order of each name based on the numbers of each name | |
| filtered_counts = Counter(name_idx_img[name_idx_img <= 10]) | |
| sorted_counts = sorted(filtered_counts.items(), key=lambda x: x[1], reverse=True) | |
| top_3_values = [num for num, count in sorted_counts[:3]] | |
| top_3_counts = [count/(dim[0]*dim[1]) for num, count in sorted_counts[:3]] | |
| top_3_colors = [COLOR_NAME[i] for i in top_3_values] | |
| # print("Top 3 colors:", top_3_counts) | |
| return top_3_values, top_3_counts, top_3_colors | |
| def classify_and_log(images): | |
| # output_folder = "classified_results" | |
| # os.makedirs(output_folder, exist_ok=True) | |
| # create a temporary directory | |
| session_id = str(uuid.uuid4()) | |
| output_dir = f"temp_{session_id}" | |
| os.makedirs(output_dir, exist_ok=True) | |
| category_folders = {i: os.path.join(output_dir, COLOR_NAME[i]) for i in range(11)} | |
| for folder in category_folders.values(): | |
| os.makedirs(folder, exist_ok=True) | |
| log_file = os.path.join(output_dir, "top3colors.txt") | |
| results = {i: [] for i in range(11)} | |
| with open(log_file, "w") as log: | |
| for id_img, img in enumerate(images): | |
| filename = os.path.basename(img.name) | |
| img_array = cv2.imread(img).astype(np.float32) | |
| img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2RGB) | |
| cat_id, cat_counts, category = get_top_names(img_array) | |
| for i in range(3): | |
| if cat_counts[i] > 0.15: | |
| target_path = os.path.join(category_folders[cat_id[i]], filename) | |
| cv2.imwrite(target_path, cv2.cvtColor(img_array, cv2.COLOR_BGR2RGB)) | |
| # print(f"Image:{filename} -> Top 3 colors:{category}\n") | |
| log.write(f"{filename} -> 1 {category[0]} {100*cat_counts[0]:.2f}%, 2 {category[1]} {100*cat_counts[1]:.2f}%, 3 {category[2]} {100*cat_counts[2]:.2f}%\n") | |
| results[cat_id[0]].append(target_path) | |
| # compile all images into a zip file | |
| zip_path = f"{output_dir}.zip" | |
| with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf: | |
| for root, _, files in os.walk(output_dir): | |
| for file in files: | |
| file_path = os.path.join(root, file) | |
| arcname = os.path.relpath(file_path, start=output_dir) | |
| zipf.write(file_path, arcname) | |
| # optional: clean up the output directory | |
| shutil.rmtree(output_dir) | |
| return zip_path | |
| def swap_to_gallery(images): | |
| return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False) | |
| def upload_example_to_gallery(images, prompt, style, negative_prompt): | |
| return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False) | |
| def remove_back_to_files(): | |
| return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Image color categorization") | |
| gr.Markdown("We categorize images into different classes based on the frequency of different colors appearing in each image.") | |
| gr.Markdown("The 11 color catergories include: black, brown, blue, gray, green, orange, pink, purple, red, white and yellow.") | |
| gr.Markdown("The classification is based on the color naming model from paper _Van De Weijer, Joost, et al. Learning color names for real-world applications. IEEE Transactions on Image Processing 18.7 (2009): 1512-1523._") | |
| gr.Markdown("The output results are in a zip file with all the images in the correspoding folders.") | |
| gr.Markdown("Note that one image can be classified into multiple categories (top 3 categories and frequency > 15%), and the top 3 categories are listed in the log file.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.File( | |
| label="Drag/Select more than one images", | |
| file_types=["image"], | |
| file_count="multiple" | |
| ) | |
| uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=200) | |
| with gr.Column(visible=False) as clear_button: | |
| remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=image_input, size="sm") | |
| image_input.upload(fn=swap_to_gallery, inputs=image_input, outputs=[uploaded_files, clear_button, image_input]) | |
| remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files, clear_button, image_input]) | |
| classify_btn = gr.Button("submit") | |
| # with gr.Row(): | |
| # image_output = {str(i): gr.Gallery(label=f"{i}") for i in range(11)} | |
| log_output = gr.File(label="download results") | |
| classify_btn.click( | |
| classify_and_log, | |
| inputs=[image_input], | |
| outputs=[log_output] | |
| ) | |
| demo.launch() | |