multiple_image_recoloring / extract_palette.py
dxue321's picture
inital upload
c3a7f7f
import numpy as np
# import pandas as pd
from sklearn.cluster import KMeans
from sklearn.metrics import pairwise_distances
from collections import Counter
def histogram(img_lab, bin, mode=2, mask=None):
# img_lab = rgb2lab(img_rgb)
# img_lab = img_lab.astype(int)
if mask is None:
mask = np.ones_like(img_lab[:,:,0])
if img_lab.ndim != 2:
img_lab = img_lab.reshape(-1, 3)
mask = mask.flatten()
img_lab_masked = img_lab[mask==1]
if mode == 3:
hist, edges = np.histogramdd(img_lab_masked, bins=bin)
xpos, ypos, zpos = np.meshgrid(edges[0][:-1], edges[1][:-1], edges[2][:-1], indexing="ij")
hist_samples = np.concatenate((xpos.reshape((bin*bin*bin,1)), ypos.reshape((bin*bin*bin,1)), zpos.reshape((bin*bin*bin,1))), axis=1)
hist_counts = hist.reshape(bin*bin*bin)
elif mode == 2:
hist, xedges, yedges = np.histogram2d(img_lab_masked[:,1], img_lab_masked[:,2], bins=bin, range=None)
xpos, ypos = np.meshgrid(xedges[:-1], yedges[:-1], indexing="ij")
hist_samples = np.concatenate((xpos.reshape((bin*bin,1)), ypos.reshape((bin*bin,1))), axis=1)
hist_counts = hist.reshape(bin*bin)
# hist_counts = hist_counts/np.sum(hist_counts)
return hist_samples, hist_counts
def palette_extraction(img_lab, hist_samples, hist_counts, mode=2, threshold=0.93, num_clusters=5, mask=None):
if mask is None:
mask = np.ones_like(img_lab[:,:,0])
if img_lab.ndim != 2:
img_lab = img_lab.reshape(-1, 3)
mask = mask.flatten()
# img_lab = img_lab[mask==1]
hist_densities = hist_counts /np.sum(hist_counts)
########################### palette extraction ###########################
# inital cluster center
index = np.argwhere(hist_densities!=0)
index = np.squeeze(index, axis=(1,))
num_nonzero = np.size(index)
# ## directly clustering
# num_clusters_opt = num_clusters
# kmeans_f = KMeans(n_clusters=num_clusters_opt, init='k-means++', random_state=0).fit(
# hist_samples[index, :], y=None, sample_weight=hist_densities[index])
## clustering method from matlab code
inits_all = []
Cold = np.mean(hist_samples[index, :], 0)
distortion=np.zeros((num_clusters,1))
dist = pairwise_distances(hist_samples[index, :], np.expand_dims(Cold, axis=0), metric='euclidean')
distortion[0] = np.sum(hist_densities[index] * np.squeeze(dist**2, axis=1), 0)
inits_all.append(Cold)
for k in range(1, num_clusters):
# Initialize the cluster centers
k = k+1
cinits = np.zeros((k, mode))
cw = hist_densities[index]
for i in range(k):
id = np.argmax(cw)
cinits[i,:] = hist_samples[index, :][id,:]
d2 = cinits[i,:]* np.ones((num_nonzero, 1)) - hist_samples[index, :]
d2 = np.sum(np.square(d2), axis=1)
d2 = d2/np.max(d2)
cw = cw * (d2**2)
inits_all.append(cinits)
kmeans = KMeans(n_clusters=k, init=cinits, n_init=1).fit(
hist_samples[index, :], y=None, sample_weight=hist_densities[index])
dist_point = pairwise_distances(hist_samples[index, :], kmeans.cluster_centers_, metric='euclidean')
distortion[k-1] = np.sum(hist_densities[index] * np.min(dist_point, axis=1)**2)
variance = distortion[:-1] - distortion[1:]
distortion_percent = np.cumsum(variance)/(distortion[0]-distortion[-1])
r=np.argwhere(distortion_percent > threshold)
num_clusters_opt = np.min(r)+2
kmeans_f = KMeans(n_clusters=num_clusters_opt, init=inits_all[num_clusters_opt-1], n_init=1).fit(
hist_samples[index, :], y=None, sample_weight=hist_densities[index])
cluster_centers = kmeans_f.cluster_centers_
# print(cluster_centers.shape)
if mode ==3:
img_labels = kmeans_f.predict(img_lab)
elif mode == 2:
img_labels = kmeans_f.predict(img_lab[:, 1:3])
hist_labels = kmeans_f.predict(hist_samples)
# print(cluster_centers.shape)
# # lab to rgb
# cluster_cen_rgb = lab2rgb(np.expand_dims(cluster_centers, axis=0))
# cluster_cen_rgb = np.squeeze(cluster_cen_rgb, axis=0)
img_labels[mask==0] = 255
c_densities = np.zeros(num_clusters_opt)
dict=Counter(img_labels)
for key in np.unique(img_labels):
if key == 255:
continue
c_densities[key] = dict.get(key)
c_densities = c_densities / np.sum(c_densities)
return cluster_centers, c_densities, img_labels, hist_labels