File size: 10,148 Bytes
c3a7f7f
 
7598644
c3a7f7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7598644
 
 
 
 
 
 
 
 
 
 
 
c3a7f7f
 
 
 
 
 
 
 
 
 
 
 
7598644
 
c3a7f7f
 
 
 
7598644
c3a7f7f
 
 
 
 
7598644
c3a7f7f
 
 
7598644
c3a7f7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7598644
c3a7f7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7598644
c3a7f7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7598644
 
c3a7f7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
# import cv2
import os
import cv2
import numpy as np

from PIL import Image
from skimage.color import rgb2lab, lab2rgb, rgb2hsv, hsv2rgb
from WB_sRGB.classes import WBsRGB as wb_srgb
from extract_palette import histogram, palette_extraction
from saliency.LDF.infer import Saliency_LDF
from saliency.fast_saliency import get_saliency_ft, get_saliency_mbd
from utils import color_difference


class BaseImage:
    def __init__(self, filepath):
        self.filename = os.path.basename(filepath.name)
        self.image = Image.open(filepath)
        self.img_rgb = np.asarray(self.image).astype(dtype=np.uint8)

        anchor = 256
        width = self.img_rgb.shape[1]
        height = self.img_rgb.shape[0]
        if width > 512 or height > 512:
            self.if_downsample = True
            if width >= height:
                dim = (np.floor(width/height*anchor).astype(int), anchor)
            else:
                dim = (anchor, np.floor(height/width*anchor).astype(int))
            self.img_rgb = cv2.resize(self.img_rgb, dim, interpolation=cv2.INTER_LINEAR)

        self.img_lab = rgb2lab(self.img_rgb)

        self.bin_size = 16
        self.mode = 2
        self.hist_harmonization = False
        self.template = 'L'
        self.distortion_threshold = 0.93
        self.num_center_ind = 7
        self.lightness = 70
        # self.if_correct_wb = if_correct_wb
        # self.if_saliency = if_saliency
        # self.saliency_threshold = sal_thres
        # self.cdiff_threshold = 30
        # self.sal_threshold = 0.9
        self.applied_wb = False
        # self.valid_class = [0,1]


    def inital_info(self, if_correct_wb, if_saliency, wb_thres, sal_method, sal_thres, valid_class):
        self.hist_value, self.hist_count, \
        self.c_center, self.c_density, \
        self.c_img_label, self.sal_links = self.extract_salient_palette(if_wb=if_correct_wb, 
                                                        if_saliency=if_saliency,
                                                        wb_thres=wb_thres, 
                                                        sal_method=sal_method,
                                                        sal_thres=sal_thres, 
                                                        valid_class=valid_class)
        
        # self.label_colored = self.cal_color_segment()

    def get_rgb_image(self):
        return self.img_rgb
    
    def get_lab_image(self):
        return self.img_lab
    
    def get_wb_image(self):
        self.img_wb = self.white_balance_correction()
        return self.img_wb
    
    def get_saliency(self):
        self.sal_map = self.saliency_detection(self.img_rgb)
        return self.sal_map
    
    def get_color_segment(self):
        return self.label_colored
    
    def get_label(self):
        # print(self.links)
        # label_mapped = np.zeros_like(self.colorlabel)
        # for id, label in enumerate(self.links):
        #     label_mapped[self.colorlabel==id] = label
        # self.colorlabel = label_mapped
        return self.colorlabel
    

    def cal_color_segment(self):
        label_colored = np.zeros_like(self.img_rgb, dtype=np.float64)
        for id_color in range(np.size(self.center, 0)):
            label_colored[self.colorlabel == id_color] = self.center[id_color, :]
        label_colored = lab2rgb(label_colored)
        label_colored = np.round(label_colored*255).astype(np.uint8)
        return label_colored
    

    # def cal_salient_segment(self, palettelabel):
    #     label_colored = np.zeros_like(self.img_rgb, dtype=np.float64)
    #     valid_label = np.argwhere(palettelabel==1).flatten()
    #     for id_color in valid_label:
    #         label_colored[self.colorlabel == id_color] = self.center[id_color, :] 
    #     label_colored = lab2rgb(label_colored)
    #     label_colored = np.round(label_colored*255).astype(np.uint8)
    #     return label_colored


    def white_balance_correction(self):
        # print('Correcting the white balance...')
        # use upgraded_model = 1 to load our new model that is upgraded with new
        # training examples.
        upgraded_model = 2
        # use gamut_mapping = 1 for scaling, 2 for clipping (our paper's results
        # reported using clipping). If the image is over-saturated, scaling is
        # recommended.
        gamut_mapping = 2
        # processing
        # create an instance of the WB model
        wbModel = wb_srgb.WBsRGB(gamut_mapping=gamut_mapping,
                                upgraded=upgraded_model)
        img_wb = wbModel.correctImage(self.img_rgb)  # white balance it
        image_wb = (img_wb*255).astype(np.uint8)
        # img_wb = cv2.cvtColor(img_wb, cv2.COLOR_BGR2RGB)
        return image_wb


    def saliency_detection(self, img_rgb, method='LDF'):
        if method == 'LDF':
            get_saliency_LDF = Saliency_LDF()
            sal_map = get_saliency_LDF.inference(img_rgb)
        elif method == 'ft':
            sal_map = get_saliency_ft(img_rgb)
        elif method == 'rbd':
            sal_map = get_saliency_mbd(img_rgb)
        
        return sal_map


    def solve_ind_palette(self, img_rgb, mask_binary=None):
        w, h, c = img_rgb.shape
        img_lab = rgb2lab(img_rgb)    # lab transfer by function

        hist_value, hist_count = histogram(img_lab, self.bin_size, mode=self.mode, mask=mask_binary)   ## with numpy histogram 
    
        ## extract palette
        # mask_binary = np.ones_like(self.img_rgb[:,:,0])
        c_center, c_density, c_img_label, histlabel = palette_extraction(img_lab, hist_value, hist_count, 
                                                                            threshold=self.distortion_threshold, 
                                                                            num_clusters=self.num_center_ind,
                                                                            mode=self.mode,
                                                                            mask=mask_binary)
        
        if self.mode == 2:
            c_center = np.insert(c_center, 0, values=self.lightness, axis=1)

        c_img_label = np.reshape(c_img_label, (w,h))
        # density = np.tile(hist_counts, (self.mode, 1))

        return hist_value, hist_count, c_center, c_density, c_img_label, histlabel
    



    def extract_salient_palette(self, if_wb=False, if_saliency=False, wb_thres=5, sal_method='LDF', sal_thres=0.9, valid_class=[0,1]):

        img_rgb = self.img_rgb.copy()
        if if_wb:
            self.img_wb = self.white_balance_correction()
            img_wb = self.img_wb
            dE = color_difference(img_rgb, img_wb)
            print(dE)
            if dE > wb_thres:
                self.applied_wb = True
                img_rgb = img_wb
                print('use white balance correction on {}'.format(self.filename.split('/')[-1]))

        hist_value, hist_count, center, density, colorlabel, histlabel  = self.solve_ind_palette(img_rgb, mask_binary=None)
        self.center = center
        self.colorlabel = colorlabel

        sal_links = [i for i in range(np.size(center, axis=0))]

        if not if_saliency:
            return hist_value, hist_count, center, density, colorlabel, sal_links
        
        else:
            self.sal_map = self.saliency_detection(img_rgb, method=sal_method)
            label_sem = np.zeros_like(img_rgb[:,:,0])
            # print(label_sem.shape, self.sal_map.shape)
            label_sem[self.sal_map > sal_thres]=1
            
            p_feq = np.zeros((len(valid_class), np.size(center, axis=0)))


            for id_cls, cls in enumerate(valid_class):
                label_binary = np.zeros_like(label_sem)
                label_binary[label_sem==cls] = 1
                colorlabel_cls = colorlabel[label_binary==1]
                value, count = np.unique(colorlabel_cls, return_counts=True)
                p_feq[id_cls, value] = count/count.sum()

            palettelabel = np.argmax(p_feq, axis=0)

            class_num = len(valid_class)
            c_center = [np.array([]) for i in range(class_num)]
            c_density = [np.array([]) for i in range(class_num)]
            c_img_label = [np.array([]) for i in range(class_num)]
            hist_samples = [np.array([]) for i in range(class_num)]
            hist_counts = [np.array([]) for i in range(class_num)]
            mapping = [np.array([]) for i in range(class_num)]


            for id_cls, cls in enumerate(valid_class):
                mapping[id_cls] = np.argwhere(palettelabel==id_cls).flatten()
                c_center[id_cls]= center[mapping[id_cls],:]
                c_density[id_cls] = density[mapping[id_cls]]
                hist_samples[id_cls] = hist_value.copy()
                hist_counts[id_cls] = hist_count.copy()
                hist_counts[id_cls][histlabel!=id_cls] = 0
                

                for idx, label in enumerate(mapping[id_cls]):
                    labels = np.zeros_like(colorlabel)
                    labels[colorlabel==label] = idx
                c_img_label[id_cls] = labels

                # if id_cls ==1:
                #     label_colored = np.zeros_like(self.img_rgb, dtype=np.float64)
                #     for id_color in mapping[id_cls]:
                #         label_colored[colorlabel == id_color] = center[id_color, :] 
                #     label_colored = lab2rgb(label_colored)
                #     label_colored = np.round(label_colored*255).astype(np.uint8)

                # print(colorlabel.shape, c_img_label[id_cls].shape)
                # print(density.shape, c_density[id_cls].shape)
                # print(center.shape, c_center[id_cls].shape)

            sal_links = np.hstack((mapping[1], mapping[0]))

            # print(links)

            return hist_samples, hist_counts, c_center, c_density, c_img_label, sal_links