-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathdataloader.py
More file actions
266 lines (230 loc) · 10.9 KB
/
dataloader.py
File metadata and controls
266 lines (230 loc) · 10.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
"""
Copyright (c) 2025 Samsung Electronics Co., Ltd.
Author(s):
Mahmoud Afifi (m.afifi1@samsung.com, m.3afifi@gmail.com)
Licensed under the Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) License, (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at https://creativecommons.org/licenses/by-nc/4.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and limitations under the License.
For conditions of distribution and use, see the accompanying LICENSE.md file.
Dataloader for training/testing our illuminant estimation model.
"""
from torch.utils.data import Dataset
from utils import *
class IlluminantEstimationDataLoader(Dataset):
def __init__(self, data_dir: str, capture_data: List[str],
user_pref: Optional[bool]=False,
target_size: Optional[Union[List[int], None]]=None,
train: Optional[bool]=False,
test: Optional[bool]=False,
valid: Optional[bool]=False,
hist_bins: Optional[int]=32,
normalization_data: Optional[np.ndarray]=None,
hist_boundaries: Optional[List[float]]=None,
without_mask: Optional[bool] = False):
"""Initialization.
Args:
data_dir: The path to the dataset directory, which must include 'training', 'testing', and 'validation'
subdirectories.
capture_data: A list of capture data types to load. Options include:
'iso', 'shutter_speed', 'flash', 'time', 'noise_stats', 'snr_stats'.
user_pref: A flag to use user-preference ground-truth illuminant colors.
target_size: Target image size [height, width]. If not provided, original image size will be used.
train: A flag to indicate loading of training data. If True, this includes computing histogram boundaries and
normalization factors if 'norm_data' is True.
test: A flag to indicate loading of testing data.
valid: A flag to indicate loading of validation data. At least one of 'train', 'test', or 'valid' must be set to
True.
hist_bins: Number of histogram bins.
normalization_data: If 'test' or 'valid' is True, this should provide the min-max values for capture data to
normalize the testing and validation data accordingly.
hist_boundaries: If 'test' or 'valid' is True, this should provide the histogram boundaries.
without_mask: A flag indicating whether to use images without masks applied. Ensure that the constants:
WITHOUT_MASK_SUB_FOLDER and WITH_MASK_SUB_FOLDER in constants.py are correctly configured to specify the
expected folder paths for images with and without masks.
"""
self._data_dir = data_dir
self._without_mask = without_mask
self._hist_bins = hist_bins
if train + test + valid > 1 or train + test + valid == 0:
raise ValueError('Choose only one of the options: train, test, valid.')
elif train:
self._data_dir = os.path.join(self._data_dir, 'train')
self._training = True
elif test:
self._data_dir = os.path.join(self._data_dir, 'test')
self._training = False
if hist_boundaries is None or len(hist_boundaries) < 4:
raise ValueError('Histogram boundaries (min_u, min_v, max_u, max_v) should be provided.')
elif valid:
self._data_dir = os.path.join(self._data_dir, 'val')
self._training = False
if hist_boundaries is None or len(hist_boundaries) < 4:
raise ValueError('Histogram boundaries (min_u, min_v, max_u, max_v) should be provided.')
self._filenames = [f for f in os.listdir(
os.path.join(self._data_dir, 'data')) if f.endswith('.json')]
if valid or test:
self._u_min = hist_boundaries[0]
self._v_min = hist_boundaries[1]
self._u_max = hist_boundaries[2]
self._v_max = hist_boundaries[3]
else:
self._u_min = None
self._v_min = None
self._u_max = None
self._v_max = None
self._capture_data = capture_data
self._target_size = target_size
if self._target_size is not None:
assert isinstance(self._target_size, list), 'Invalid target size: must be a list.'
assert len(self._target_size) == 2, 'Invalid target size: must have two elements [height, width].'
assert all(isinstance(x, int) and x > 0 for x in
self._target_size), 'Invalid target size: elements must be positive integers.'
self._user_pref = user_pref
print(f'Creating dataset with {len(self._filenames)} examples')
print('Done!')
print(f'Caching dataset ...')
self._dataset = []
for i in range(self.__len__()):
self._dataset.append(self._data_loading(i))
if normalization_data is None:
self._compute_data_normalization_values()
else:
self._capture_data_min_vector = np.array(normalization_data['capture-data-min'])
self._capture_data_max_vector = np.array(normalization_data['capture-data-max'])
if train:
print('Computing histogram boundaries...')
self._compute_histogram_boundary()
self._add_histograms()
def get_hist_boundaries(self):
"""Returns histogram boundaries."""
return [self._u_min, self._v_min, self._u_max, self._v_max]
def _add_histograms(self):
"""Update cached dataset with histograms."""
print('Computing histograms...')
for i in range(len(self._dataset)):
self._dataset[i].update(
{'hist': self._compute_histogram(self._dataset[i]['img'])}
)
self._dataset[i].pop('img', None)
def _compute_histogram_boundary(self):
"""Computes histogram boundary of the dataset."""
chroma_values = []
for data in self._dataset:
img = data['img']
img_chroma = rgb_to_rgbg(img).reshape([-1, 2])
chroma_values.append(img_chroma)
chroma_values = np.array(chroma_values).reshape([-1, 2])
min_uv = np.quantile(chroma_values, 0.1, axis=0)
max_uv = np.quantile(chroma_values, 0.95, axis=0)
self._u_min = min_uv[0]
self._u_max = max_uv[0]
self._v_min = min_uv[1]
self._v_max = max_uv[1]
print(f'Histogram boundaries are: ({self._u_min}, {self._u_max}), ({self._v_min}, {self._v_max}).')
def _compute_histogram(self, img: np.ndarray) -> np.ndarray:
"""Returns 2D histogram of given image(s)."""
return compute_2d_rgbg_histogram(
img=img, hist_boundaries=[self._u_min, self._v_min, self._u_max, self._v_max],
hist_bins=self._hist_bins, edge_hist=True, uv_coord=True)
def _get_capture_data(self, data: Dict[str, Any]):
"""Returns processed capture data."""
capture_data = np.array([])
for key in self._capture_data:
if key == 'noise_stats':
rgb_mean = data['noise_stats']['rgb_mean']
rgb_std = data['noise_stats']['rgb_std']
data_i = np.array(rgb_mean + rgb_std)
elif key == 'snr_stats':
rgb_mean = data['snr_stats']['mean']
rgb_std = data['snr_stats']['std']
data_i = np.array(rgb_mean + rgb_std)
elif key == 'time':
time = np.stack(
[data['capture_metadata']['prob_sunrise'],
data['capture_metadata']['prob_sunset'],
data['capture_metadata']['prob_dusk'],
data['capture_metadata']['prob_dawn'],
data['capture_metadata']['prob_noon'],
data['capture_metadata']['prob_midnight']], axis=-1)
is_before = np.stack(
[data['capture_metadata']['is_before_sunrise'],
data['capture_metadata']['is_before_sunset'],
data['capture_metadata']['is_before_dusk'],
data['capture_metadata']['is_before_dawn'],
data['capture_metadata']['is_before_noon'],
data['capture_metadata']['is_before_midnight']], axis=-1)
data_i = np.concatenate([np.sqrt(time), is_before])
elif key == 'flash':
data_i = np.array([data['capture_metadata'][key]])
else:
data_i = np.log(data['capture_metadata'][key])
if np.isscalar(data_i):
data_i = np.array([data_i])
capture_data = np.concatenate([capture_data, data_i.flatten()], axis=0)
return capture_data
def __len__(self):
"""Returns number of images in the set."""
return len(self._filenames)
def get_normalization_values(self):
"""Returns normalization values."""
return {'capture-data-min': self._capture_data_min_vector,
'capture-data-max': self._capture_data_max_vector}
def _compute_data_normalization_values(self):
"""Computes data normalization values."""
print('Computing normalization values....')
data = self._dataset[0]
data_shape = data['capture_data'].shape
capture_data_temp_min = np.ones(data_shape) * np.inf
capture_data_temp_max = np.ones(data_shape) * -np.inf
for data in self._dataset:
capture_data_temp_min = np.minimum(data['capture_data'], capture_data_temp_min)
capture_data_temp_max = np.maximum(data['capture_data'], capture_data_temp_max)
self._capture_data_min_vector = capture_data_temp_min
self._capture_data_max_vector = capture_data_temp_max
def _data_loading(self, i):
"""Loads data based on in_data."""
f = self._filenames[i]
data_file = os.path.join(self._data_dir, 'data', f)
data = read_json_file(data_file)
if self._user_pref:
gt = np.array(data['pref_illum'])
else:
gt = np.array(data['gt_illum'])
capture_data = self._get_capture_data(data).astype(np.float32)
if self._without_mask:
sub_dir = WITHOUT_MASK_SUB_FOLDER
else:
sub_dir = WITH_MASK_SUB_FOLDER
img = imread(os.path.join(self._data_dir, sub_dir, f.replace('.json', '.png')))
if self._target_size is not None and (self._target_size[0] != img.shape[0] or self._target_size[1] != img.shape[1]):
img = cv2.resize(img, (self._target_size[1], self._target_size[0]),
interpolation=cv2.INTER_LINEAR)
if self._training:
data = {'capture_data': capture_data,
'img': img.astype(np.float32),
'gt': gt.astype(np.float32),
'filename': f}
else:
data = {'capture_data': capture_data,
'gt': gt.astype(np.float32),
'filename': f}
if self._u_min is not None:
hist = self._compute_histogram(img)
data.update({'hist': hist})
return data
def __getitem__(self, i):
"""Returns ith item in dataset."""
data = self._dataset[i]
capture_data = data['capture_data']
data_min = self._capture_data_min_vector
data_max = self._capture_data_max_vector
capture_data = min_max_normalization(capture_data, data_min, data_max)
return {
'capture_data': torch.from_numpy(capture_data.astype(np.float32)),
'hist': torch.from_numpy(data['hist'].transpose(2, 0, 1).astype(np.float32)),
'gt': torch.from_numpy(data['gt']),
'filename': data['filename']
}