| import os
|
| import numpy as np
|
| import torch
|
| import shutil
|
| from torch.autograd import Variable
|
| import matplotlib.pyplot as plt
|
| from PIL import Image
|
|
|
|
|
|
|
| def pair_downsampler(img):
|
|
|
| c = img.shape[1]
|
| filter1 = torch.FloatTensor([[[[0, 0.5], [0.5, 0]]]]).to(img.device)
|
| filter1 = filter1.repeat(c, 1, 1, 1)
|
| filter2 = torch.FloatTensor([[[[0.5, 0], [0, 0.5]]]]).to(img.device)
|
| filter2 = filter2.repeat(c, 1, 1, 1)
|
| output1 = torch.nn.functional.conv2d(img, filter1, stride=2, groups=c)
|
| output2 = torch.nn.functional.conv2d(img, filter2, stride=2, groups=c)
|
| return output1,output2
|
|
|
| def gauss_cdf(x):
|
| return 0.5*(1+torch.erf(x/torch.sqrt(torch.tensor(2.))))
|
|
|
| def gauss_kernel(kernlen=21,nsig=3,channels=1):
|
| interval=(2*nsig+1.)/(kernlen)
|
| x=torch.linspace(-nsig-interval/2.,nsig+interval/2.,kernlen+1,).cuda()
|
|
|
| kern1d=torch.diff(gauss_cdf(x))
|
| kernel_raw=torch.sqrt(torch.outer(kern1d,kern1d))
|
| kernel=kernel_raw/torch.sum(kernel_raw)
|
|
|
| out_filter=kernel.view(1,1,kernlen,kernlen)
|
| out_filter = out_filter.repeat(channels,1,1,1)
|
| return out_filter
|
|
|
| class LocalMean(torch.nn.Module):
|
| def __init__(self, patch_size=5):
|
| super(LocalMean, self).__init__()
|
| self.patch_size = patch_size
|
| self.padding = self.patch_size // 2
|
|
|
| def forward(self, image):
|
| image = torch.nn.functional.pad(image, (self.padding, self.padding, self.padding, self.padding), mode='reflect')
|
| patches = image.unfold(2, self.patch_size, 1).unfold(3, self.patch_size, 1)
|
| return patches.mean(dim=(4, 5))
|
|
|
| def blur(x):
|
| device = x.device
|
| kernel_size = 21
|
| padding = kernel_size // 2
|
| kernel_var = gauss_kernel(kernel_size, 1, x.size(1)).to(device)
|
| x_padded = torch.nn.functional.pad(x, (padding, padding, padding, padding), mode='reflect')
|
| return torch.nn.functional .conv2d(x_padded, kernel_var, padding=0, groups=x.size(1))
|
|
|
| def padr_tensor(img):
|
| pad=2
|
| pad_mod=torch.nn.ConstantPad2d(pad,0)
|
| img_pad=pad_mod(img)
|
| return img_pad
|
|
|
| def calculate_local_variance(train_noisy):
|
| b,c,w,h=train_noisy.shape
|
| avg_pool = torch.nn.AvgPool2d(kernel_size=5,stride=1,padding=2)
|
| noisy_avg= avg_pool(train_noisy)
|
| noisy_avg_pad=padr_tensor(noisy_avg)
|
| train_noisy=padr_tensor(train_noisy)
|
| unfolded_noisy_avg=noisy_avg_pad.unfold(2,5,1).unfold(3,5,1)
|
| unfolded_noisy=train_noisy.unfold(2,5,1).unfold(3,5,1)
|
| unfolded_noisy_avg=unfolded_noisy_avg.reshape(unfolded_noisy_avg.shape[0],-1,5,5)
|
| unfolded_noisy=unfolded_noisy.reshape(unfolded_noisy.shape[0],-1,5,5)
|
| noisy_diff_squared=(unfolded_noisy-unfolded_noisy_avg)**2
|
| noisy_var=torch.mean(noisy_diff_squared,dim=(2,3))
|
| noisy_var=noisy_var.view(b,c,w,h)
|
| return noisy_var
|
|
|
| def count_parameters_in_MB(model):
|
| return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6
|
|
|
|
|
|
|
| def save_checkpoint(state, is_best, save):
|
| filename = os.path.join(save, 'checkpoint.pth.tar')
|
| torch.save(state, filename)
|
| if is_best:
|
| best_filename = os.path.join(save, 'model_best.pth.tar')
|
| shutil.copyfile(filename, best_filename)
|
|
|
|
|
| def save(model, model_path):
|
| torch.save(model.state_dict(), model_path)
|
|
|
|
|
| def load(model, model_path):
|
| model.load_state_dict(torch.load(model_path))
|
|
|
| def drop_path(x, drop_prob):
|
| if drop_prob > 0.:
|
| keep_prob = 1.-drop_prob
|
| mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob))
|
| x.div_(keep_prob)
|
| x.mul_(mask)
|
| return x
|
|
|
| def create_exp_dir(path, scripts_to_save=None):
|
| if not os.path.exists(path):
|
| os.makedirs(path,exist_ok=True)
|
| print('Experiment dir : {}'.format(path))
|
|
|
| if scripts_to_save is not None:
|
| os.makedirs(os.path.join(path, 'scripts'),exist_ok=True)
|
| for script in scripts_to_save:
|
| dst_file = os.path.join(path, 'scripts', os.path.basename(script))
|
| shutil.copyfile(script, dst_file)
|
|
|
| def show_pic(pic, name,path):
|
| pic_num = len(pic)
|
| for i in range(pic_num):
|
| img = pic[i]
|
| image_numpy = img[0].cpu().float().numpy()
|
| if image_numpy.shape[0]==3:
|
| image_numpy = (np.transpose(image_numpy, (1, 2, 0)))
|
| im = Image.fromarray(np.clip(image_numpy * 255.0, 0, 255.0).astype('uint8'))
|
| img_name = name[i]
|
| plt.subplot(5, 6, i + 1)
|
| plt.xlabel(str(img_name))
|
| plt.xticks([])
|
| plt.yticks([])
|
| plt.imshow(im)
|
| elif image_numpy.shape[0]==1:
|
| im = Image.fromarray(np.clip(image_numpy[0] * 255.0, 0, 255.0).astype('uint8'))
|
| img_name = name[i]
|
| plt.subplot(5, 6, i + 1)
|
| plt.xlabel(str(img_name))
|
| plt.xticks([])
|
| plt.yticks([])
|
| plt.imshow(im,plt.cm.gray)
|
| plt.savefig(path)
|
|
|
|
|
|
|
|
|