dxue321's picture
inital upload
c3a7f7f
# import cv2
import numpy as np
# import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb, rgb2hsv, hsv2rgb
from PIL import Image
def stack_list(x):
stack = []
for id, val in enumerate(x):
if np.size(val) != 0:
if stack == []:
stack = val
else:
stack = np.vstack([stack, val])
return stack
def rgb_to_hex(r, g, b):
return '#{:02x}{:02x}{:02x}'.format(r, g, b)
def hex_to_rgb(hex):
# print(hex)
rgb = []
for i in (1, 3, 5):
decimal = int(hex[i:i+2], 16)
rgb.append(decimal)
return tuple(rgb)
def image_resize(img, c_w, c_h):
# img : PIL Image
if type(img) is np.ndarray:
img = Image.fromarray(img)
h, w = img.size
h_factor = c_h / w
w_factor = c_w / h
# factor = h_factor
factor = np.minimum(h_factor, w_factor)
# print(w*factor, h*factor)
img = img.resize((np.round(h*factor).astype(np.int64),
np.round(w*factor).astype(np.int64)),
Image.BILINEAR)
return img
def get_palette(num_cls):
""" Returns the color map for visualizing the segmentation mask.
Args:
num_cls: Number of classes
Returns:
The color map
"""
n = num_cls
palette = [0] * (n * 3)
for j in range(0, n):
lab = j
palette[j * 3 + 0] = 0
palette[j * 3 + 1] = 0
palette[j * 3 + 2] = 0
i = 0
while lab:
palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
i += 1
lab >>= 3
return palette
def visualize_palette(palette_lab, patch_size=20):
# print(palette_lab)
if palette_lab is None:
return np.ones((patch_size, patch_size, 3)) * [1.,1.,1.]
palette_lab = np.expand_dims(palette_lab, axis=0)
# palette_lab = np.sort(palette_lab, axis=1)
# # lab transfer by lookuptable
# # cluster_cen_rgb = lab2rgb_lut(cluster_cen_lab)
palette_rgb = lab2rgb(palette_lab)
palette_rgb = np.squeeze(palette_rgb, axis=0)
for id in range(np.size(palette_rgb, 0)):
rgb = np.expand_dims(palette_rgb[id,:], axis=(0, 1))
if id==0:
img_palette = np.ones((patch_size, patch_size, 3)) * rgb
else:
img_palette = np.append(img_palette, np.ones((patch_size, patch_size, 3)) * rgb, axis=1)
return img_palette
def visualize_palette_rgb(palette_rgb, patch_size=20):
# print(palette_lab)
if palette_rgb == 0:
return np.ones((patch_size, patch_size, 3)) * [1.,1.,1.]
for id in range(np.size(palette_rgb, 0)):
rgb = np.expand_dims(palette_rgb[id,:], axis=(0, 1))
if id==0:
img_palette = np.ones((patch_size, patch_size, 3)) * rgb
else:
img_palette = np.append(img_palette, np.ones((patch_size, patch_size, 3)) * rgb, axis=1)
return img_palette
# def vis_consistency(img_rgb_all, img_rgb_out_all, label_colored_all, c_center, L_idx):
def color_difference(img1, img2):
h, w, c = img1.shape
img1_lab = rgb2lab(img1)
img2_lab = rgb2lab(img2)
diff=img1_lab-img2_lab
dE = np.sqrt(diff[:,:,0]**2 + diff[:,:,1]**2 + diff[:,:,2]**2)
# dE = np.sqrt(diff[:,:,0]**2 + diff[:,:,0]**2)
dE = np.sum(dE)/(h*w)
return dE