-
Notifications
You must be signed in to change notification settings - Fork 125
/
Copy pathrun_bert_multitask.py
411 lines (352 loc) · 16.4 KB
/
run_bert_multitask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
# AUTOGENERATED! DO NOT EDIT! File to edit: source_nbs/14_run_bert_multitask.ipynb (unless otherwise specified).
__all__ = ['LOGGER', 'create_keras_model', 'get_params_ready', 'train_bert_multitask', 'trim_checkpoint_for_prediction',
'eval_bert_multitask', 'predict_bert_multitask']
# Cell
import argparse
import os
import time
from typing import Dict, Callable
from shutil import copytree, ignore_patterns, rmtree
import tensorflow as tf
from tensorflow.python.framework.errors_impl import NotFoundError as TFNotFoundError
from .input_fn import predict_input_fn, train_eval_input_fn
from .model_fn import BertMultiTask
from .params import DynamicBatchSizeParams, BaseParams
from .special_tokens import EVAL
# Fix duplicate log
LOGGER = tf.get_logger()
LOGGER.propagate = False
# Cell
def create_keras_model(
mirrored_strategy: tf.distribute.MirroredStrategy,
params: BaseParams,
mode='train',
inputs_to_build_model=None,
model=None):
"""init model in various mode
train: model will be loaded from huggingface
resume: model will be loaded from params.ckpt_dir, if params.ckpt_dir dose not contain valid checkpoint, then load from huggingface
transfer: model will be loaded from params.init_checkpoint, the correspongding path should contain checkpoints saved using bert-multitask-learning
predict: model will be loaded from params.ckpt_dir except optimizers' states
eval: model will be loaded from params.ckpt_dir except optimizers' states, model will be compiled
Args:
mirrored_strategy (tf.distribute.MirroredStrategy): mirrored strategy
params (BaseParams): params
mode (str, optional): Mode, see above explaination. Defaults to 'train'.
inputs_to_build_model (Dict, optional): A batch of data. Defaults to None.
model (Model, optional): Keras model. Defaults to None.
Returns:
model: loaded model
"""
def _get_model_wrapper(params, mode, inputs_to_build_model, model):
if model is None:
model = BertMultiTask(params)
# model.run_eagerly = True
if mode == 'resume':
model.compile()
# build training graph
# model.train_step(inputs_to_build_model)
_ = model(inputs_to_build_model,
mode=tf.estimator.ModeKeys.PREDICT)
# load ALL vars including optimizers' states
try:
model.load_weights(os.path.join(
params.ckpt_dir, 'model'), skip_mismatch=False)
except TFNotFoundError:
LOGGER.warn('Not resuming since no mathcing ckpt found')
elif mode == 'transfer':
# build graph without optimizers' states
# calling compile again should reset optimizers' states but we're playing safe here
_ = model(inputs_to_build_model,
mode=tf.estimator.ModeKeys.PREDICT)
# load weights without loading optimizers' vars
model.load_weights(os.path.join(params.init_checkpoint, 'model'))
# compile again
model.compile()
elif mode == 'predict':
_ = model(inputs_to_build_model,
mode=tf.estimator.ModeKeys.PREDICT)
# load weights without loading optimizers' vars
model.load_weights(os.path.join(params.ckpt_dir, 'model'))
elif mode == 'eval':
_ = model(inputs_to_build_model,
mode=tf.estimator.ModeKeys.PREDICT)
# load weights without loading optimizers' vars
model.load_weights(os.path.join(params.ckpt_dir, 'model'))
model.compile()
else:
model.compile()
return model
if mirrored_strategy is not None:
with mirrored_strategy.scope():
model = _get_model_wrapper(params, mode, inputs_to_build_model, model)
else:
model = _get_model_wrapper(params, mode, inputs_to_build_model, model)
return model
# Cell
def _train_bert_multitask_keras_model(train_dataset: tf.data.Dataset,
eval_dataset: tf.data.Dataset,
model: tf.keras.Model,
params: BaseParams,
mirrored_strategy: tf.distribute.MirroredStrategy = None):
# can't save whole model with model subclassing api due to tf bug
# see: https://github.com/tensorflow/tensorflow/issues/42741
# https://github.com/tensorflow/tensorflow/issues/40366
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(params.ckpt_dir, 'model'),
save_weights_only=True,
monitor='val_mean_acc',
mode='auto',
save_best_only=False)
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=params.ckpt_dir)
if mirrored_strategy is not None:
with mirrored_strategy.scope():
model.fit(
x=train_dataset,
validation_data=eval_dataset,
epochs=params.train_epoch,
callbacks=[model_checkpoint_callback, tensorboard_callback],
steps_per_epoch=params.train_steps_per_epoch
)
else:
model.fit(
x=train_dataset,
validation_data=eval_dataset,
epochs=params.train_epoch,
callbacks=[model_checkpoint_callback, tensorboard_callback],
steps_per_epoch=params.train_steps_per_epoch
)
model.summary()
# Cell
def get_params_ready(problem, num_gpus, model_dir, params, problem_type_dict, processing_fn_dict, mode='train', json_path=''):
if params is None:
params = DynamicBatchSizeParams()
if not os.path.exists('models'):
os.mkdir('models')
if model_dir:
base_dir, dir_name = os.path.split(model_dir)
else:
base_dir, dir_name = None, None
# add new problem to params if problem_type_dict and processing_fn_dict provided
if problem_type_dict:
params.add_multiple_problems(
problem_type_dict=problem_type_dict, processing_fn_dict=processing_fn_dict)
if mode == 'train':
params.assign_problem(problem, gpu=int(num_gpus),
base_dir=base_dir, dir_name=dir_name)
params.to_json()
else:
params.from_json(json_path)
params.assign_problem(problem, gpu=int(num_gpus),
base_dir=base_dir, dir_name=dir_name)
return params
# Cell
def train_bert_multitask(
problem='weibo_ner',
num_gpus=1,
num_epochs=10,
model_dir='',
params: BaseParams = None,
problem_type_dict: Dict[str, str] = None,
processing_fn_dict: Dict[str, Callable] = None,
model: tf.keras.Model = None,
create_tf_record_only=False,
steps_per_epoch=None,
warmup_ratio=0.1,
continue_training=False,
mirrored_strategy=None):
"""Train Multi-task Bert model
About problem:
There are two types of chaining operations can be used to chain problems.
- `&`. If two problems have the same inputs, they can be chained using `&`.
Problems chained by `&` will be trained at the same time.
- `|`. If two problems don't have the same inputs, they need to be chained using `|`.
Problems chained by `|` will be sampled to train at every instance.
For example, `cws|NER|weibo_ner&weibo_cws`, one problem will be sampled at each turn, say `weibo_ner&weibo_cws`, then `weibo_ner` and `weibo_cws` will trained for this turn together. Therefore, in a particular batch, some tasks might not be sampled, and their loss could be 0 in this batch.
About problem_type_dict and processing_fn_dict:
If the problem is not predefined, you need to tell the model what's the new problem's problem_type
and preprocessing function.
For example, a new problem: fake_classification
problem_type_dict = {'fake_classification': 'cls'}
processing_fn_dict = {'fake_classification': lambda: return ...}
Available problem type:
cls: Classification
seq_tag: Sequence Labeling
seq2seq_tag: Sequence to Sequence tag problem
seq2seq_text: Sequence to Sequence text generation problem
Preprocessing function example:
Please refer to https://github.com/JayYip/bert-multitask-learning/blob/master/README.md
Keyword Arguments:
problem {str} -- Problems to train (default: {'weibo_ner'})
num_gpus {int} -- Number of GPU to use (default: {1})
num_epochs {int} -- Number of epochs to train (default: {10})
model_dir {str} -- model dir (default: {''})
params {BaseParams} -- Params to define training and models (default: {DynamicBatchSizeParams()})
problem_type_dict {dict} -- Key: problem name, value: problem type (default: {{}})
processing_fn_dict {dict} -- Key: problem name, value: problem data preprocessing fn (default: {{}})
"""
params = get_params_ready(problem, num_gpus, model_dir,
params, problem_type_dict, processing_fn_dict)
params.train_epoch = num_epochs
train_dataset = train_eval_input_fn(params)
eval_dataset = train_eval_input_fn(params, mode=EVAL)
if create_tf_record_only:
return
# get train_steps and update params
if steps_per_epoch is not None:
train_steps = steps_per_epoch
else:
train_steps = 0
for _ in train_dataset:
train_steps += 1
params.update_train_steps(train_steps, warmup_ratio=warmup_ratio)
train_dataset = train_eval_input_fn(params)
train_dataset = train_dataset.repeat(10)
one_batch = next(train_dataset.as_numpy_iterator())
if mirrored_strategy is None:
mirrored_strategy = tf.distribute.MirroredStrategy()
elif mirrored_strategy is False:
mirrored_strategy = None
if num_gpus > 1 and mirrored_strategy is not False:
train_dataset = mirrored_strategy.experimental_distribute_dataset(
train_dataset)
eval_dataset = mirrored_strategy.experimental_distribute_dataset(
eval_dataset)
# restore priority: self > transfer > huggingface
if continue_training and tf.train.latest_checkpoint(params.ckpt_dir):
mode = 'resume'
elif tf.train.latest_checkpoint(params.init_checkpoint):
mode = 'transfer'
else:
mode = 'train'
model = create_keras_model(
mirrored_strategy=mirrored_strategy, params=params, mode=mode, inputs_to_build_model=one_batch)
_train_bert_multitask_keras_model(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
model=model,
params=params,
mirrored_strategy=mirrored_strategy
)
return model
# Cell
def trim_checkpoint_for_prediction(problem: str,
input_dir: str,
output_dir: str,
problem_type_dict: Dict[str, str] = None,
overwrite=True,
fake_input_list=None,
params=None):
"""Minimize checkpoint size for prediction.
Since the original checkpoint contains optimizer's variable,
for instance, if the use adam, the checkpoint size will
be three times of the size of model weights. This function
will remove those unused variables in prediction to save space.
Note: if the model is a multimodal model, you have to provide fake_input_list that
mimic the structure of real input.
Args:
problem (str): problem
input_dir (str): input dir
output_dir (str): output dir
problem_type_dict (Dict[str, str], optional): problem type dict. Defaults to None.
fake_input_list (List): fake input list to create dummy dataset
"""
if overwrite and os.path.exists(output_dir):
rmtree(output_dir)
copytree(input_dir, output_dir, ignore=ignore_patterns(
'checkpoint', '*.index', '*.data-000*'))
base_dir, dir_name = os.path.split(output_dir)
if params is None:
params = DynamicBatchSizeParams()
params.add_multiple_problems(problem_type_dict=problem_type_dict)
params.from_json(os.path.join(input_dir, 'params.json'))
params.assign_problem(problem, base_dir=base_dir,
dir_name=dir_name, predicting=True)
model = BertMultiTask(params)
if fake_input_list is None:
dummy_dataset = predict_input_fn(['fake']*5, params)
else:
dummy_dataset = predict_input_fn(fake_input_list*5, params)
_ = model(next(dummy_dataset.as_numpy_iterator()),
mode=tf.estimator.ModeKeys.PREDICT)
model.load_weights(os.path.join(input_dir, 'model'))
model.save_weights(os.path.join(params.ckpt_dir, 'model'))
params.to_json()
# Cell
def eval_bert_multitask(
problem='weibo_ner',
num_gpus=1,
model_dir='',
params=None,
problem_type_dict=None,
processing_fn_dict=None,
model=None):
"""Evaluate Multi-task Bert model
Available eval_scheme:
ner, cws, acc
Keyword Arguments:
problem {str} -- problems to evaluate (default: {'weibo_ner'})
num_gpus {int} -- number of gpu to use (default: {1})
model_dir {str} -- model dir (default: {''})
eval_scheme {str} -- Evaluation scheme (default: {'ner'})
params {Params} -- params to define model (default: {DynamicBatchSizeParams()})
problem_type_dict {dict} -- Key: problem name, value: problem type (default: {{}})
processing_fn_dict {dict} -- Key: problem name, value: problem data preprocessing fn (default: {{}})
"""
if not model_dir and params is not None:
model_dir = params.ckpt_dir
params = get_params_ready(problem, num_gpus, model_dir,
params, problem_type_dict, processing_fn_dict,
mode='predict', json_path=os.path.join(model_dir, 'params.json'))
eval_dataset = train_eval_input_fn(params, mode=EVAL)
one_batch_data = next(eval_dataset.as_numpy_iterator())
eval_dataset = train_eval_input_fn(params, mode=EVAL)
mirrored_strategy = tf.distribute.MirroredStrategy()
model = create_keras_model(
mirrored_strategy=mirrored_strategy, params=params, mode='eval', inputs_to_build_model=one_batch_data)
eval_dict = model.evaluate(eval_dataset, return_dict=True)
return eval_dict
# Cell
def predict_bert_multitask(
inputs,
problem='weibo_ner',
model_dir='',
params: BaseParams = None,
problem_type_dict: Dict[str, str] = None,
processing_fn_dict: Dict[str, Callable] = None,
model: tf.keras.Model = None,
return_model=False):
"""Evaluate Multi-task Bert model
Available eval_scheme:
ner, cws, acc
Keyword Arguments:
problem {str} -- problems to evaluate (default: {'weibo_ner'})
num_gpus {int} -- number of gpu to use (default: {1})
model_dir {str} -- model dir (default: {''})
eval_scheme {str} -- Evaluation scheme (default: {'ner'})
params {Params} -- params to define model (default: {DynamicBatchSizeParams()})
problem_type_dict {dict} -- Key: problem name, value: problem type (default: {{}})
processing_fn_dict {dict} -- Key: problem name, value: problem data preprocessing fn (default: {{}})
"""
if params is None:
params = DynamicBatchSizeParams()
if not model_dir and params is not None:
model_dir = params.ckpt_dir
params = get_params_ready(problem, 1, model_dir,
params, problem_type_dict, processing_fn_dict,
mode='predict', json_path=os.path.join(model_dir, 'params.json'))
LOGGER.info('Checkpoint dir: %s', params.ckpt_dir)
time.sleep(3)
pred_dataset = predict_input_fn(inputs, params)
one_batch_data = next(pred_dataset.as_numpy_iterator())
pred_dataset = predict_input_fn(inputs, params)
mirrored_strategy = tf.distribute.MirroredStrategy()
if model is None:
model = create_keras_model(
mirrored_strategy=mirrored_strategy, params=params, mode='predict', inputs_to_build_model=one_batch_data)
with mirrored_strategy.scope():
pred = model.predict(pred_dataset)
if return_model:
return pred, model
return pred