Spaces:
Running
on
Zero
Running
on
Zero
| import os, sys | |
| sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import numpy as np | |
| import networkx as nx | |
| from io import BytesIO | |
| from PIL import Image, ImageDraw | |
| from matplotlib import pyplot as plt | |
| from sklearn.decomposition import PCA | |
| from singapo_utils.refs import graph_color_ref | |
| def add_text(text, imgarr): | |
| ''' | |
| Function to add text to image | |
| Args: | |
| - text (str): text to add | |
| - imgarr (np.array): image array | |
| Returns: | |
| - img (np.array): image array with text | |
| ''' | |
| img = Image.fromarray(imgarr) | |
| I = ImageDraw.Draw(img) | |
| I.text((10, 10), text, fill='black') | |
| return np.asarray(img) | |
| def get_color(ref, n_nodes): | |
| ''' | |
| Function to color the nodes | |
| Args: | |
| - ref (list): list of color reference | |
| - n_nodes (int): number of nodes | |
| Returns: | |
| - colors (list): list of colors | |
| ''' | |
| N = len(ref) | |
| colors = [] | |
| for i in range(n_nodes): | |
| colors.append(np.array([[int(i) for i in ref[i%N][4:-1].split(',')]]) / 255.) | |
| return colors | |
| def make_grid(images, cols=5): | |
| """ | |
| Arrange list of images into a N x cols grid. | |
| Args: | |
| - images (list): List of Numpy arrays representing the images. | |
| - cols (int): Number of columns for the grid. | |
| Returns: | |
| - grid (numpy array): Numpy array representing the image grid. | |
| """ | |
| # Determine the dimensions of each image | |
| img_h, img_w, _ = images[0].shape | |
| rows = len(images) // cols | |
| # Initialize a blank canvas | |
| grid = np.zeros((rows * img_h, cols * img_w, 3), dtype=images[0].dtype) | |
| # Place each image onto the grid | |
| for idx, img in enumerate(images): | |
| y = (idx // cols) * img_h | |
| x = (idx % cols) * img_w | |
| grid[y: y + img_h, x: x + img_w] = img | |
| return grid | |
| def viz_graph(info_dict, res=256): | |
| ''' | |
| Function to plot the directed graph | |
| Args: | |
| - info_dict (dict): output json containing the graph information | |
| - res (int): resolution of the image | |
| Returns: | |
| - img_arr (np.array): image array | |
| ''' | |
| # build tree | |
| tree = info_dict['diffuse_tree'] | |
| edges = [] | |
| for node in tree: | |
| edges += [(node['id'], child) for child in node['children']] | |
| G = nx.DiGraph() | |
| G.add_edges_from(edges) | |
| # plot tree | |
| plt.figure(figsize=(res/100, res/100)) | |
| colors = get_color(graph_color_ref, len(tree)) | |
| pos = nx.nx_agraph.graphviz_layout(G, prog="twopi", args="") | |
| node_order = sorted(G.nodes()) | |
| nx.draw(G, pos, node_color=colors, nodelist=node_order, edge_color='k', with_labels=False) | |
| buf = BytesIO() | |
| plt.savefig(buf, format="png", dpi=100) | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| img_arr = np.asarray(img) | |
| buf.close() | |
| plt.clf() | |
| plt.close() | |
| return img_arr[:, :, :3] | |
| def viz_patch_feat_pca(feat): | |
| pca = PCA(n_components=3) | |
| pca.fit(feat) | |
| feat_pca = pca.transform(feat) | |
| t = np.array(feat_pca) | |
| t_min = t.min(axis=0, keepdims=True) | |
| t_max = t.max(axis=0, keepdims=True) | |
| normalized_t = (t - t_min) / (t_max - t_min) | |
| array = (normalized_t * 255).astype(np.uint8) | |
| img_array = array.reshape(16, 16, 3) | |
| return img_array |