-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathConfig.py
More file actions
68 lines (59 loc) · 2.56 KB
/
Config.py
File metadata and controls
68 lines (59 loc) · 2.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import yaml
import os
class train:
def __init__(self):
self.train_dir = None
self.valid_dir = None
self.use_valid_data = None
self.output_dir = None
self.train_split = None
self.features = None
self.add_range = None
self.labels = None
self.normalize = None
self.binary = None
self.device = None
self.batch_size = None
self.epochs = None
self.init_lr = None
self.output_classes = None
self.epoch_timeout = None
self.threshold_method = None
self.termination_criteria = None
class test:
def __init__(self):
self.test_dir = None
self.device = None
self.batch_size = 1
self.save_pred_clouds = None
class Config():
def __init__(self, root_dir = ''):
with open(root_dir) as file:
self.config = yaml.safe_load(file)
# ---------------------------------------------------------------------#
# TRAIN CONFIGURATION
self.train = train()
self.train.train_dir = self.config["train"]["TRAIN_DIR"]
self.train.valid_dir = self.config["train"]["VALID_DIR"]
self.train.use_valid_data = self.config["train"]["USE_VALID_DATA"]
self.train.output_dir = self.config["train"]["OUTPUT_DIR"]
self.train.train_split = self.config["train"]["TRAIN_SPLIT"]
self.train.features = self.config["train"]["FEATURES"]
self.train.labels = self.config["train"]["LABELS"]
self.train.normalize = self.config["train"]["NORMALIZE"]
self.train.binary = self.config["train"]["BINARY"]
self.train.device = self.config["train"]["DEVICE"]
self.train.batch_size = self.config["train"]["BATCH_SIZE"]
self.train.epochs = self.config["train"]["EPOCHS"]
self.train.init_lr = self.config["train"]["LR"]
self.train.output_classes = self.config["train"]["OUTPUT_CLASSES"]
self.train.threshold_method = self.config["train"]["THRESHOLD_METHOD"]
self.train.termination_criteria = self.config["train"]["TERMINATION_CRITERIA"]
self.train.epoch_timeout = self.config["train"]["EPOCH_TIMEOUT"]
# ---------------------------------------------------------------------#
# TEST CONFIGURATION
self.test = test()
self.test.test_dir = self.config["test"]["TEST_DIR"]
self.test.device = self.config["test"]["DEVICE"]
self.test.batch_size = self.config["test"]["BATCH_SIZE"]
self.test.save_pred_clouds = self.config["test"]["SAVE_PRED_CLOUDS"]