没什么用的前言

何凯民博士的论文复现,笨比不会写代码,但是会copy。

暗通道

Jdark(x)=minyΩ(x)(Jc(y)cr,g,b)Jdark0J^{dark}(x) = \underset{y\in \Omega (x)}{min} (\underset{c\in{r,g,b}}{J^c(y)})\\ J^{dark}\to 0

  • 求出每个像素RGB分量的最小值,存入一副和原始图像大小相同的灰度图,进行最小值滤波,滤波半径由窗口大小决定,一般的有windows=2×Radius+1windows = 2\times Radius +1

雾图形

I(x)=J(x)t(x)+A(1t(x))I(x) = J(x)t(x)+A(1-t(x))

  • I(X)I(X)为已有图像,t(x)t(x)为透射率,J(x)J(x)为目标值,AA为全球大气光成分。假定AA是恒定的,我们在暗通道选出最亮的0.10.1%的像素,这些像素中输入图像II具有最高强度的像素被选为大气光。

透射率化简

t~(x)=1minyΩ(x)(mincIc(y)Ac)\widetilde{t}(x) = 1 - \underset{y\in \Omega(x)}{min}({\underset{c}{min}\frac{I^c(y)}{A^c}} )

修正

t~(x)=1ωminyΩ(x)(mincIc(y)Ac)\widetilde{t}(x) = 1 -\omega\underset{y\in \Omega(x)}{min}({\underset{c}{min}\frac{I^c(y)}{A^c}} )

  • ω\omega是自然一定程度的雾,论文取0.950.95

恢复公式

J(x)=I(x)Amax(t(x),t0)+AJ(x) = \frac{I(x)-A}{max(t(x),t_0)} +A

  • 当透视图tt很小,会导致JJ偏大,使得图像向白场过渡,设置阈值T0T_0,当阈值小于T0T_0,令t=T0t=T_0,此处取T0=0.1T_0 = 0.1为标准。

步骤

  1. 求暗通道图
  2. 求全球大气光值
  3. 估计透射率图
  4. 计算导向滤波图
  5. 获得精细的透射率图
  6. 恢复图像

效果

20240810171840
20240810171816

其他

  • 导向滤波获得较好的透射率图
  • soft matting获得精细的透射率图
  • MinFilter算法的快速实现
  • 均值模糊是个很快速的算法
  • 窗口大小影响,越大去雾效果约不明显
  • ω具有着明显的意义,其值越小,去雾效果越不明显
  • 关于均值模糊的优化可参考的文章:彩色图像高速模糊之懒惰算法。

无关紧要的py代码

  • cv2.imread()读取图片
    imread函数有两个参数,第一个参数是图片路径,第二个参数表示读取图片的形式,有三种:

    • cv2.IMREAD_COLOR:加载彩色图片,这个是默认参数,可以直接写1。
    • cv2.IMREAD_GRAYSCALE:以灰度模式加载图片,可以直接写0。
    • cv2.IMREAD_UNCHANGED:包括alpha,可以直接写-1
  • cv2默认读进去为BGR顺序,而其他函数一般使用RGB,所以需要转换

1
2
img = cv2.imread(picture) # cv2默认为bgr顺序
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  • 代码参考(扒来的)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch

def torch_equalize(image):

def scale_channel(im, c):
"""Scale the data in the channel to implement equalize."""
im = im[:, :, c]
# Compute the histogram of the image channel.
histo = torch.histc(im, bins=256, min=0, max=255) # .type(torch.int32)
# For the purposes of computing the step, filter out the nonzeros.
nonzero_histo = torch.reshape(histo[histo != 0], [-1])
step = (torch.sum(nonzero_histo) - nonzero_histo[-1]) // 255

def build_lut(histo, step):
# Compute the cumulative sum, shifting by step // 2
# and then normalization by step.
lut = (torch.cumsum(histo, 0) + (step // 2)) // step
# Shift lut, prepending with 0.
lut = torch.cat([torch.zeros(1), lut[:-1]])
# Clip the counts to be in range. This is done
# in the C code for image.point.
return torch.clamp(lut, 0, 255)

# If step is zero, return the original image. Otherwise, build
# lut from the full histogram and step and then index from it.
if step == 0:
result = im
else:
# can't index using 2d index. Have to flatten and then reshape
result = torch.gather(build_lut(histo, step), 0, im.flatten().long())
result = result.reshape_as(im)

return result.type(torch.uint8)

# Assumes RGB for now. Scales each channel independently
# and then stacks the result.
image = image.type(torch.float)
s1 = scale_channel(image, 0)
s2 = scale_channel(image, 1)
s3 = scale_channel(image, 2)
image = torch.stack([s1, s2, s3], 2)
return image


def find_nearest_above(my_array, target):
diff = my_array - target
mask = diff <= -1
# We need to mask the negative differences
# since we are looking for values above
if torch.all(mask):
c = torch.abs(diff).argmin()
return c # returns min index of the nearest if target is greater than any value
masked_diff = diff.clone()
masked_diff[mask] = 9999
return masked_diff.argmin()


def hist_match(source, template):
s = source.view(-1)
t = template.view(-1)
s_values, bin_idx, s_counts = torch.unique(s, return_inverse=True, return_counts=True)
t_values, t_counts = torch.unique(t, return_counts=True)
s_quantities = torch.cumsum(s_counts, 0).type(torch.float)
t_quantities = torch.cumsum(t_counts, 0).type(torch.float)
s_quantities = s_quantities / s_quantities[s_quantities.shape[0] - 1]
t_quantities = t_quantities / t_quantities[t_quantities.shape[0] - 1]
sour = (s_quantities * 255).type(torch.long)
temp = (t_quantities * 255).type(torch.long)
b = torch.zeros(sour.shape)
for i in range(sour.shape[0]):
b[i] = find_nearest_above(temp, sour[i])

s = b[bin_idx]
return s.view(source.shape)


def hist_match_dark_prior(img):
# input: img[h, w, c]
# output:res[h, w, c]
result = img.clone()
result = torch_equalize(result)
dark_prior, _ = torch.min(result, axis=2)
for i in range(3):
result[:, :, i] = hist_match(result[:, :, i], dark_prior)
return result


if __name__ == '__main__':
from PIL import Image
import numpy as np

im = Image.open("haze/*.jpg")
img = torch.from_numpy(np.array(im))
img1 = hist_match_dark_prior(img).numpy()
im1 = Image.fromarray(img1)
im1.save('result/*.jpg')