-
Notifications
You must be signed in to change notification settings - Fork 2
/
trainer.py
59 lines (44 loc) · 1.9 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import time
import utils
from utils import write_to_tensorboard
def train(model,
estimator,
optimizer,
train_dataset,
val_dataset,
epochs,
eval_every=100,
performance_key='elbo',
tensorboard=False,
**kwargs):
steps_pr_epoch = train_dataset.__len__()
# ---- train
start = time.time()
best = -float("inf")
for epoch in range(epochs):
for _step, train_batch in enumerate(train_dataset):
step = _step + steps_pr_epoch * epoch
# ---- one training step
loss, metrics = estimator.train_step(train_batch, model, optimizer, **kwargs)
if step % eval_every == 0:
took = time.time() - start
start = time.time()
# ---- write training to tensorboard
if tensorboard:
with estimator.train_summary_writer.as_default():
write_to_tensorboard(metrics, step)
# ---- monitor the val-set
val_metrics = {}
val_len = val_dataset.__len__().numpy()
for val_batch in val_dataset:
val_loss, _val_metrics = estimator.val_step(val_batch, model, **kwargs)
utils.sum_metrics(val_metrics, _val_metrics)
utils.scale_metrics(val_metrics, val_len)
if tensorboard:
with estimator.test_summary_writer.as_default():
write_to_tensorboard(val_metrics, step)
if val_metrics[performance_key] > best:
best = val_metrics[performance_key]
model.save_weights(filepath=estimator.save_dir, epoch=epoch, step=step)
print("Performance: {} is {:.4f}".format(performance_key, best))
estimator.print(metrics, val_metrics, epoch, epochs, took)