defscale_channel(im, c): """Scale the data in the channel to implement equalize.""" im = im[:, :, c] # Compute the histogram of the image channel. histo = torch.histc(im, bins=256, min=0, max=255) # .type(torch.int32) # For the purposes of computing the step, filter out the nonzeros. nonzero_histo = torch.reshape(histo[histo != 0], [-1]) step = (torch.sum(nonzero_histo) - nonzero_histo[-1]) // 255
defbuild_lut(histo, step): # Compute the cumulative sum, shifting by step // 2 # and then normalization by step. lut = (torch.cumsum(histo, 0) + (step // 2)) // step # Shift lut, prepending with 0. lut = torch.cat([torch.zeros(1), lut[:-1]]) # Clip the counts to be in range. This is done # in the C code for image.point. return torch.clamp(lut, 0, 255)
# If step is zero, return the original image. Otherwise, build # lut from the full histogram and step and then index from it. if step == 0: result = im else: # can't index using 2d index. Have to flatten and then reshape result = torch.gather(build_lut(histo, step), 0, im.flatten().long()) result = result.reshape_as(im)
return result.type(torch.uint8)
# Assumes RGB for now. Scales each channel independently # and then stacks the result. image = image.type(torch.float) s1 = scale_channel(image, 0) s2 = scale_channel(image, 1) s3 = scale_channel(image, 2) image = torch.stack([s1, s2, s3], 2) return image
deffind_nearest_above(my_array, target): diff = my_array - target mask = diff <= -1 # We need to mask the negative differences # since we are looking for values above if torch.all(mask): c = torch.abs(diff).argmin() return c # returns min index of the nearest if target is greater than any value masked_diff = diff.clone() masked_diff[mask] = 9999 return masked_diff.argmin()