Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import io | |
| import time | |
| from streamlit_drawable_canvas import st_canvas | |
| # Helper functions | |
| def np_to_pil(np_img_bgr): | |
| if len(np_img_bgr.shape) == 2: | |
| return Image.fromarray(np_img_bgr) | |
| else: | |
| return Image.fromarray(np_img_bgr[..., ::-1]) | |
| def pil_to_np(pil_img): | |
| np_img_rgb = np.array(pil_img) | |
| if np_img_rgb.shape[-1] == 4: | |
| np_img_rgb = np_img_rgb[..., :3] | |
| return np_img_rgb[..., ::-1] | |
| def download_button_img(np_img_bgr, label, filename): | |
| img = np_to_pil(np_img_bgr) | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| st.download_button(label, data=buf.getvalue(), file_name=filename, mime="image/png") | |
| # Set page config | |
| st.set_page_config(page_title="Image Restoration App", layout="wide") | |
| st.title("Image Restoration App") | |
| # Upload section | |
| st.sidebar.title("Upload Image") | |
| uploaded_file = st.sidebar.file_uploader("Choose an image", type=["png", "jpg", "jpeg"]) | |
| if "orig_image" not in st.session_state: | |
| st.session_state.orig_image = None | |
| if "current_image" not in st.session_state: | |
| st.session_state.current_image = None | |
| if "inpaint_result" not in st.session_state: | |
| st.session_state.inpaint_result = None | |
| if "canvas_result" not in st.session_state: | |
| st.session_state.canvas_result = None | |
| if uploaded_file: | |
| file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8) | |
| image = cv2.imdecode(file_bytes, 1) | |
| st.session_state.orig_image = image | |
| st.session_state.current_image = image.copy() | |
| st.session_state.inpaint_result = None | |
| if st.session_state.orig_image is None: | |
| st.info("Upload an image to get started.") | |
| st.stop() | |
| # Tabs | |
| tabs = st.tabs(["Filters", "Inpainting", "Compare"]) | |
| # FILTERS TAB | |
| with tabs[0]: | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| st.subheader("Filters") | |
| filter_type = st.selectbox( | |
| "Choose filter:", | |
| ["None", "Gaussian", "Median", "Bilateral", "Brightness/Contrast", "Grayscale"], | |
| key="filter", | |
| ) | |
| if filter_type == "Gaussian": | |
| ksize = st.slider("Kernel Size", 1, 31, 5, step=2, key="gauss_ksize") | |
| sigma = st.slider("Sigma X", 0.0, 10.0, 2.0, key="gauss_sigma") | |
| elif filter_type == "Median": | |
| ksize = st.slider("Kernel Size", 1, 31, 5, step=2, key="median_ksize") | |
| elif filter_type == "Bilateral": | |
| d = st.slider("Diameter", 1, 30, 9, key="bilateral_d") | |
| sigmaColor = st.slider("Sigma Color", 1, 150, 75, key="bilateral_color") | |
| sigmaSpace = st.slider("Sigma Space", 1, 150, 75, key="bilateral_space") | |
| elif filter_type == "Brightness/Contrast": | |
| brightness = st.slider("Brightness", -100, 100, 0, key="brightness") | |
| contrast = st.slider("Contrast", -100, 100, 0, key="contrast") | |
| if st.button("Apply Filter", key="apply_filter"): | |
| img = st.session_state.current_image.copy() | |
| if filter_type == "Gaussian": | |
| img = cv2.GaussianBlur(img, (ksize, ksize), sigma) | |
| elif filter_type == "Median": | |
| img = cv2.medianBlur(img, ksize) | |
| elif filter_type == "Bilateral": | |
| img = cv2.bilateralFilter(img, d, sigmaColor, sigmaSpace) | |
| elif filter_type == "Brightness/Contrast": | |
| img = cv2.convertScaleAbs(img, alpha=1 + contrast / 100.0, beta=brightness) | |
| elif filter_type == "Grayscale": | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
| st.session_state.current_image = img | |
| st.session_state.inpaint_result = None | |
| if st.button("Reset Image", key="reset_filter"): | |
| st.session_state.current_image = st.session_state.orig_image.copy() | |
| st.session_state.inpaint_result = None | |
| with col2: | |
| st.subheader("Image Preview") | |
| img = st.session_state.current_image | |
| st.image(img if len(img.shape) == 2 else img[..., ::-1], use_container_width=True) | |
| # INPAINTING TAB | |
| with tabs[1]: | |
| col1, col2, col3 = st.columns([1, 1.5, 1.5]) | |
| with col1: | |
| st.subheader("Inpainting Settings") | |
| stroke_width = st.slider("Stroke Width", 1, 25, 5, key="stroke") | |
| method = st.selectbox("Inpainting Method", ["Telea", "NS"], key="inpaint_method") | |
| if st.button("Apply Inpaint", key="apply_inpaint"): | |
| canvas = st.session_state.get("canvas_result") | |
| if canvas and canvas.image_data is not None: | |
| mask_rgba = canvas.image_data | |
| if mask_rgba.shape[-1] == 4: | |
| mask = mask_rgba[..., 3] | |
| h, w = st.session_state.current_image.shape[:2] | |
| mask = cv2.resize(mask, (w, h)) | |
| mask = (mask > 0).astype(np.uint8) * 255 | |
| flag = cv2.INPAINT_TELEA if method == "Telea" else cv2.INPAINT_NS | |
| result = cv2.inpaint(st.session_state.current_image, mask, 3, flag) | |
| st.session_state.inpaint_result = result | |
| if st.button("Reset to Original", key="reset_inpaint"): | |
| st.session_state.current_image = st.session_state.orig_image.copy() | |
| st.session_state.inpaint_result = None | |
| st.markdown("---") | |
| if st.button("Reset Canvas"): | |
| st.session_state.canvas_key = f"canvas_{int(time.time())}" | |
| with col2: | |
| st.subheader("Draw Mask") | |
| h, w = st.session_state.current_image.shape[:2] | |
| max_width = 500 | |
| scale = min(1.0, max_width / w) | |
| canvas_w, canvas_h = int(w * scale), int(h * scale) | |
| show_mask = st.checkbox("Show Mask Preview", key="show_mask") | |
| if "canvas_key" not in st.session_state: | |
| st.session_state.canvas_key = "canvas" | |
| if not show_mask: | |
| pil_bg = np_to_pil(st.session_state.current_image).resize((canvas_w, canvas_h)) | |
| canvas = st_canvas( | |
| fill_color="white", | |
| stroke_width=stroke_width, | |
| stroke_color="black", | |
| background_image=pil_bg, | |
| update_streamlit=True, | |
| height=canvas_h, | |
| width=canvas_w, | |
| drawing_mode="freedraw", | |
| key=st.session_state.canvas_key, | |
| ) | |
| st.session_state.canvas_result = canvas | |
| else: | |
| canvas = st.session_state.get("canvas_result") | |
| if canvas and canvas.image_data is not None: | |
| mask = canvas.image_data[..., 3] if canvas.image_data.shape[-1] == 4 else None | |
| if mask is not None: | |
| mask = cv2.resize(mask, (w, h)) | |
| mask = (mask > 0).astype(np.uint8) * 255 | |
| st.image(mask, caption="Inpainting Mask", use_container_width=True) | |
| with col3: | |
| st.subheader("Inpainting Result") | |
| result = st.session_state.inpaint_result | |
| if result is not None: | |
| st.image(result[..., ::-1], use_container_width=True) | |
| download_button_img(result, "Download Inpainted Image", "inpainted_result.png") | |
| else: | |
| st.info("Draw a mask and apply inpainting to see result.") | |
| # COMPARE TAB | |
| with tabs[2]: | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.subheader("Original Image") | |
| orig = st.session_state.orig_image | |
| st.image(orig[..., ::-1], use_container_width=True) | |
| download_button_img(orig, "Download Original", "original.png") | |
| with col2: | |
| st.subheader("Processed Image") | |
| current = ( | |
| st.session_state.inpaint_result | |
| if st.session_state.inpaint_result is not None | |
| else st.session_state.current_image | |
| ) | |
| st.image(current if len(current.shape) == 2 else current[..., ::-1], use_container_width=True) | |
| download_button_img(current, "Download Current", "current.png") | |