# import cv2 import os import cv2 import numpy as np from PIL import Image from skimage.color import rgb2lab, lab2rgb, rgb2hsv, hsv2rgb from WB_sRGB.classes import WBsRGB as wb_srgb from extract_palette import histogram, palette_extraction from saliency.LDF.infer import Saliency_LDF from saliency.fast_saliency import get_saliency_ft, get_saliency_mbd from utils import color_difference class BaseImage: def __init__(self, filepath): self.filename = os.path.basename(filepath.name) self.image = Image.open(filepath) self.img_rgb = np.asarray(self.image).astype(dtype=np.uint8) anchor = 256 width = self.img_rgb.shape[1] height = self.img_rgb.shape[0] if width > 512 or height > 512: self.if_downsample = True if width >= height: dim = (np.floor(width/height*anchor).astype(int), anchor) else: dim = (anchor, np.floor(height/width*anchor).astype(int)) self.img_rgb = cv2.resize(self.img_rgb, dim, interpolation=cv2.INTER_LINEAR) self.img_lab = rgb2lab(self.img_rgb) self.bin_size = 16 self.mode = 2 self.hist_harmonization = False self.template = 'L' self.distortion_threshold = 0.93 self.num_center_ind = 7 self.lightness = 70 # self.if_correct_wb = if_correct_wb # self.if_saliency = if_saliency # self.saliency_threshold = sal_thres # self.cdiff_threshold = 30 # self.sal_threshold = 0.9 self.applied_wb = False # self.valid_class = [0,1] def inital_info(self, if_correct_wb, if_saliency, wb_thres, sal_method, sal_thres, valid_class): self.hist_value, self.hist_count, \ self.c_center, self.c_density, \ self.c_img_label, self.sal_links = self.extract_salient_palette(if_wb=if_correct_wb, if_saliency=if_saliency, wb_thres=wb_thres, sal_method=sal_method, sal_thres=sal_thres, valid_class=valid_class) # self.label_colored = self.cal_color_segment() def get_rgb_image(self): return self.img_rgb def get_lab_image(self): return self.img_lab def get_wb_image(self): self.img_wb = self.white_balance_correction() return self.img_wb def get_saliency(self): self.sal_map = self.saliency_detection(self.img_rgb) return self.sal_map def get_color_segment(self): return self.label_colored def get_label(self): # print(self.links) # label_mapped = np.zeros_like(self.colorlabel) # for id, label in enumerate(self.links): # label_mapped[self.colorlabel==id] = label # self.colorlabel = label_mapped return self.colorlabel def cal_color_segment(self): label_colored = np.zeros_like(self.img_rgb, dtype=np.float64) for id_color in range(np.size(self.center, 0)): label_colored[self.colorlabel == id_color] = self.center[id_color, :] label_colored = lab2rgb(label_colored) label_colored = np.round(label_colored*255).astype(np.uint8) return label_colored # def cal_salient_segment(self, palettelabel): # label_colored = np.zeros_like(self.img_rgb, dtype=np.float64) # valid_label = np.argwhere(palettelabel==1).flatten() # for id_color in valid_label: # label_colored[self.colorlabel == id_color] = self.center[id_color, :] # label_colored = lab2rgb(label_colored) # label_colored = np.round(label_colored*255).astype(np.uint8) # return label_colored def white_balance_correction(self): # print('Correcting the white balance...') # use upgraded_model = 1 to load our new model that is upgraded with new # training examples. upgraded_model = 2 # use gamut_mapping = 1 for scaling, 2 for clipping (our paper's results # reported using clipping). If the image is over-saturated, scaling is # recommended. gamut_mapping = 2 # processing # create an instance of the WB model wbModel = wb_srgb.WBsRGB(gamut_mapping=gamut_mapping, upgraded=upgraded_model) img_wb = wbModel.correctImage(self.img_rgb) # white balance it image_wb = (img_wb*255).astype(np.uint8) # img_wb = cv2.cvtColor(img_wb, cv2.COLOR_BGR2RGB) return image_wb def saliency_detection(self, img_rgb, method='LDF'): if method == 'LDF': get_saliency_LDF = Saliency_LDF() sal_map = get_saliency_LDF.inference(img_rgb) elif method == 'ft': sal_map = get_saliency_ft(img_rgb) elif method == 'rbd': sal_map = get_saliency_mbd(img_rgb) return sal_map def solve_ind_palette(self, img_rgb, mask_binary=None): w, h, c = img_rgb.shape img_lab = rgb2lab(img_rgb) # lab transfer by function hist_value, hist_count = histogram(img_lab, self.bin_size, mode=self.mode, mask=mask_binary) ## with numpy histogram ## extract palette # mask_binary = np.ones_like(self.img_rgb[:,:,0]) c_center, c_density, c_img_label, histlabel = palette_extraction(img_lab, hist_value, hist_count, threshold=self.distortion_threshold, num_clusters=self.num_center_ind, mode=self.mode, mask=mask_binary) if self.mode == 2: c_center = np.insert(c_center, 0, values=self.lightness, axis=1) c_img_label = np.reshape(c_img_label, (w,h)) # density = np.tile(hist_counts, (self.mode, 1)) return hist_value, hist_count, c_center, c_density, c_img_label, histlabel def extract_salient_palette(self, if_wb=False, if_saliency=False, wb_thres=5, sal_method='LDF', sal_thres=0.9, valid_class=[0,1]): img_rgb = self.img_rgb.copy() if if_wb: self.img_wb = self.white_balance_correction() img_wb = self.img_wb dE = color_difference(img_rgb, img_wb) print(dE) if dE > wb_thres: self.applied_wb = True img_rgb = img_wb print('use white balance correction on {}'.format(self.filename.split('/')[-1])) hist_value, hist_count, center, density, colorlabel, histlabel = self.solve_ind_palette(img_rgb, mask_binary=None) self.center = center self.colorlabel = colorlabel sal_links = [i for i in range(np.size(center, axis=0))] if not if_saliency: return hist_value, hist_count, center, density, colorlabel, sal_links else: self.sal_map = self.saliency_detection(img_rgb, method=sal_method) label_sem = np.zeros_like(img_rgb[:,:,0]) # print(label_sem.shape, self.sal_map.shape) label_sem[self.sal_map > sal_thres]=1 p_feq = np.zeros((len(valid_class), np.size(center, axis=0))) for id_cls, cls in enumerate(valid_class): label_binary = np.zeros_like(label_sem) label_binary[label_sem==cls] = 1 colorlabel_cls = colorlabel[label_binary==1] value, count = np.unique(colorlabel_cls, return_counts=True) p_feq[id_cls, value] = count/count.sum() palettelabel = np.argmax(p_feq, axis=0) class_num = len(valid_class) c_center = [np.array([]) for i in range(class_num)] c_density = [np.array([]) for i in range(class_num)] c_img_label = [np.array([]) for i in range(class_num)] hist_samples = [np.array([]) for i in range(class_num)] hist_counts = [np.array([]) for i in range(class_num)] mapping = [np.array([]) for i in range(class_num)] for id_cls, cls in enumerate(valid_class): mapping[id_cls] = np.argwhere(palettelabel==id_cls).flatten() c_center[id_cls]= center[mapping[id_cls],:] c_density[id_cls] = density[mapping[id_cls]] hist_samples[id_cls] = hist_value.copy() hist_counts[id_cls] = hist_count.copy() hist_counts[id_cls][histlabel!=id_cls] = 0 for idx, label in enumerate(mapping[id_cls]): labels = np.zeros_like(colorlabel) labels[colorlabel==label] = idx c_img_label[id_cls] = labels # if id_cls ==1: # label_colored = np.zeros_like(self.img_rgb, dtype=np.float64) # for id_color in mapping[id_cls]: # label_colored[colorlabel == id_color] = center[id_color, :] # label_colored = lab2rgb(label_colored) # label_colored = np.round(label_colored*255).astype(np.uint8) # print(colorlabel.shape, c_img_label[id_cls].shape) # print(density.shape, c_density[id_cls].shape) # print(center.shape, c_center[id_cls].shape) sal_links = np.hstack((mapping[1], mapping[0])) # print(links) return hist_samples, hist_counts, c_center, c_density, c_img_label, sal_links