dxue321's picture
add saliency options, add image downsampling
7598644
# import cv2
import os
import cv2
import numpy as np
from PIL import Image
from skimage.color import rgb2lab, lab2rgb, rgb2hsv, hsv2rgb
from WB_sRGB.classes import WBsRGB as wb_srgb
from extract_palette import histogram, palette_extraction
from saliency.LDF.infer import Saliency_LDF
from saliency.fast_saliency import get_saliency_ft, get_saliency_mbd
from utils import color_difference
class BaseImage:
def __init__(self, filepath):
self.filename = os.path.basename(filepath.name)
self.image = Image.open(filepath)
self.img_rgb = np.asarray(self.image).astype(dtype=np.uint8)
anchor = 256
width = self.img_rgb.shape[1]
height = self.img_rgb.shape[0]
if width > 512 or height > 512:
self.if_downsample = True
if width >= height:
dim = (np.floor(width/height*anchor).astype(int), anchor)
else:
dim = (anchor, np.floor(height/width*anchor).astype(int))
self.img_rgb = cv2.resize(self.img_rgb, dim, interpolation=cv2.INTER_LINEAR)
self.img_lab = rgb2lab(self.img_rgb)
self.bin_size = 16
self.mode = 2
self.hist_harmonization = False
self.template = 'L'
self.distortion_threshold = 0.93
self.num_center_ind = 7
self.lightness = 70
# self.if_correct_wb = if_correct_wb
# self.if_saliency = if_saliency
# self.saliency_threshold = sal_thres
# self.cdiff_threshold = 30
# self.sal_threshold = 0.9
self.applied_wb = False
# self.valid_class = [0,1]
def inital_info(self, if_correct_wb, if_saliency, wb_thres, sal_method, sal_thres, valid_class):
self.hist_value, self.hist_count, \
self.c_center, self.c_density, \
self.c_img_label, self.sal_links = self.extract_salient_palette(if_wb=if_correct_wb,
if_saliency=if_saliency,
wb_thres=wb_thres,
sal_method=sal_method,
sal_thres=sal_thres,
valid_class=valid_class)
# self.label_colored = self.cal_color_segment()
def get_rgb_image(self):
return self.img_rgb
def get_lab_image(self):
return self.img_lab
def get_wb_image(self):
self.img_wb = self.white_balance_correction()
return self.img_wb
def get_saliency(self):
self.sal_map = self.saliency_detection(self.img_rgb)
return self.sal_map
def get_color_segment(self):
return self.label_colored
def get_label(self):
# print(self.links)
# label_mapped = np.zeros_like(self.colorlabel)
# for id, label in enumerate(self.links):
# label_mapped[self.colorlabel==id] = label
# self.colorlabel = label_mapped
return self.colorlabel
def cal_color_segment(self):
label_colored = np.zeros_like(self.img_rgb, dtype=np.float64)
for id_color in range(np.size(self.center, 0)):
label_colored[self.colorlabel == id_color] = self.center[id_color, :]
label_colored = lab2rgb(label_colored)
label_colored = np.round(label_colored*255).astype(np.uint8)
return label_colored
# def cal_salient_segment(self, palettelabel):
# label_colored = np.zeros_like(self.img_rgb, dtype=np.float64)
# valid_label = np.argwhere(palettelabel==1).flatten()
# for id_color in valid_label:
# label_colored[self.colorlabel == id_color] = self.center[id_color, :]
# label_colored = lab2rgb(label_colored)
# label_colored = np.round(label_colored*255).astype(np.uint8)
# return label_colored
def white_balance_correction(self):
# print('Correcting the white balance...')
# use upgraded_model = 1 to load our new model that is upgraded with new
# training examples.
upgraded_model = 2
# use gamut_mapping = 1 for scaling, 2 for clipping (our paper's results
# reported using clipping). If the image is over-saturated, scaling is
# recommended.
gamut_mapping = 2
# processing
# create an instance of the WB model
wbModel = wb_srgb.WBsRGB(gamut_mapping=gamut_mapping,
upgraded=upgraded_model)
img_wb = wbModel.correctImage(self.img_rgb) # white balance it
image_wb = (img_wb*255).astype(np.uint8)
# img_wb = cv2.cvtColor(img_wb, cv2.COLOR_BGR2RGB)
return image_wb
def saliency_detection(self, img_rgb, method='LDF'):
if method == 'LDF':
get_saliency_LDF = Saliency_LDF()
sal_map = get_saliency_LDF.inference(img_rgb)
elif method == 'ft':
sal_map = get_saliency_ft(img_rgb)
elif method == 'rbd':
sal_map = get_saliency_mbd(img_rgb)
return sal_map
def solve_ind_palette(self, img_rgb, mask_binary=None):
w, h, c = img_rgb.shape
img_lab = rgb2lab(img_rgb) # lab transfer by function
hist_value, hist_count = histogram(img_lab, self.bin_size, mode=self.mode, mask=mask_binary) ## with numpy histogram
## extract palette
# mask_binary = np.ones_like(self.img_rgb[:,:,0])
c_center, c_density, c_img_label, histlabel = palette_extraction(img_lab, hist_value, hist_count,
threshold=self.distortion_threshold,
num_clusters=self.num_center_ind,
mode=self.mode,
mask=mask_binary)
if self.mode == 2:
c_center = np.insert(c_center, 0, values=self.lightness, axis=1)
c_img_label = np.reshape(c_img_label, (w,h))
# density = np.tile(hist_counts, (self.mode, 1))
return hist_value, hist_count, c_center, c_density, c_img_label, histlabel
def extract_salient_palette(self, if_wb=False, if_saliency=False, wb_thres=5, sal_method='LDF', sal_thres=0.9, valid_class=[0,1]):
img_rgb = self.img_rgb.copy()
if if_wb:
self.img_wb = self.white_balance_correction()
img_wb = self.img_wb
dE = color_difference(img_rgb, img_wb)
print(dE)
if dE > wb_thres:
self.applied_wb = True
img_rgb = img_wb
print('use white balance correction on {}'.format(self.filename.split('/')[-1]))
hist_value, hist_count, center, density, colorlabel, histlabel = self.solve_ind_palette(img_rgb, mask_binary=None)
self.center = center
self.colorlabel = colorlabel
sal_links = [i for i in range(np.size(center, axis=0))]
if not if_saliency:
return hist_value, hist_count, center, density, colorlabel, sal_links
else:
self.sal_map = self.saliency_detection(img_rgb, method=sal_method)
label_sem = np.zeros_like(img_rgb[:,:,0])
# print(label_sem.shape, self.sal_map.shape)
label_sem[self.sal_map > sal_thres]=1
p_feq = np.zeros((len(valid_class), np.size(center, axis=0)))
for id_cls, cls in enumerate(valid_class):
label_binary = np.zeros_like(label_sem)
label_binary[label_sem==cls] = 1
colorlabel_cls = colorlabel[label_binary==1]
value, count = np.unique(colorlabel_cls, return_counts=True)
p_feq[id_cls, value] = count/count.sum()
palettelabel = np.argmax(p_feq, axis=0)
class_num = len(valid_class)
c_center = [np.array([]) for i in range(class_num)]
c_density = [np.array([]) for i in range(class_num)]
c_img_label = [np.array([]) for i in range(class_num)]
hist_samples = [np.array([]) for i in range(class_num)]
hist_counts = [np.array([]) for i in range(class_num)]
mapping = [np.array([]) for i in range(class_num)]
for id_cls, cls in enumerate(valid_class):
mapping[id_cls] = np.argwhere(palettelabel==id_cls).flatten()
c_center[id_cls]= center[mapping[id_cls],:]
c_density[id_cls] = density[mapping[id_cls]]
hist_samples[id_cls] = hist_value.copy()
hist_counts[id_cls] = hist_count.copy()
hist_counts[id_cls][histlabel!=id_cls] = 0
for idx, label in enumerate(mapping[id_cls]):
labels = np.zeros_like(colorlabel)
labels[colorlabel==label] = idx
c_img_label[id_cls] = labels
# if id_cls ==1:
# label_colored = np.zeros_like(self.img_rgb, dtype=np.float64)
# for id_color in mapping[id_cls]:
# label_colored[colorlabel == id_color] = center[id_color, :]
# label_colored = lab2rgb(label_colored)
# label_colored = np.round(label_colored*255).astype(np.uint8)
# print(colorlabel.shape, c_img_label[id_cls].shape)
# print(density.shape, c_density[id_cls].shape)
# print(center.shape, c_center[id_cls].shape)
sal_links = np.hstack((mapping[1], mapping[0]))
# print(links)
return hist_samples, hist_counts, c_center, c_density, c_img_label, sal_links