|
|
|
|
|
import os
|
|
|
import cv2
|
|
|
import numpy as np
|
|
|
|
|
|
from PIL import Image
|
|
|
from skimage.color import rgb2lab, lab2rgb, rgb2hsv, hsv2rgb
|
|
|
from WB_sRGB.classes import WBsRGB as wb_srgb
|
|
|
from extract_palette import histogram, palette_extraction
|
|
|
from saliency.LDF.infer import Saliency_LDF
|
|
|
from saliency.fast_saliency import get_saliency_ft, get_saliency_mbd
|
|
|
from utils import color_difference
|
|
|
|
|
|
|
|
|
class BaseImage:
|
|
|
def __init__(self, filepath):
|
|
|
self.filename = os.path.basename(filepath.name)
|
|
|
self.image = Image.open(filepath)
|
|
|
self.img_rgb = np.asarray(self.image).astype(dtype=np.uint8)
|
|
|
|
|
|
anchor = 256
|
|
|
width = self.img_rgb.shape[1]
|
|
|
height = self.img_rgb.shape[0]
|
|
|
if width > 512 or height > 512:
|
|
|
self.if_downsample = True
|
|
|
if width >= height:
|
|
|
dim = (np.floor(width/height*anchor).astype(int), anchor)
|
|
|
else:
|
|
|
dim = (anchor, np.floor(height/width*anchor).astype(int))
|
|
|
self.img_rgb = cv2.resize(self.img_rgb, dim, interpolation=cv2.INTER_LINEAR)
|
|
|
|
|
|
self.img_lab = rgb2lab(self.img_rgb)
|
|
|
|
|
|
self.bin_size = 16
|
|
|
self.mode = 2
|
|
|
self.hist_harmonization = False
|
|
|
self.template = 'L'
|
|
|
self.distortion_threshold = 0.93
|
|
|
self.num_center_ind = 7
|
|
|
self.lightness = 70
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.applied_wb = False
|
|
|
|
|
|
|
|
|
|
|
|
def inital_info(self, if_correct_wb, if_saliency, wb_thres, sal_method, sal_thres, valid_class):
|
|
|
self.hist_value, self.hist_count, \
|
|
|
self.c_center, self.c_density, \
|
|
|
self.c_img_label, self.sal_links = self.extract_salient_palette(if_wb=if_correct_wb,
|
|
|
if_saliency=if_saliency,
|
|
|
wb_thres=wb_thres,
|
|
|
sal_method=sal_method,
|
|
|
sal_thres=sal_thres,
|
|
|
valid_class=valid_class)
|
|
|
|
|
|
|
|
|
|
|
|
def get_rgb_image(self):
|
|
|
return self.img_rgb
|
|
|
|
|
|
def get_lab_image(self):
|
|
|
return self.img_lab
|
|
|
|
|
|
def get_wb_image(self):
|
|
|
self.img_wb = self.white_balance_correction()
|
|
|
return self.img_wb
|
|
|
|
|
|
def get_saliency(self):
|
|
|
self.sal_map = self.saliency_detection(self.img_rgb)
|
|
|
return self.sal_map
|
|
|
|
|
|
def get_color_segment(self):
|
|
|
return self.label_colored
|
|
|
|
|
|
def get_label(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return self.colorlabel
|
|
|
|
|
|
|
|
|
def cal_color_segment(self):
|
|
|
label_colored = np.zeros_like(self.img_rgb, dtype=np.float64)
|
|
|
for id_color in range(np.size(self.center, 0)):
|
|
|
label_colored[self.colorlabel == id_color] = self.center[id_color, :]
|
|
|
label_colored = lab2rgb(label_colored)
|
|
|
label_colored = np.round(label_colored*255).astype(np.uint8)
|
|
|
return label_colored
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def white_balance_correction(self):
|
|
|
|
|
|
|
|
|
|
|
|
upgraded_model = 2
|
|
|
|
|
|
|
|
|
|
|
|
gamut_mapping = 2
|
|
|
|
|
|
|
|
|
wbModel = wb_srgb.WBsRGB(gamut_mapping=gamut_mapping,
|
|
|
upgraded=upgraded_model)
|
|
|
img_wb = wbModel.correctImage(self.img_rgb)
|
|
|
image_wb = (img_wb*255).astype(np.uint8)
|
|
|
|
|
|
return image_wb
|
|
|
|
|
|
|
|
|
def saliency_detection(self, img_rgb, method='LDF'):
|
|
|
if method == 'LDF':
|
|
|
get_saliency_LDF = Saliency_LDF()
|
|
|
sal_map = get_saliency_LDF.inference(img_rgb)
|
|
|
elif method == 'ft':
|
|
|
sal_map = get_saliency_ft(img_rgb)
|
|
|
elif method == 'rbd':
|
|
|
sal_map = get_saliency_mbd(img_rgb)
|
|
|
|
|
|
return sal_map
|
|
|
|
|
|
|
|
|
def solve_ind_palette(self, img_rgb, mask_binary=None):
|
|
|
w, h, c = img_rgb.shape
|
|
|
img_lab = rgb2lab(img_rgb)
|
|
|
|
|
|
hist_value, hist_count = histogram(img_lab, self.bin_size, mode=self.mode, mask=mask_binary)
|
|
|
|
|
|
|
|
|
|
|
|
c_center, c_density, c_img_label, histlabel = palette_extraction(img_lab, hist_value, hist_count,
|
|
|
threshold=self.distortion_threshold,
|
|
|
num_clusters=self.num_center_ind,
|
|
|
mode=self.mode,
|
|
|
mask=mask_binary)
|
|
|
|
|
|
if self.mode == 2:
|
|
|
c_center = np.insert(c_center, 0, values=self.lightness, axis=1)
|
|
|
|
|
|
c_img_label = np.reshape(c_img_label, (w,h))
|
|
|
|
|
|
|
|
|
return hist_value, hist_count, c_center, c_density, c_img_label, histlabel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_salient_palette(self, if_wb=False, if_saliency=False, wb_thres=5, sal_method='LDF', sal_thres=0.9, valid_class=[0,1]):
|
|
|
|
|
|
img_rgb = self.img_rgb.copy()
|
|
|
if if_wb:
|
|
|
self.img_wb = self.white_balance_correction()
|
|
|
img_wb = self.img_wb
|
|
|
dE = color_difference(img_rgb, img_wb)
|
|
|
print(dE)
|
|
|
if dE > wb_thres:
|
|
|
self.applied_wb = True
|
|
|
img_rgb = img_wb
|
|
|
print('use white balance correction on {}'.format(self.filename.split('/')[-1]))
|
|
|
|
|
|
hist_value, hist_count, center, density, colorlabel, histlabel = self.solve_ind_palette(img_rgb, mask_binary=None)
|
|
|
self.center = center
|
|
|
self.colorlabel = colorlabel
|
|
|
|
|
|
sal_links = [i for i in range(np.size(center, axis=0))]
|
|
|
|
|
|
if not if_saliency:
|
|
|
return hist_value, hist_count, center, density, colorlabel, sal_links
|
|
|
|
|
|
else:
|
|
|
self.sal_map = self.saliency_detection(img_rgb, method=sal_method)
|
|
|
label_sem = np.zeros_like(img_rgb[:,:,0])
|
|
|
|
|
|
label_sem[self.sal_map > sal_thres]=1
|
|
|
|
|
|
p_feq = np.zeros((len(valid_class), np.size(center, axis=0)))
|
|
|
|
|
|
|
|
|
for id_cls, cls in enumerate(valid_class):
|
|
|
label_binary = np.zeros_like(label_sem)
|
|
|
label_binary[label_sem==cls] = 1
|
|
|
colorlabel_cls = colorlabel[label_binary==1]
|
|
|
value, count = np.unique(colorlabel_cls, return_counts=True)
|
|
|
p_feq[id_cls, value] = count/count.sum()
|
|
|
|
|
|
palettelabel = np.argmax(p_feq, axis=0)
|
|
|
|
|
|
class_num = len(valid_class)
|
|
|
c_center = [np.array([]) for i in range(class_num)]
|
|
|
c_density = [np.array([]) for i in range(class_num)]
|
|
|
c_img_label = [np.array([]) for i in range(class_num)]
|
|
|
hist_samples = [np.array([]) for i in range(class_num)]
|
|
|
hist_counts = [np.array([]) for i in range(class_num)]
|
|
|
mapping = [np.array([]) for i in range(class_num)]
|
|
|
|
|
|
|
|
|
for id_cls, cls in enumerate(valid_class):
|
|
|
mapping[id_cls] = np.argwhere(palettelabel==id_cls).flatten()
|
|
|
c_center[id_cls]= center[mapping[id_cls],:]
|
|
|
c_density[id_cls] = density[mapping[id_cls]]
|
|
|
hist_samples[id_cls] = hist_value.copy()
|
|
|
hist_counts[id_cls] = hist_count.copy()
|
|
|
hist_counts[id_cls][histlabel!=id_cls] = 0
|
|
|
|
|
|
|
|
|
for idx, label in enumerate(mapping[id_cls]):
|
|
|
labels = np.zeros_like(colorlabel)
|
|
|
labels[colorlabel==label] = idx
|
|
|
c_img_label[id_cls] = labels
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sal_links = np.hstack((mapping[1], mapping[0]))
|
|
|
|
|
|
|
|
|
|
|
|
return hist_samples, hist_counts, c_center, c_density, c_img_label, sal_links
|
|
|
|
|
|
|
|
|
|