forked from google/retiming
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
134 lines (106 loc) · 5.56 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script for training a layered neural renderer on a video.
You need to specify the dataset ('--dataroot') and experiment name ('--name').
Example:
python train.py --dataroot ./datasets/reflection --name reflection --gpu_ids 0,1
The script first creates a model, dataset, and visualizer given the options.
It then does standard network training. During training, it also visualizes/saves the images, prints/saves the loss
plot, and saves the model.
Use '--continue_train' to resume your previous training.
The default setting is to first train the base model, which produces the low-resolution result (256x448), and then
train the upsampling module to produce the 512x896 result. If the upsampling module is unnecessary, use
'--n_epochs_upsample 0'.
See options/base_options.py and options/train_options.py for more training options.
"""
import time
from options.train_options import TrainOptions
from third_party.data import create_dataset
from third_party.models import create_model
from third_party.util.visualizer import Visualizer
import torch
import numpy as np
def main():
trainopt = TrainOptions()
trainopt.parse()
opt = trainopt.parse_dataset_meta()
torch.manual_seed(opt.seed)
np.random.seed(opt.seed)
opt.do_upsampling = False # Train low-res network first
dataset = create_dataset(opt, use_fast_loader=True)
dataset_size = len(dataset)
print('The number of training images = %d' % dataset_size)
model = create_model(opt)
model.setup(opt) # regular setup: load and print networks; create schedulers
visualizer = Visualizer(opt)
# Train base model (produces low-resolution output)
train(model, dataset, visualizer, opt)
# Optionally train upsampling module
if opt.n_epochs_upsample > 0:
opt.do_upsampling = True
opt.batch_size = opt.batch_size_upsample
# load dataset for upsampling
dataset = create_dataset(opt, use_fast_loader=True)
dataset_size = len(dataset)
print('The number of training images = %d' % dataset_size)
# set lambdas for upsampling training
opt.lambda_mask = 0
opt.lambda_alpha_l0 = 0
opt.lambda_alpha_l1 = 0
opt.mask_loss_rolloff_epoch = -1
opt.jitter_rgb = 0
# reinit optimizers and schedulers, lambdas
model.setup_train(opt)
# freeze base model and just train upsampling module
model.freeze_basenet()
model.setup(opt)
# update epoch count to resume training
opt.epoch_count = opt.n_epochs + opt.n_epochs_decay + 1
opt.n_epochs += opt.n_epochs_upsample
train(model, dataset, visualizer, opt)
def train(model, dataset, visualizer, opt):
dataset_size = len(dataset)
total_iters = 0 # the total number of training iterations
for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1): # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
epoch_start_time = time.time() # timer for entire epoch
iter_data_time = time.time() # timer for data loading per iteration
epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch
model.update_lambdas(epoch)
for i, data in enumerate(dataset): # inner loop within one epoch
iter_start_time = time.time() # timer for computation per iteration
if i % opt.print_freq == 0:
t_data = iter_start_time - iter_data_time
total_iters += opt.batch_size
epoch_iter += opt.batch_size
model.set_input(data)
model.optimize_parameters()
if i % opt.print_freq == 0: # print training losses and save logging information to the disk
losses = model.get_current_losses()
t_comp = (time.time() - iter_start_time) / opt.batch_size
visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
if opt.display_id > 0:
visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)
iter_data_time = time.time()
if epoch % opt.display_freq == 1: # display images on visdom and save images to a HTML file
save_result = epoch % opt.update_html_freq == 1
model.compute_visuals()
visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
if epoch % opt.save_latest_freq == 0: # cache our latest model every <save_latest_freq> epochs
print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
save_suffix = 'epoch_%d' % epoch if opt.save_by_epoch else 'latest'
model.save_networks(save_suffix)
model.update_learning_rate() # update learning rates at the end of every epoch.
print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time))
if __name__ == '__main__':
main()