-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain_model.py
More file actions
120 lines (103 loc) · 5.86 KB
/
train_model.py
File metadata and controls
120 lines (103 loc) · 5.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import argparse
import pickle
import os
import sys
import gzip
import torch
import pyrootutils
ROOT_DIR = pyrootutils.setup_root(__file__, indicator="README.md", pythonpath=True)
os.chdir(ROOT_DIR)
RAMIFICATIONS_DIR = os.path.join(ROOT_DIR, "ramifications")
if not os.path.exists(RAMIFICATIONS_DIR):
print(f"Ramifications datasets in {RAMIFICATIONS_DIR} does not exist. Please run the scripts to generate them.")
sys.exit(1)
MODELS_DIR = os.path.join(ROOT_DIR, "trained_models")
def main():
parser = argparse.ArgumentParser(description='Train surrogate model on an ABM')
parser.add_argument('--abm_model', type=str, default='predatorprey', help='choose between predatorprey and schelling (default: predatorprey)')
parser.add_argument('--model_type', type=str, default='GDN', help='choose between GDN, diffusion-only and gnn-only (default: GDN)')
parser.add_argument('--parameter', type=str, default='psi1', help='choose parameter between: xi1, xi2, xi3 for schelling; and psi1, psi2, psi3, psi4 for predator-prey (default: psi1)')
parser.add_argument('--learning_rate', type=float, default=1e-5, help='learning rate for the model (default: 1e-5)')
parser.add_argument('--T_diffusion', type=int, default=100, help='number of diffusion steps (default: 100)')
parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs for training (default: 100)')
args = parser.parse_args()
# Number of ramifications to use for training (500 adviced, but can be reduced for testing)
N_RAMIFICATIONS = 500
# Load the configuration
abm_model = args.abm_model
model_type = args.model_type
parameter = args.parameter
learning_rate = args.learning_rate
T_diffusion = args.T_diffusion
n_epochs = args.n_epochs
# Load the training data
with gzip.open(f'{RAMIFICATIONS_DIR}/{abm_model}/ramification_training_{parameter}.pickle.gz', 'rb') as f:
ramification_training = pickle.load(f)
# Initialize the model
if abm_model == 'predatorprey':
from models.abm.predpreyfeaturizer import PredPreyFeaturizer
featurizer = PredPreyFeaturizer()
if model_type == 'GDN':
from models.surrogate.nnmodel import NNModel
model = NNModel(n_features = 6, learning_rate=learning_rate, abm_featurizer = featurizer,diffusion_timesteps=T_diffusion,aggregation='add')
elif model_type == 'diffusion-only':
from models.surrogate.nnmodel_diffusion_only import NNModel_diffusion_only
model = NNModel_diffusion_only(n_features = 6, learning_rate=learning_rate, abm_featurizer = featurizer, diffusion_timesteps=T_diffusion,
domain_dim=featurizer.scale_abm_state(ramification_training[0][0]).flatten().shape[0])
elif model_type == 'gnn-only':
from models.surrogate.nnmodel_gnn_only import GNNModel
model = GNNModel(n_features = 6, abm_featurizer = featurizer, aggregation='add')
else:
raise ValueError(
f"model_type '{model_type}' not recognized. Choose among "
"['GDN', 'diffusion-only', 'gnn-only']."
)
elif abm_model == 'schelling':
from models.abm.schellingfeaturizer import SchellingFeaturizer
featurizer = SchellingFeaturizer()
if model_type == 'GDN':
from models.surrogate.nnmodel import NNModel
model = NNModel(n_features = 2, learning_rate=learning_rate, abm_featurizer = featurizer, diffusion_timesteps=T_diffusion, aggregation='mean')
elif model_type == 'diffusion-only':
from models.surrogate.nnmodel_diffusion_only import NNModel_diffusion_only
model = NNModel_diffusion_only(n_features = 2, learning_rate=learning_rate, abm_featurizer = featurizer, diffusion_timesteps=T_diffusion,
domain_dim=featurizer.scale_abm_state(ramification_training[0][0]).flatten().shape[0])
elif model_type == 'gnn-only':
from models.surrogate.nnmodel_gnn_only import GNNModel
model = GNNModel(n_features = 2, abm_featurizer = featurizer, aggregation='mean')
else:
raise ValueError(
f"model_type '{model_type}' not recognized. Choose among "
"['GDN', 'diffusion-only', 'gnn-only']."
)
model.train(ramifications = ramification_training, n_ramifications= 500, n_epochs=n_epochs)
# Save the model
save_dir = f'{MODELS_DIR}/{abm_model}/{model_type}/'
if not os.path.exists(save_dir):
os.makedirs(save_dir)
if model_type == 'GDN':
save_path = os.path.join(save_dir, f"model_GDN_{parameter}.pth")
torch.save({'ld_model_state_dict': model.ld_model.state_dict(),
'graph_model_state_dict': model.graph_model.state_dict(),
'diffusion_timesteps': model.diffusion_timesteps,
'aggregation': model.aggregation,
'learning_rate': model.lr_ld,
'losses': model.losses
}, save_path)
elif model_type == 'diffusion-only':
save_path = os.path.join(save_dir, f"model_diffusion-only_{parameter}.pth")
torch.save({'ld_model_state_dict': model.ld_model.state_dict(),
'diffusion_timesteps': model.diffusion_timesteps,
'learning_rate': model.lr_ld,
'losses': model.losses
}, save_path)
elif model_type == 'gnn-only':
save_path = os.path.join(save_dir, f"model_gnn-only_{parameter}.pth")
torch.save({'graph_model_state_dict': model.graph_model.state_dict(),
'learning_rate': model.lr_gnn,
'aggregation': model.aggregation,
'losses': model.losses
}, save_path)
print(f'Model saved to {save_path}')
if __name__ == "__main__":
main()