Spaces:
Running
Running
| """ | |
| HF Space: CLIP Latent Conformity & Likelihood | |
| ------------------------------------------------- | |
| This Gradio app computes (1) conformity-to-mean and (2) relative log-likelihood | |
| for CLIP image/text embeddings. It also supports pairwise comparison in terms of | |
| both metrics. | |
| IMPORTANT (data provenance): The matrices below that drive the likelihood are | |
| loaded from MS-COCO–based statistics. We use the precomputed means and W matrices | |
| provided in the repo: https://github.com/rbetser/W_CLIP/tree/main/w_mats. | |
| Definitions used here (aligned with the user's papers): | |
| - Conformity (per modality) = cosine similarity between a unit-normalized sample | |
| feature and the corresponding modality mean feature (also unit-normalized). | |
| - Log-likelihood (per modality) is modeled by a quadratic form using a | |
| positive semi-definite precision matrix W (MS-COCO-based): | |
| d^2(x) = (x - mu)^T W (x - mu) | |
| loglike_rel(x) = -0.5 * d^2(x) (constant terms omitted) | |
| Notes: | |
| - Conformity measure is based on the paper: "The Double-Ellipsoid Geometry of CLIP" (https://arxiv.org/abs/2411.14517) | |
| - Likelihood measure is based on the paper: "Whitened CLIP as a Likelihood Surrogate of Images and Captions" (https://arxiv.org/abs/2505.06934) | |
| - CLIP embedding dim is 768 for ViT-L/14. | |
| - We keep modality-specific means (mu_img, mu_txt) and precision matrices | |
| (W_img, W_txt). These are loaded at runtime from local `.pt` files shipped | |
| with the Space. | |
| """ | |
| from __future__ import annotations | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from transformers import CLIPModel, AutoProcessor | |
| import os | |
| import base64 | |
| from io import BytesIO | |
| # --------------------------- | |
| # Load internal statistics (from w_mats) | |
| # --------------------------- | |
| _device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Paths (must be uploaded to the Space inside a folder named w_mats) | |
| _mean_image_path = "w_mats/mean_image_L14.pt" | |
| _mean_text_path = "w_mats/mean_text_L14.pt" | |
| _w_image_path = "w_mats/w_mat_image_L14.pt" | |
| _w_text_path = "w_mats/w_mat_text_L14.pt" | |
| # Load tensors | |
| _modality_mean_image = torch.load(_mean_image_path, map_location=_device, weights_only=False).cpu().numpy() | |
| _modality_mean_text = torch.load(_mean_text_path, map_location=_device, weights_only=False).cpu().numpy() | |
| _W_image = torch.load(_w_image_path, map_location=_device, weights_only=False).cpu().numpy() | |
| _W_text = torch.load(_w_text_path, map_location=_device, weights_only=False).cpu().numpy() | |
| # Sanity checks | |
| EMB_DIM = 768 # ViT-L/14 feature dimension | |
| assert _modality_mean_image.shape == (EMB_DIM,), f"mu_image must be {EMB_DIM}-D" | |
| assert _modality_mean_text.shape == (EMB_DIM,), f"mu_text must be {EMB_DIM}-D" | |
| assert _W_image.shape == (EMB_DIM, EMB_DIM), f"W_image must be {EMB_DIM}x{EMB_DIM}" | |
| assert _W_text.shape == (EMB_DIM, EMB_DIM), f"W_text must be {EMB_DIM}x{EMB_DIM}" | |
| # --------------------------- | |
| # Model / Processor | |
| # --------------------------- | |
| MODEL_ID = "openai/clip-vit-large-patch14" | |
| _model: CLIPModel | None = None | |
| _processor: AutoProcessor | None = None | |
| def _load_model(): | |
| global _model, _processor | |
| if _model is None: | |
| _model = CLIPModel.from_pretrained(MODEL_ID).to(_device).eval() | |
| if _processor is None: | |
| _processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| def _l2_normalize(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: | |
| return x / (x.norm(dim=-1, keepdim=True) + eps) | |
| # --------------------------- | |
| # Embedding helpers | |
| # --------------------------- | |
| def embed_image(img: Image.Image) -> np.ndarray: | |
| _load_model() | |
| inputs = _processor(images=img, return_tensors="pt").to(_device) | |
| feats = _model.get_image_features(**inputs) # [1, D] | |
| # feats = _l2_normalize(feats) | |
| return feats.squeeze(0).detach().cpu().numpy() | |
| def embed_text(text: str) -> np.ndarray: | |
| _load_model() | |
| inputs = _processor(text=[text], return_tensors="pt", padding=True).to(_device) | |
| feats = _model.get_text_features(**inputs) # [1, D] | |
| # feats = _l2_normalize(feats) | |
| return feats.squeeze(0).detach().cpu().numpy() | |
| # --------------------------- | |
| # Conformity & Likelihood | |
| # --------------------------- | |
| def _cosine(a: np.ndarray, b: np.ndarray, eps: float = 1e-12) -> float: | |
| a = a / (np.linalg.norm(a) + eps) | |
| b = b / (np.linalg.norm(b) + eps) | |
| return float(np.dot(a, b)) | |
| def conformity_image(z: np.ndarray) -> float: | |
| return _cosine(z, _modality_mean_image) | |
| def conformity_text(z: np.ndarray) -> float: | |
| return _cosine(z, _modality_mean_text) | |
| def loglike_image_relative(z_in_i: np.ndarray) -> float: | |
| # Convert to torch tensor on the correct device | |
| z_i = torch.tensor(z_in_i, dtype=torch.float32, device=_device).reshape(1,-1) | |
| mu_i = torch.tensor(_modality_mean_image, dtype=torch.float32, device=_device).reshape(1,-1) | |
| W = torch.tensor(_W_image, dtype=torch.float32, device=_device) | |
| # Center and transform features using the whitening matrix | |
| cntr_features = z_i - mu_i | |
| w_features = torch.matmul(cntr_features, W) | |
| # quad = (cntr_features @ W @ cntr_features.T).squeeze() | |
| # Compute log-likelihood using Gaussian distribution assumption | |
| N = z_i.shape[-1] | |
| log_like = -0.5 * (N * torch.log(torch.tensor(2 * torch.pi, device=_device)) + torch.sum(w_features**2)) | |
| # Return as NumPy float | |
| return log_like.cpu().numpy().item() | |
| def loglike_text_relative(z_in_t: np.ndarray) -> float: | |
| # Convert to torch tensor on the correct device | |
| z_t = torch.tensor(z_in_t, dtype=torch.float32, device=_device).reshape(1,-1) | |
| mu_t = torch.tensor(_modality_mean_text, dtype=torch.float32, device=_device).reshape(1,-1) | |
| W = torch.tensor(_W_text, dtype=torch.float32, device=_device) | |
| # Center and transform features using the whitening matrix | |
| cntr_features = z_t - mu_t | |
| w_features = torch.matmul(cntr_features, W) | |
| #quad = (cntr_features @ W @ cntr_features.T).squeeze() | |
| # Compute log-likelihood using Gaussian distribution assumption | |
| N = z_t.shape[-1] | |
| log_like = -0.5 * (N * torch.log(torch.tensor(2 * torch.pi, device=_device)) + torch.sum(w_features**2)) | |
| # Return as NumPy float | |
| return log_like.cpu().numpy().item() | |
| # --------------------------- | |
| # Gradio logic | |
| # --------------------------- | |
| DESC = """ | |
| This Space operates on **CLIP ViT-L/14** latent space to compute two metrics per modality: | |
| 1. **Conformity** — measure how common the samle is (based on [The Double-Ellipsoid Geometry of CLIP](https://arxiv.org/abs/2411.14517)) | |
| 2. **Log-Likelihood** — measure how like the common is (based on [Whitened CLIP as a Likelihood Surrogate of Images and Captions](https://arxiv.org/abs/2505.06934)) | |
| All modality means and W matrices are stored *internally* and loaded from `w_mats/*.pt`. | |
| """ | |
| PROVENANCE = """ | |
| **Data provenance** | |
| Modality means and precision matrices (W) are computed from **MS-COCO** features. | |
| They are loaded from precomputed `.pt` files in the Space repo. | |
| """ | |
| def analyze_single(modality: str, text: str, image: Image.Image): | |
| if modality == "Image": | |
| if image is None: | |
| return {"Error": "Please upload an image."}, None | |
| z = embed_image(image) | |
| conf = conformity_image(z) | |
| ll = loglike_image_relative(z) | |
| else: | |
| if not text: | |
| return {"Error": "Please enter text."}, None | |
| z = embed_text(text) | |
| conf = conformity_text(z) | |
| ll = loglike_text_relative(z) | |
| report = { | |
| "Modality": modality, | |
| "Conformity (cosine to mu)": round(conf, 6), | |
| "Rel. Log-Likelihood (MS-COCO W)": round(ll, 6), | |
| } | |
| summary = f"Conformity: {conf:.6f} | Log-likelihood: {ll:.6f}" | |
| return report, summary | |
| def compare_pair_gui(modality: str, text1: str, image1: Image.Image, text2: str, image2: Image.Image): | |
| from io import BytesIO | |
| import base64 | |
| # Prepare images if modality is Image | |
| img1_html = "" | |
| img2_html = "" | |
| if modality == "Image": | |
| if image1 is None or image2 is None: | |
| return "<p style='color:red'>Please upload both images.</p>" | |
| # Convert first image to base64 | |
| buf1 = BytesIO() | |
| image1.save(buf1, format="PNG") | |
| img1_b64 = base64.b64encode(buf1.getvalue()).decode() | |
| img1_html = f"<img src='data:image/png;base64,{img1_b64}' width='150px' style='border:1px solid #ccc; border-radius:8px;'/>" | |
| # Convert second image to base64 | |
| buf2 = BytesIO() | |
| image2.save(buf2, format="PNG") | |
| img2_b64 = base64.b64encode(buf2.getvalue()).decode() | |
| img2_html = f"<img src='data:image/png;base64,{img2_b64}' width='150px' style='border:1px solid #ccc; border-radius:8px;'/>" | |
| z1 = embed_image(image1) | |
| z2 = embed_image(image2) | |
| c1, c2 = conformity_image(z1), conformity_image(z2) | |
| l1, l2 = loglike_image_relative(z1), loglike_image_relative(z2) | |
| else: | |
| if not text1 or not text2: | |
| return "<p style='color:red'>Please enter both texts.</p>" | |
| z1 = embed_text(text1) | |
| z2 = embed_text(text2) | |
| c1, c2 = conformity_text(z1), conformity_text(z2) | |
| l1, l2 = loglike_text_relative(z1), loglike_text_relative(z2) | |
| # Build HTML output | |
| html = f""" | |
| <div style='display:flex; gap:20px; min-height:150px; padding:10px; border:1px solid #eee; border-radius:8px;'> | |
| <div style='text-align:center;'> | |
| {img1_html if modality=="Image" else "<div style='min-height:50px'></div>"} | |
| <p><b>#1 {modality}:</b></p> | |
| <p>Conformity: {c1:.6f}</p> | |
| <p>Log-Likelihood: {l1:.6f}</p> | |
| </div> | |
| <div style='text-align:center;'> | |
| {img2_html if modality=="Image" else "<div style='min-height:50px'></div>"} | |
| <p><b>#2 {modality}:</b></p> | |
| <p>Conformity: {c2:.6f}</p> | |
| <p>Log-Likelihood: {l2:.6f}</p> | |
| </div> | |
| <div style='text-align:center;'> | |
| <p><b>Δ (2-1)</b></p> | |
| <p>Δ Conformity: {c2-c1:.6f}</p> | |
| <p>Δ Log-Likelihood: {l2-l1:.6f}</p> | |
| </div> | |
| </div> | |
| """ | |
| return html | |
| with gr.Blocks( | |
| title="CLIP Latent: Conformity & Likelihood (ViT-L/14)", | |
| css=""" | |
| #result-box, #result-cmp { | |
| min-height: 200px; | |
| padding: 10px; | |
| border: 1px solid #eee; | |
| border-radius: 8px; | |
| } | |
| """ | |
| ) as demo: | |
| gr.Markdown(f"# CLIP Latent Space — Conformity & Likelihood (ViT-L/14)\n\n{DESC}\n\n{PROVENANCE}") | |
| with gr.Tab("Single Input"): | |
| modality = gr.Radio(["Image", "Text"], value="Image", label="Modality") | |
| img_in = gr.Image(type="pil", label="Image", visible=True) | |
| txt_in = gr.Textbox(label="Text", visible=False) | |
| btn = gr.Button("Analyze") | |
| result_out = gr.HTML("<p>Result will appear here</p>", elem_id="result-box") | |
| # Update function must be inside the Blocks context | |
| def update_inputs(mod): | |
| return gr.update(visible=(mod=="Image")), gr.update(visible=(mod=="Text")) | |
| modality.change(fn=update_inputs, inputs=[modality], outputs=[img_in, txt_in]) | |
| # Analysis function inside the same context | |
| def analyze_single_gui(modality: str, text: str, image: Image.Image): | |
| # ------------------- | |
| # Embed input | |
| # ------------------- | |
| if modality == "Image": | |
| if image is None: | |
| return "<p style='color:red'>Please upload an image.</p>" | |
| z = embed_image(image) | |
| conf = conformity_image(z) | |
| ll = loglike_image_relative(z) | |
| buf_img = BytesIO() | |
| image.save(buf_img, format="PNG") | |
| img_b64 = base64.b64encode(buf_img.getvalue()).decode() | |
| img_html = f"<img src='data:image/png;base64,{img_b64}' width='200px' style='border:1px solid #ccc; border-radius:8px;'/>" | |
| else: | |
| if not text.strip(): | |
| return "<p style='color:red'>Please enter text.</p>" | |
| z = embed_text(text) | |
| conf = conformity_text(z) | |
| ll = loglike_text_relative(z) | |
| img_html = "<div style='min-height:50px; width:200px;'></div>" | |
| # ------------------- | |
| # Load conformity distribution | |
| # ------------------- | |
| conf_file = "w_mats/conformity_image.npy" if modality == "Image" else "w_mats/conformity_text.npy" | |
| if os.path.exists(conf_file): | |
| all_confs = np.load(conf_file, allow_pickle=False) | |
| bins = np.linspace(np.min(all_confs), np.max(all_confs), 51) # 50 bins | |
| hist, _ = np.histogram(all_confs, bins=bins) | |
| max_count = hist.max() | |
| else: | |
| hist = np.ones(50) | |
| bins = np.linspace(0, 1, 51) | |
| max_count = 1 | |
| # Determine which bin the input falls into | |
| input_bin = np.digitize(conf, bins) - 1 | |
| input_bin = np.clip(input_bin, 0, 49) | |
| # Build HTML mini histogram | |
| bar_width = 100 / 50 # percentage | |
| bars_html = "" | |
| for i in range(50): | |
| height = (hist[i] / max_count) * 100 # scale to 100% | |
| color = "#FF4136" if i == input_bin else "#4C72B0" | |
| bars_html += f"<div style='display:inline-block; width:{bar_width}%; height:{height}%; background:{color}; margin:0 0.2px; vertical-align:bottom;'></div>" | |
| hist_html = f""" | |
| <p><b>Conformity: {conf:.6f}</b></p> | |
| <div style='width:300px; height:100px; border:1px solid #ddd; border-radius:5px; background:#f9f9f9; display:flex; align-items:flex-end;'> | |
| {bars_html} | |
| </div> | |
| """ | |
| # ------------------- | |
| # Combine HTML | |
| # ------------------- | |
| html = f""" | |
| <div style='display:flex; gap:20px; align-items:flex-start; min-height:200px; padding:10px; border:1px solid #eee; border-radius:8px;'> | |
| {img_html} | |
| <div style='flex:1;'> | |
| <p><b>Modality:</b> {modality}</p> | |
| <p><b>Log-Likelihood:</b> {ll:.6f}</p> | |
| {hist_html} | |
| </div> | |
| </div> | |
| """ | |
| return html | |
| btn.click(analyze_single_gui, inputs=[modality, txt_in, img_in], outputs=[result_out]) | |
| with gr.Tab("Compare Two"): | |
| modality_c = gr.Radio(["Image", "Text"], value="Image", label="Modality") | |
| img1 = gr.Image(type="pil", label="#1 Image", visible=True) | |
| txt1 = gr.Textbox(label="#1 Text", visible=False) | |
| img2 = gr.Image(type="pil", label="#2 Image", visible=True) | |
| txt2 = gr.Textbox(label="#2 Text", visible=False) | |
| result_cmp = gr.HTML("<p>Comparison result will appear here</p>", elem_id="result-cmp") | |
| def update_compare_inputs(mod): | |
| return (gr.update(visible=(mod=="Image")), # img1 | |
| gr.update(visible=(mod=="Text")), # txt1 | |
| gr.update(visible=(mod=="Image")), # img2 | |
| gr.update(visible=(mod=="Text"))) # txt2 | |
| modality_c.change(fn=update_compare_inputs, | |
| inputs=[modality_c], | |
| outputs=[img1, txt1, img2, txt2]) | |
| def compare_pair_gui(modality: str, text1: str, image1: Image.Image, text2: str, image2: Image.Image): | |
| from io import BytesIO | |
| import base64 | |
| if modality == "Image": | |
| if image1 is None or image2 is None: | |
| return "<p style='color:red'>Please upload both images.</p>" | |
| def img_to_html(img): | |
| buf = BytesIO() | |
| img.save(buf, format="PNG") | |
| img_b64 = base64.b64encode(buf.getvalue()).decode() | |
| return f"<img src='data:image/png;base64,{img_b64}' width='150px' style='border:1px solid #ccc; border-radius:8px;'/>" | |
| img1_html = img_to_html(image1) | |
| img2_html = img_to_html(image2) | |
| z1 = embed_image(image1) | |
| z2 = embed_image(image2) | |
| c1, c2 = conformity_image(z1), conformity_image(z2) | |
| l1, l2 = loglike_image_relative(z1), loglike_image_relative(z2) | |
| else: | |
| if not text1 or not text2: | |
| return "<p style='color:red'>Please enter both texts.</p>" | |
| z1 = embed_text(text1) | |
| z2 = embed_text(text2) | |
| c1, c2 = conformity_text(z1), conformity_text(z2) | |
| l1, l2 = loglike_text_relative(z1), loglike_text_relative(z2) | |
| img1_html = img2_html = "<div style='min-height:50px'></div>" | |
| html = f""" | |
| <div style='display:flex; gap:20px; min-height:150px; padding:10px; border:1px solid #eee; border-radius:8px;'> | |
| <div style='text-align:center;'>{img1_html}<p><b>#1 {modality}</b></p> | |
| <p>Conformity: {c1:.6f}</p><p>Log-Likelihood: {l1:.6f}</p> | |
| </div> | |
| <div style='text-align:center;'>{img2_html}<p><b>#2 {modality}</b></p> | |
| <p>Conformity: {c2:.6f}</p><p>Log-Likelihood: {l2:.6f}</p> | |
| </div> | |
| <div style='text-align:center;'> | |
| <p><b>Δ (2-1)</b></p> | |
| <p>Δ Conformity: {c2-c1:.6f}</p> | |
| <p>Δ Log-Likelihood: {l2-l1:.6f}</p> | |
| </div> | |
| </div> | |
| """ | |
| return html | |
| btn_c = gr.Button("Compare") | |
| btn_c.click(compare_pair_gui, inputs=[modality_c, txt1, img1, txt2, img2], outputs=[result_cmp]) | |
| gr.Markdown( | |
| """ | |
| **Implementation details:** | |
| - Embeddings: `openai/clip-vit-large-patch14` via 🤗 Transformers; features are L2-normalized. | |
| - Conformity: cosine similarity to stored modality means `mu_image`, `mu_text`. | |
| - Log-likelihood: `-0.5 * (x-mu)^T W (x-mu)` using MS-COCO-based precision `W`. | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |