I am trying to implement a type of GAN (Generative Adversarial Networks) named RS-ESRGAN to my satellite imagery dataset. You can check the github page of it if you like(https://github.com/luissalgueiro/rs-esrgan/tree/master). Here is the complete script I run on google colab:
!git clone https:
!pip install lmdb
!pip install earthpy
!pip install tensorboardX
!cp /content/rs-esrgan/utils -r /usr/local/lib/python3.10/dist-package
import pandas
import matplotlib
import math
import torch
import json
import torch.nn as nn
import numpy as np
from skimage.metrics import peak_signal_noise_ratio
# Loading a JSON File to a Python Dictionary
with open('/content/rs-esrgan/options/train/train_ESRGAN_WV_5x_ALL.json', 'r') as file:
data = json.load(file)
print(data)
# Data to be written
dictionary = {'name': 'P41_ESRGAN_ALL_L1_x5_PREUP_FT_StanData', 'use_tb_logger': True, 'model': 'srragan', 'scale': 5, 'gpu_ids': [0], 'datasets': {'train': {'name': 'GDAL', 'mode': 'LRHR', 'data_IDs': '/content/dataset/bandas_all.csv', 'dataroot_HR': '/content/dataset/train_HR', 'dataroot_LR': '/content/dataset/train_LR', 'subset_file': None, 'use_shuffle': True, 'n_workers': 1, 'batch_size': 2, 'HR_size': 140, 'use_flip': True, 'use_rot': True, 'stand': True, 'LR_down': False, 'PreUP': True, 'norm': False, 'up_lr': False, 'HF': False, 'scale': 5}, 'val': {'name': 'val_GDAL', 'mode': 'LRHR', 'data_IDs': '/content/dataset/valid.csv', 'dataroot_HR': '/content/dataset/valid_HR', 'dataroot_LR': '/content/dataset/valid_LR', 'HR_size': 140, 'stand': True, 'norm': False, 'LR_down': False, 'PreUP': True, 'up_lr': False, 'HF': False, 'scale': 5}}, 'path': {'root': '/content/drive/MyDrive/BasicSR', 'work_root': '/content/drive/MyDrive/BasicSR_2020/', 'resume_state': None}, 'network_G': {'which_model_G': 'RRDB_net', 'norm_type': None, 'mode': 'CNA', 'nf': 64, 'nb': 23, 'in_nc': 3, 'out_nc': 3}, 'network_D': {'which_model_D': 'discriminator_vgg_128', 'norm_type': None, 'act_type': 'leakyrelu', 'mode': 'CNA', 'nf': 64, 'in_nc': 3}, 'train': {'lr_G': 0.0001, 'weight_decay_G': 0, 'beta1_G': 0.9, 'lr_D': 0.0001, 'weight_decay_D': 0, 'beta1_D': 0.9, 'lr_scheme': 'MultiStepLR', 'lr_steps': [20000, 40000, 60000, 80000], 'lr_gamma': 0.5, 'pixel_criterion': 'l1', 'pixel_weight': 0.01, 'feature_criterion': 'l1', 'feature_weight': 1, 'gan_type': 'vanilla', 'gan_weight': 0.005, 'manual_seed': 100, 'niter': 101000, 'val_freq': 10000}, 'logger': {'print_freq': 100, 'save_checkpoint_freq': 10000}}
json_object = json.dumps(dictionary, indent=4)
# Writing to sample.json
with open("sample.json", "w") as outfile:
outfile.write(json_object)
!python /content/rs-esrgan/train_esrgan_WV.py -opt /content/sample.json
First of all, I cloned the git repo.
Then, I created a folder inside google colab content part for uploading my dataset.
I copied all the important folders of the repository (data, models, options and utils folders) into usr/local/lib/Python3.10/dist-packages folder to import the custom packages of the repo.
I converted train file (.json format) to python dict. Then, I copied the output and pasted into the next line's dictionary part and created sample.json file which is the file I use for training.
Lastly, I started the training. It was working fine at first but after a couple of hours, I got this error message :
Traceback (most recent call last):
File "/content/rs-esrgan/train_esrgan_WV.py", line 299, in <module>
main()
File "/content/rs-esrgan/train_esrgan_WV.py", line 188, in main
sr_img = util.tensor2imgStand(visuals['SR'], MeanVal = val_data["LR_mean"], StdVal = val_data["LR_std"]) # uint16
KeyError: 'LR_mean'
What I have tried:
I thought this error might result from some missing lines in LRHR_dataset.py file which is in data folder in the repo so I tried to add some lines to define LR_mean, LR_std in the python file but I could not solve the problem. Here is the content of LRHR_dataset.py:
import os.path
import random
import numpy as np
import cv2
import torch
import torch.utils.data as data
import data.util as util
class LRHRDataset(data.Dataset):
'''
Read LR and HR image pairs.
If only HR image is provided, generate LR image on-the-fly.
The pair is ensured by 'sorted' function, so please check the name convention.
'''
def __init__(self, opt):
super(LRHRDataset, self).__init__()
self.opt = opt
self.paths_LR = None
self.paths_HR = None
self.LR_env = None # environment for lmdb
self.HR_env = None
# read image list from subset list txt
if opt['subset_file'] is not None and opt['phase'] == 'train':
with open(opt['subset_file']) as f:
self.paths_HR = sorted([os.path.join(opt['dataroot_HR'], line.rstrip('\n')) \
for line in f])
if opt['dataroot_LR'] is not None:
raise NotImplementedError('Now subset only supports generating LR on-the-fly.')
else: # read image list from lmdb or image files
self.HR_env, self.paths_HR = util.get_image_paths(opt['data_type'], opt['dataroot_HR'])
self.LR_env, self.paths_LR = util.get_image_paths(opt['data_type'], opt['dataroot_LR'])
assert self.paths_HR, 'Error: HR path is empty.'
if self.paths_LR and self.paths_HR:
assert len(self.paths_LR) == len(self.paths_HR), \
'HR and LR datasets have different number of images - {}, {}.'.format(\
len(self.paths_LR), len(self.paths_HR))
self.random_scale_list = [1]
def __getitem__(self, index):
# print("EJecuta GeT ITEM")
HR_path, LR_path = None, None
scale = 0 #self.opt['scale']
HR_size = self.opt['HR_size']
# get HR image
HR_path = self.paths_HR[index]
img_HR = util.read_img(self.HR_env, HR_path)
# modcrop in the validation / test phase
if self.opt['phase'] != 'train':
# print("Opcion: ", self.opt["phase"])
img_HR = util.modcrop(img_HR, scale)
# change color space if necessary
if self.opt['color']:
img_HR = util.channel_convert(img_HR.shape[2], self.opt['color'], [img_HR])[0]
# get LR image
if self.paths_LR:
LR_path = self.paths_LR[index]
img_LR = util.read_img(self.LR_env, LR_path)
else: # down-sampling on-the-fly
# randomly scale during training
if self.opt['phase'] == 'train':
random_scale = random.choice(self.random_scale_list)
H_s, W_s, _ = img_HR.shape
def _mod(n, random_scale, scale, thres):
rlt = int(n * random_scale)
rlt = (rlt
return thres if rlt < thres else rlt
H_s = _mod(H_s, random_scale, scale, HR_size)
W_s = _mod(W_s, random_scale, scale, HR_size)
img_HR = cv2.resize(np.copy(img_HR), (W_s, H_s), interpolation=cv2.INTER_LINEAR)
# force to 3 channels
if img_HR.ndim == 2:
img_HR = cv2.cvtColor(img_HR, cv2.COLOR_GRAY2BGR)
H, W, _ = img_HR.shape
# using matlab imresize
img_LR = util.imresize_np(img_HR, 1 / scale, True)
if img_LR.ndim == 2:
img_LR = np.expand_dims(img_LR, axis=2)
if self.opt['phase'] == 'train':
# if the image size is too small
# print("1...........")
H, W, _ = img_HR.shape
if H < HR_size or W < HR_size:
img_HR = cv2.resize(
np.copy(img_HR), (HR_size, HR_size), interpolation=cv2.INTER_LINEAR)
# using matlab imresize
img_LR = util.imresize_np(img_HR, 1 / scale, True)
if img_LR.ndim == 2:
img_LR = np.expand_dims(img_LR, axis=2)
H, W, C = img_LR.shape
LR_size = HR_size #
# randomly crop
rnd_h = random.randint(0, max(0, H - LR_size))
rnd_w = random.randint(0, max(0, W - LR_size))
img_LR = img_LR[rnd_h:rnd_h + LR_size, rnd_w:rnd_w + LR_size, :]
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
img_HR = img_HR[rnd_h_HR:rnd_h_HR + HR_size, rnd_w_HR:rnd_w_HR + HR_size, :]
# augmentation - flip, rotate
img_LR, img_HR = util.augment([img_LR, img_HR], self.opt['use_flip'], \
self.opt['use_rot'])
# change color space if necessary
if self.opt['color']:
img_LR = util.channel_convert(C, self.opt['color'], [img_LR])[0] # TODO during val no definetion
# BGR to RGB, HWC to CHW, numpy to tensor
if img_HR.shape[2] == 3:
img_HR = img_HR[:, :, [2, 1, 0]]
img_LR = img_LR[:, :, [2, 1, 0]]
img_HR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_HR, (2, 0, 1)))).float()
img_LR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LR, (2, 0, 1)))).float()
# print("2.................................")
#
# print(img_HR.shape)
if LR_path is None:
LR_path = HR_path
return {'LR': img_LR, 'HR': img_HR, 'LR_path': LR_path, 'HR_path': HR_path}
def __len__(self):
# print("EJecuta LEN")
return len(self.paths_HR)
I ran out of my python knowledge here. Where and how in this LRHR_dataset.py file should I make changes to add new lines?