aaljabari commited on
Commit
8cc99ac
·
verified ·
1 Parent(s): 367883f

Create train.py

Browse files
Files changed (1) hide show
  1. Nested/bin/train.py +222 -0
Nested/bin/train.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import json
4
+ import argparse
5
+ import torch.utils.tensorboard
6
+ from torchvision import *
7
+ import pickle
8
+ from Nested.utils.data import get_dataloaders, parse_conll_files
9
+ from Nested.utils.helpers import logging_config, load_object, make_output_dirs, set_seed
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def parse_args():
15
+ parser = argparse.ArgumentParser(
16
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
17
+ )
18
+
19
+ parser.add_argument(
20
+ "--output_path",
21
+ type=str,
22
+ required=True,
23
+ help="Output path",
24
+ )
25
+
26
+ parser.add_argument(
27
+ "--train_path",
28
+ type=str,
29
+ required=True,
30
+ help="Path to training data",
31
+ )
32
+
33
+ parser.add_argument(
34
+ "--val_path",
35
+ type=str,
36
+ required=True,
37
+ help="Path to training data",
38
+ )
39
+
40
+ parser.add_argument(
41
+ "--test_path",
42
+ type=str,
43
+ required=True,
44
+ help="Path to training data",
45
+ )
46
+
47
+ parser.add_argument(
48
+ "--bert_model",
49
+ type=str,
50
+ default="aubmindlab/bert-base-arabertv2",
51
+ help="BERT model",
52
+ )
53
+
54
+ parser.add_argument(
55
+ "--gpus",
56
+ type=int,
57
+ nargs="+",
58
+ default=[0],
59
+ help="GPU IDs to train on",
60
+ )
61
+
62
+ parser.add_argument(
63
+ "--log_interval",
64
+ type=int,
65
+ default=10,
66
+ help="Log results every that many timesteps",
67
+ )
68
+
69
+ parser.add_argument(
70
+ "--batch_size",
71
+ type=int,
72
+ default=32,
73
+ help="Batch size",
74
+ )
75
+
76
+ parser.add_argument(
77
+ "--num_workers",
78
+ type=int,
79
+ default=0,
80
+ help="Dataloader number of workers",
81
+ )
82
+
83
+ parser.add_argument(
84
+ "--data_config",
85
+ type=json.loads,
86
+ default='{"fn": "Nested.data.datasets.DefaultDataset", "kwargs": {"max_seq_len": 512}}',
87
+ help="Dataset configurations",
88
+ )
89
+
90
+ parser.add_argument(
91
+ "--trainer_config",
92
+ type=json.loads,
93
+ default='{"fn": "Nested.trainers.BertTrainer", "kwargs": {"max_epochs": 50}}',
94
+ help="Trainer configurations",
95
+ )
96
+
97
+ parser.add_argument(
98
+ "--network_config",
99
+ type=json.loads,
100
+ default='{"fn": "Nested.nn.BertSeqTagger", "kwargs": '
101
+ '{"dropout": 0.1, "bert_model": "aubmindlab/bert-base-arabertv2"}}',
102
+ help="Network configurations",
103
+ )
104
+
105
+ parser.add_argument(
106
+ "--optimizer",
107
+ type=json.loads,
108
+ default='{"fn": "torch.optim.AdamW", "kwargs": {"lr": 0.0001}}',
109
+ help="Optimizer configurations",
110
+ )
111
+
112
+ parser.add_argument(
113
+ "--lr_scheduler",
114
+ type=json.loads,
115
+ default='{"fn": "torch.optim.lr_scheduler.ExponentialLR", "kwargs": {"gamma": 1}}',
116
+ help="Learning rate scheduler configurations",
117
+ )
118
+
119
+ parser.add_argument(
120
+ "--loss",
121
+ type=json.loads,
122
+ default='{"fn": "torch.nn.CrossEntropyLoss", "kwargs": {}}',
123
+ help="Loss function configurations",
124
+ )
125
+
126
+ parser.add_argument(
127
+ "--overwrite",
128
+ action="store_true",
129
+ help="Overwrite output directory",
130
+ )
131
+
132
+ parser.add_argument(
133
+ "--seed",
134
+ type=int,
135
+ default=1,
136
+ help="Seed for random initialization",
137
+ )
138
+
139
+ args = parser.parse_args()
140
+
141
+ return args
142
+
143
+
144
+ def main(args):
145
+ make_output_dirs(
146
+ args.output_path,
147
+ subdirs=("tensorboard", "checkpoints"),
148
+ overwrite=args.overwrite,
149
+ )
150
+
151
+ # Set the seed for randomization
152
+ set_seed(args.seed)
153
+
154
+ logging_config(os.path.join(args.output_path, "train.log"))
155
+ summary_writer = torch.utils.tensorboard.SummaryWriter(
156
+ os.path.join(args.output_path, "tensorboard")
157
+ )
158
+ os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(gpu) for gpu in args.gpus])
159
+
160
+ # Get the datasets and vocab for tags and tokens
161
+ datasets, vocab = parse_conll_files((args.train_path, args.val_path, args.test_path))
162
+
163
+ if "Nested" in args.network_config["fn"]:
164
+ args.network_config["kwargs"]["num_labels"] = [len(v) for v in vocab.tags[1:]]
165
+ else:
166
+ args.network_config["kwargs"]["num_labels"] = len(vocab.tags[0])
167
+
168
+ args.data_config["kwargs"]["bert_model"] = args.network_config["kwargs"]["bert_model"]
169
+
170
+ # Save tag vocab to desk
171
+ with open(os.path.join(args.output_path, "tag_vocab.pkl"), "wb") as fh:
172
+ pickle.dump(vocab.tags, fh)
173
+
174
+ # Write config to file
175
+ args_file = os.path.join(args.output_path, "args.json")
176
+ with open(args_file, "w") as fh:
177
+ logger.info("Writing config to %s", args_file)
178
+ json.dump(args.__dict__, fh, indent=4)
179
+
180
+ # From the datasets generate the dataloaders
181
+ train_dataloader, val_dataloader, test_dataloader = get_dataloaders(
182
+ datasets, vocab, args.data_config, args.batch_size, args.num_workers
183
+ )
184
+
185
+ model = load_object(args.network_config["fn"], args.network_config["kwargs"])
186
+ model = torch.nn.DataParallel(model, device_ids=range(len(args.gpus)))
187
+
188
+ if torch.cuda.is_available():
189
+ model = model.cuda()
190
+
191
+ args.optimizer["kwargs"]["params"] = model.parameters()
192
+ optimizer = load_object(args.optimizer["fn"], args.optimizer["kwargs"])
193
+
194
+ args.lr_scheduler["kwargs"]["optimizer"] = optimizer
195
+ if "num_training_steps" in args.lr_scheduler["kwargs"]:
196
+ args.lr_scheduler["kwargs"]["num_training_steps"] = args.max_epochs * len(
197
+ train_dataloader
198
+ )
199
+
200
+ scheduler = load_object(args.lr_scheduler["fn"], args.lr_scheduler["kwargs"])
201
+ loss = load_object(args.loss["fn"], args.loss["kwargs"])
202
+
203
+ args.trainer_config["kwargs"].update({
204
+ "model": model,
205
+ "optimizer": optimizer,
206
+ "scheduler": scheduler,
207
+ "loss": loss,
208
+ "train_dataloader": train_dataloader,
209
+ "val_dataloader": val_dataloader,
210
+ "test_dataloader": test_dataloader,
211
+ "log_interval": args.log_interval,
212
+ "summary_writer": summary_writer,
213
+ "output_path": args.output_path
214
+ })
215
+
216
+ trainer = load_object(args.trainer_config["fn"], args.trainer_config["kwargs"])
217
+ trainer.train()
218
+ return
219
+
220
+
221
+ if __name__ == "__main__":
222
+ main(parse_args())