-
Notifications
You must be signed in to change notification settings - Fork 45.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tf_upgrade_v2 on resnet and utils folders. #6154
Changes from 13 commits
4b921e5
ce73f81
39b1273
3ab2869
c4ece8a
17a777d
e57a082
9a4f074
d1f5d1a
8b18b5b
69ce0a2
682736c
020e374
bab39b4
9ca3ce7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,7 +52,7 @@ | |
############################################################################### | ||
def get_filenames(is_training, data_dir): | ||
"""Returns a list of filenames.""" | ||
assert tf.gfile.Exists(data_dir), ( | ||
assert tf.io.gfile.exists(data_dir), ( | ||
'Run cifar10_download_and_extract.py first to download and extract the ' | ||
'CIFAR-10 data.') | ||
|
||
|
@@ -68,7 +68,7 @@ def get_filenames(is_training, data_dir): | |
def parse_record(raw_record, is_training, dtype): | ||
"""Parse CIFAR-10 image and label from a raw record.""" | ||
# Convert bytes to a vector of uint8 that is record_bytes long. | ||
record_vector = tf.decode_raw(raw_record, tf.uint8) | ||
record_vector = tf.io.decode_raw(raw_record, tf.uint8) | ||
|
||
# The first byte represents the label, which we convert from uint8 to int32 | ||
# and then to one-hot. | ||
|
@@ -81,7 +81,7 @@ def parse_record(raw_record, is_training, dtype): | |
|
||
# Convert from [depth, height, width] to [height, width, depth], and cast as | ||
# float32. | ||
image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32) | ||
image = tf.cast(tf.transpose(a=depth_major, perm=[1, 2, 0]), tf.float32) | ||
|
||
image = preprocess_image(image, is_training) | ||
image = tf.cast(image, dtype) | ||
|
@@ -97,7 +97,7 @@ def preprocess_image(image, is_training): | |
image, HEIGHT + 8, WIDTH + 8) | ||
|
||
# Randomly crop a [HEIGHT, WIDTH] section of the image. | ||
image = tf.random_crop(image, [HEIGHT, WIDTH, NUM_CHANNELS]) | ||
image = tf.image.random_crop(image, [HEIGHT, WIDTH, NUM_CHANNELS]) | ||
|
||
# Randomly flip the image horizontally. | ||
image = tf.image.random_flip_left_right(image) | ||
|
@@ -253,8 +253,9 @@ def run_cifar(flags_obj): | |
Dictionary of results. Including final accuracy. | ||
""" | ||
if flags_obj.image_bytes_as_serving_input: | ||
tf.logging.fatal('--image_bytes_as_serving_input cannot be set to True ' | ||
'for CIFAR. This flag is only applicable to ImageNet.') | ||
tf.compat.v1.logging.fatal( | ||
'--image_bytes_as_serving_input cannot be set to True for CIFAR. ' | ||
'This flag is only applicable to ImageNet.') | ||
return | ||
|
||
input_function = (flags_obj.use_synthetic_data and | ||
|
@@ -273,6 +274,6 @@ def main(_): | |
|
||
|
||
if __name__ == '__main__': | ||
tf.logging.set_verbosity(tf.logging.INFO) | ||
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was checking the docs and I think we are supposed to use absl for logging in 2.0? Looks like the upgrade script doesn't do this automatically. Can we do this in a subsequent PR perhaps?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is on our list. I was shocked that no logging moves forward in TF 2.0 without a rewrite. |
||
define_cifar_flags() | ||
absl_app.run(main) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,7 +81,7 @@ def parse_record_keras(raw_record, is_training, dtype): | |
Tuple with processed image tensor and one-hot-encoded label tensor. | ||
""" | ||
image, label = cifar_main.parse_record(raw_record, is_training, dtype) | ||
label = tf.sparse_to_dense(label, (cifar_main.NUM_CLASSES,), 1) | ||
label = tf.compat.v1.sparse_to_dense(label, (cifar_main.NUM_CLASSES,), 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you file a ticket to fix this? We should not need to do this anyway and it might be better to just figure out the right way to do this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are going to go back and file tickets on anything compat.v1 in the keras code path. |
||
return image, label | ||
|
||
|
||
|
@@ -98,7 +98,7 @@ def run(flags_obj): | |
Dictionary of training and eval stats. | ||
""" | ||
if flags_obj.enable_eager: | ||
tf.enable_eager_execution() | ||
tf.compat.v1.enable_eager_execution() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think maybe we should also add an else clause here for the opposite case so it can work in 2.0 too?
wdyt? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We will make a 2.0 branch as some things are in 2.0 that cannot be done in 1.0 like saving a model in 2.0 format, the call just does not exist. |
||
|
||
dtype = flags_core.get_tf_dtype(flags_obj) | ||
if dtype == 'fp16': | ||
|
@@ -194,7 +194,7 @@ def main(_): | |
|
||
|
||
if __name__ == '__main__': | ||
tf.logging.set_verbosity(tf.logging.INFO) | ||
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) | ||
cifar_main.define_cifar_flags() | ||
keras_common.define_keras_flags() | ||
absl_app.run(main) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -88,7 +88,7 @@ def run(flags_obj): | |
ValueError: If fp16 is passed as it is not currently supported. | ||
""" | ||
if flags_obj.enable_eager: | ||
tf.enable_eager_execution() | ||
tf.compat.v1.enable_eager_execution() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here re: disable eager in else case |
||
|
||
dtype = flags_core.get_tf_dtype(flags_obj) | ||
if dtype == 'fp16': | ||
|
@@ -187,7 +187,7 @@ def main(_): | |
|
||
|
||
if __name__ == '__main__': | ||
tf.logging.set_verbosity(tf.logging.INFO) | ||
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) | ||
imagenet_main.define_imagenet_flags() | ||
keras_common.define_keras_flags() | ||
absl_app.run(main) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't we use absl`s app.run here? (simiar to cifar10_main.py)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can ask and then see if they can fix it. We can also manually change it.