-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathmain.py
More file actions
139 lines (116 loc) · 4.94 KB
/
main.py
File metadata and controls
139 lines (116 loc) · 4.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
import argparse
from torch.utils.data import DataLoader
from src.utils import *
from src import train
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1,2"
parser = argparse.ArgumentParser(description='Meme Hatefulness detection')
# Fixed
parser.add_argument('--model', type=str, default='AverageBERT',
help='name of the model to use (Transformer, etc.)')
# Tasks
parser.add_argument('--dataset', type=str, default='meme_dataset',
help='dataset to use (default: meme_dataset)')
parser.add_argument('--data_path', type=str, default='data',
help='path for storing the dataset')
# Dropouts
parser.add_argument('--mlp_dropout', type=float, default=0.0,
help='fully connected layers dropout')
# Architecture
parser.add_argument('--bert_model', type=str, default="bert-base-cased",
help='pretrained bert model to use')
parser.add_argument('--cnn_model', type=str, default="vgg16",
help='pretrained CNN to use for image feature extraction')
parser.add_argument('--image_feature_size', type=int, default=4096,
help='image feature size extracted from pretrained CNN (default: 4096)')
parser.add_argument('--bert_hidden_size', type=int, default=768,
help='bert hidden size for each word token (default: 768)')
# Tuning
parser.add_argument('--batch_size', type=int, default=8, metavar='N',
help='batch size (default: 8)')
parser.add_argument('--max_token_length', type=int, default=50,
help='max number of tokens per sentence (default: 50)')
parser.add_argument('--clip', type=float, default=0.8,
help='gradient clip value (default: 0.8)')
parser.add_argument('--lr', type=float, default=2e-5,
help='initial learning rate (default: 2e-5)')
parser.add_argument('--optim', type=str, default='AdamW',
help='optimizer to use (default: AdamW)')
parser.add_argument('--num_epochs', type=int, default=3,
help='number of epochs (default: 3)')
parser.add_argument('--when', type=int, default=2,
help='when to decay learning rate (default: 2)')
# Logistics
parser.add_argument('--log_interval', type=int, default=100,
help='frequency of result logging (default: 100)')
parser.add_argument('--seed', type=int, default=1111,
help='random seed')
parser.add_argument('--no_cuda', action='store_true',
help='do not use cuda')
parser.add_argument('--name', type=str, default='model',
help='name of the trial (default: "model")')
parser.add_argument('--num_workers', type=int, default=6,
help='number of workers to use for DataLoaders (default: 6)')
args = parser.parse_args()
torch.manual_seed(args.seed)
dataset = str.lower(args.dataset.strip())
print(dataset)
use_cuda = False
output_dim_dict = {
'meme_dataset': 2,
'mmimdb': 27
}
criterion_dict = {
'meme_dataset': 'CrossEntropyLoss',
'mmimdb': 'BCEWithLogitsLoss'
}
torch.set_default_tensor_type('torch.FloatTensor')
if torch.cuda.is_available():
if args.no_cuda:
print("WARNING: You have a CUDA device, so you should probably not run with --no_cuda")
else:
torch.cuda.manual_seed(args.seed)
torch.set_default_tensor_type('torch.cuda.FloatTensor')
use_cuda = True
####################################################################
#
# Load the dataset
#
####################################################################
print("Start loading the data....")
train_data = get_data(args, dataset, 'train')
valid_data = get_data(args, dataset, 'dev')
test_data = get_data(args, dataset, 'test')
train_loader = DataLoader(train_data,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers)
valid_loader = DataLoader(valid_data,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers)
if test_data is None:
test_loader = None
else:
test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
print('Finish loading the data....')
####################################################################
#
# Hyperparameters
#
####################################################################
hyp_params = args
hyp_params.use_cuda = use_cuda
hyp_params.dataset = dataset
hyp_params.n_train, hyp_params.n_valid = len(train_data), len(valid_data)
if test_data is None:
pass
else:
hyp_params.n_test = len(test_data)
hyp_params.model = args.model.strip()
hyp_params.output_dim = output_dim_dict.get(dataset)
hyp_params.criterion = criterion_dict.get(dataset)
if __name__ == '__main__':
test_loss = train.initiate(hyp_params, train_loader, valid_loader, test_loader)