forked from fracozzi/ABM-Graph-Diffusion-Network
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_model.py
More file actions
90 lines (74 loc) · 4.25 KB
/
train_model.py
File metadata and controls
90 lines (74 loc) · 4.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import argparse
import pickle
import os
import sys
import gzip
import torch
import pyrootutils
ROOT_DIR = pyrootutils.setup_root(__file__, indicator="README.md", pythonpath=True)
from models.surrogate.nnmodel import NNModel
from models.surrogate.nnmodel_ablation import NNModel_ablation
from models.abm.predpreyfeaturizer import PredPreyFeaturizer
from models.abm.schellingfeaturizer import SchellingFeaturizer
os.chdir(ROOT_DIR)
RAMIFICATIONS_DIR = os.path.join(ROOT_DIR, "ramifications")
if not os.path.exists(RAMIFICATIONS_DIR):
print(f"Ramifications datasets in {RAMIFICATIONS_DIR} does not exist. Please run the scripts to generate them.")
sys.exit(1)
MODELS_DIR = os.path.join(ROOT_DIR, "trained_models")
def main():
parser = argparse.ArgumentParser(description='Train surrogate model on an ABM')
parser.add_argument('--abm_model', type=str, default='predatorprey', help='choose between predatorprey and schelling (default: predatorprey)')
parser.add_argument('--model_type', type=str, default='surrogate', help='choose between surrogate and ablation (default: surrogate)')
parser.add_argument('--parameter', type=str, default='psi1', help='choose parameter between: xi1, xi2, xi3 for schelling; and psi1, psi2, psi3, psi4 for predator-prey (default: psi1)')
parser.add_argument('--learning_rate', type=float, default=1e-5, help='learning rate for the model (default: 1e-5)')
parser.add_argument('--T_diffusion', type=int, default=100, help='number of diffusion steps (default: 100)')
parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs for training (default: 100)')
args = parser.parse_args()
# Load the configuration
abm_model = args.abm_model
model_type = args.model_type
parameter = args.parameter
learning_rate = args.learning_rate
T_diffusion = args.T_diffusion
n_epochs = args.n_epochs
# Load the training data
with gzip.open(f'{RAMIFICATIONS_DIR}/{abm_model}/ramification_training_{parameter}.pickle.gz', 'rb') as f:
ramification_training = pickle.load(f)
# Initialize the model
if abm_model == 'predatorprey':
featurizer = PredPreyFeaturizer()
if model_type == 'surrogate':
model = NNModel(n_features = 6, learning_rate=learning_rate, abm_featurizer = featurizer,diffusion_timesteps=T_diffusion,aggregation='add')
elif model_type == 'ablation':
model = NNModel_ablation(n_features = 6, learning_rate=learning_rate, abm_featurizer = featurizer, diffusion_timesteps=T_diffusion,
domain_dim=featurizer.scale_abm_state(ramification_training[0][0]).flatten().shape[0])
elif abm_model == 'schelling':
featurizer = SchellingFeaturizer()
if model_type == 'surrogate':
model = NNModel(n_features = 2, learning_rate=learning_rate, abm_featurizer = featurizer, diffusion_timesteps=T_diffusion, aggregation='mean')
elif model_type == 'ablation':
model = NNModel_ablation(n_features = 2, learning_rate=learning_rate, abm_featurizer = featurizer, diffusion_timesteps=T_diffusion,
domain_dim=featurizer.scale_abm_state(ramification_training[0][0]).flatten().shape[0])
model.train(ramification_training, n_epochs=n_epochs)
# Save the model
save_dir = f'{MODELS_DIR}/{abm_model}/{model_type}/'
if not os.path.exists(save_dir):
os.makedirs(save_dir)
if model_type == 'surrogate':
save_path = os.path.join(save_dir, f"model_surrogate_{parameter}.pth")
torch.save({'ld_model_state_dict': model.ld_model.state_dict(),
'graph_model_state_dict': model.graph_model.state_dict(),
'diffusion_timesteps': model.diffusion_timesteps,
'aggregation': model.aggregation,
'learning_rate': model.lr_ld,
'losses': model.losses
}, save_path)
if model_type == 'ablation':
save_path = os.path.join(save_dir, f"model_ablation_{parameter}.pth")
torch.save({'ld_model_state_dict': model.ld_model.state_dict(),
'losses': model.losses
}, save_path)
print(f'Model saved to {save_path}')
if __name__ == "__main__":
main()