1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
| import torch
def torch_equalize(image):
def scale_channel(im, c): """Scale the data in the channel to implement equalize.""" im = im[:, :, c] histo = torch.histc(im, bins=256, min=0, max=255) nonzero_histo = torch.reshape(histo[histo != 0], [-1]) step = (torch.sum(nonzero_histo) - nonzero_histo[-1]) // 255
def build_lut(histo, step): lut = (torch.cumsum(histo, 0) + (step // 2)) // step lut = torch.cat([torch.zeros(1), lut[:-1]]) return torch.clamp(lut, 0, 255)
if step == 0: result = im else: result = torch.gather(build_lut(histo, step), 0, im.flatten().long()) result = result.reshape_as(im)
return result.type(torch.uint8)
image = image.type(torch.float) s1 = scale_channel(image, 0) s2 = scale_channel(image, 1) s3 = scale_channel(image, 2) image = torch.stack([s1, s2, s3], 2) return image
def find_nearest_above(my_array, target): diff = my_array - target mask = diff <= -1 if torch.all(mask): c = torch.abs(diff).argmin() return c masked_diff = diff.clone() masked_diff[mask] = 9999 return masked_diff.argmin()
def hist_match(source, template): s = source.view(-1) t = template.view(-1) s_values, bin_idx, s_counts = torch.unique(s, return_inverse=True, return_counts=True) t_values, t_counts = torch.unique(t, return_counts=True) s_quantities = torch.cumsum(s_counts, 0).type(torch.float) t_quantities = torch.cumsum(t_counts, 0).type(torch.float) s_quantities = s_quantities / s_quantities[s_quantities.shape[0] - 1] t_quantities = t_quantities / t_quantities[t_quantities.shape[0] - 1] sour = (s_quantities * 255).type(torch.long) temp = (t_quantities * 255).type(torch.long) b = torch.zeros(sour.shape) for i in range(sour.shape[0]): b[i] = find_nearest_above(temp, sour[i])
s = b[bin_idx] return s.view(source.shape)
def hist_match_dark_prior(img): result = img.clone() result = torch_equalize(result) dark_prior, _ = torch.min(result, axis=2) for i in range(3): result[:, :, i] = hist_match(result[:, :, i], dark_prior) return result
if __name__ == '__main__': from PIL import Image import numpy as np
im = Image.open("haze/*.jpg") img = torch.from_numpy(np.array(im)) img1 = hist_match_dark_prior(img).numpy() im1 = Image.fromarray(img1) im1.save('result/*.jpg')
|