sparsity.prune_low_magnitude fails with mixed precision policy mixed_float16
Describe the bug
When using tf.keras.mixed_precision.experimental.Policy("mixed_float16", loss_scale="dynamic") the sparsity.prune_low_magnitude fails in tensor type conversion with the error Tensor conversion requested dtype float32 for Tensor with dtype float16: <tf.Tensor 'pruning_ops/Cast_2:0' shape=(3, 3, 1, 12) dtype=float16>. Things work perfectly fine when precision is set to the default float32. Looks like some piece of code is not properly respecting the dtype.
System information
TensorFlow installed from (source or binary): pip3
TensorFlow version: 2.2.0
TensorFlow Model Optimization version:
Python version: 3.6
Describe the expected behavior
The prune_low_magnitude should work with layers using mixed_float16 policy.
Describe the current behavior
Throws error described above.
Code to reproduce the issue See this colab link.
Screenshots If applicable, add screenshots to help explain your problem.
Additional context Add any other context about the problem here.
@alanchiao any insights on when will this be fixed?
@alanchiao Is this a major change? From the stack trace it looked more like a minor bug where some piece of code is not properly respecting the dtype.
Hi @dd1923 , sorry for really late response.
Just want to check if this still bugs you before picking up this again.
Hi, I can confirm that the issue still exist
ValueError: Tensor conversion requested dtype float32 for Tensor with dtype float16: <tf.Tensor 'prune_low_magnitude_conv1/Mul:0' shape=(7, 7, 3, 64) dtype=float16>
Complete stack trace:
Traceback (most recent call last):
File "official/vision/image_classification/classifier_trainer.py", line 530, in <module>
app.run(main)
File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 303, in run
_run_main(main, args)
File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 251, in _run_main
sys.exit(main(argv))
File "official/vision/image_classification/classifier_trainer.py", line 517, in main
stats = run(flags.FLAGS)
File "official/vision/image_classification/classifier_trainer.py", line 509, in run
return train_and_eval(params, strategy_override)
File "official/vision/image_classification/classifier_trainer.py", line 435, in train_and_eval
clone_function=apply_pruning,
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/models.py", line 431, in clone_model
model, input_tensors=input_tensors, layer_fn=clone_function)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/models.py", line 201, in _clone_functional_model
created_layers=created_layers))
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py", line 1285, in reconstruct_from_config
process_node(layer, node_data)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py", line 1233, in process_node
output_tensors = layer(input_tensors, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 952, in __call__
input_list)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 1091, in _functional_construction_call
inputs, input_masks, args, kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 822, in _keras_tensor_symbolic_call
return self._infer_output_signature(inputs, args, kwargs, input_masks)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 863, in _infer_output_signature
outputs = call_fn(inputs, *args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py", line 670, in wrapper
raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:
/usr/local/lib/python3.6/dist-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py:258 call *
self.add_update(self.pruning_obj.weight_mask_op())
/usr/local/lib/python3.6/dist-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py:195 weight_mask_op *
return tf.group(self._weight_assign_objs())
/usr/local/lib/python3.6/dist-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py:168 update_var *
return tf_compat.assign(variable, reduced_value)
/usr/local/lib/python3.6/dist-packages/tensorflow_model_optimization/python/core/keras/compat.py:28 assign *
return ref.assign(value, name=name)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/mixed_precision/autocast_variable.py:237 assign **
name, read_value)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/mixed_precision/autocast_variable.py:209 _apply_assign_update
assign_op = update_fn(value, use_locking, name, False)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:882 assign
value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/profiler/trace.py:163 wrapped
return func(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:1509 convert_to_tensor
(dtype.name, value.dtype.name, value))
This is an issue for me as well. TF2.4.1 with 0.5 of tensorflow_model_optimization.
I've been trying to wrap my head around this - even tried setting various layer kwargs.
It seems the failure happens on the model clone step: keras.models.clone_model()
ValueError Traceback (most recent call last)
<command--1> in <module>
12
13 with open(filename, "rb") as f:
---> 14 exec(f.read())
15
<string> in <module>
/databricks/python/lib/python3.8/site-packages/SCVS/cli.py in main()
8 def main():
9 args = parse_args()
---> 10 action_args(args)
11
12
/databricks/python/lib/python3.8/site-packages/SCVS/cli.py in action_args(args)
79 else:
80 # run function based on specified command
---> 81 args.func(args)
/databricks/python/lib/python3.8/site-packages/SCVS/cli.py in cmd_task(args)
61 # execute the task
62 from SCVS.app import main
---> 63 main(task_id, config_file, instructions)
64 # shutdown (need to flush logs before the python interpreter terminates)
65 logging.shutdown()
/databricks/python/lib/python3.8/site-packages/SCVS/app.py in main(task_id, config_file, instructions)
45 util.flush_logger()
46 time.sleep(10)
---> 47 raise e
48
49
/databricks/python/lib/python3.8/site-packages/SCVS/app.py in main(task_id, config_file, instructions)
38 logger.info(f"starting task {task_id} (class {task_cls.__name__}): {task.summary}")
39 try:
---> 40 task.start()
41 t1 = time.perf_counter()
42 logger.info(f"task {task_id} finished after {util.format_time_span(t1-t0)}")
/databricks/python/lib/python3.8/site-packages/SCVS/pipeline/task.py in start(self)
221 """
222 self.validate_input()
--> 223 self[0;34m.run()
224 self.validate_output()
225
/databricks/python/lib/python3.8/site-packages/SCVS/pipeline/train_object_detection.py in run(self)
203 # record_summaries=True,
204 # )
--> 205 custom_train_loop(
206 pipeline_config_path=self.path_pipeline_config,
207 model_dir=self.path_model,
/databricks/python/lib/python3.8/site-packages/SCVS/ml/tf_training.py in custom_train_loop(pipeline_config_path, model_dir, config_override, train_steps, use_tpu, save_final_config, checkpoint_every_n, checkpoint_max_to_keep, record_summaries, mlflow_log_every_n, **kwargs)
293 #TODO(PdS): Investigate if we should prune other parts of the architecture: _box_predictor, etc
294
--> 295 detection_model._feature_extractor._efficientnet = tfmot.sparsity.keras.prune_low_magnitude(
296 to_prune=detection_model._feature_extractor._efficientnet,
297 pruning_schedule=pruning_schedule,
/databricks/python/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/prune.py in prune_low_magnitude(to_prune, pruning_schedule, block_size, block_pooling_type, **kwargs)
182 return _prune_list(to_prune, **params)
183 elif is_sequential_or_functional:
--> 184 return keras.models.clone_model(
185 to_prune, input_tensors=None, clone_function=_add_pruning_wrapper)
186 elif is_keras_layer:
/databricks/python/lib/python3.8/site-packages/tensorflow/python/keras/models.py in clone_model(model, input_tensors, clone_function)
428 model, input_tensors=input_tensors, layer_fn=clone_function)
429 else:
--> 430 return _clone_functional_model(
431 model, input_tensors=input_tensors, layer_fn=clone_function)
432
/databricks/python/lib/python3.8/site-packages/tensorflow/python/keras/models.py in _clone_functional_model(model, input_tensors, layer_fn)
198 # Reconstruct model from the config, using the cloned layers.
199 input_tensors, output_tensors, created_layers = (
--> 200 functional.reconstruct_from_config(model_configs,
201 created_layers=created_layers))
202 metrics_names = model.metrics_names
/databricks/python/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py in reconstruct_from_config(config, custom_objects, created_layers)
1283 if layer in unprocessed_nodes:
1284 for node_data in unprocessed_nodes.pop(layer):
-> 1285 process_node(layer, node_data)
1286
1287 input_tensors = []
/databricks/python/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py in process_node(layer, node_data)
1231 input_tensors = (
1232 base_layer_utils.unnest_if_single_tensor(input_tensors))
-> 1233 output_tensors = layer(input_tensors, **kwargs)
1234
1235 # Update node index map.
/databricks/python/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
949 # >> model = tf.keras.Model(inputs, outputs)
950 if _in_functional_construction_mode(self, inputs, args, kwargs, input_list):
--> 951 return self._functional_construction_call(inputs, args, kwargs,
952 input_list)
953
/databricks/python/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _functional_construction_call(self, inputs, args, kwargs, input_list)
1088 layer=self, inputs=inputs, build_graph=True, training=training_value):
1089 # Check input assumptions set after layer building, e.g. input shape.
-> 1090 outputs = self._keras_tensor_symbolic_call(
1091 inputs, input_masks, args, kwargs)
1092
/databricks/python/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _keras_tensor_symbolic_call(self, inputs, input_masks, args, kwargs)
820 return nest.map_structure(keras_tensor.KerasTensor, output_signature)
821 else:
--> 822 return self._infer_output_signature(inputs, args, kwargs, input_masks)
823
824 def _infer_output_signature(self, inputs, args, kwargs, input_masks):
/databricks/python/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _infer_output_signature(self, inputs, args, kwargs, input_masks)
861 # TODO(kaftan): do we maybe_build here, or have we already done it?
862 self._maybe_build(inputs)[0;34m
--> 863 outputs = call_fn(inputs, *args, **kwargs)
864
865 self._handle_activity_regularization(inputs, outputs)
/databricks/python/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
668 except Exception as e: # pylint:disable=broad-except
669 if hasattr(e, 'ag_error_metadata'):
--> 670 raise e.ag_error_metadata.to_exception(e)
671 else:
672 raise
ValueError: in user code:
/databricks/python/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py:258 call *
self.add_update(self.pruning_obj.weight_mask_op())
/databricks/python/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py:195 weight_mask_op *
return tf.group(self._weight_assign_objs())
/databricks/python/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py:190 _weight_assign_objs *
assign_objs.append(tf_compat.assign(weight, masked_weight))
/databricks/python/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/keras/compat.py:28 assign *
return ref.assign(value, name=name)
/databricks/python/lib/python3.8/site-packages/tensorflow/python/keras/mixed_precision/autocast_variable.py:236 assign **
return self._apply_assign_update(self._variable.assign, value, use_locking,
/databricks/python/lib/python3.8/site-packages/tensorflow/python/keras/mixed_precision/autocast_variable.py:209 _apply_assign_update
assign_op = update_fn(value, use_locking, name, False)
/databricks/python/lib/python3.8/site-packages/tensorflow/python/distribute/values.py:781 assign
return values_util.on_write_assign(
/databricks/python/lib/python3.8/site-packages/tensorflow/python/distribute/values_util.py:140 on_write_assign
return var._update( # pylint: disable=protected-access
/databricks/python/lib/python3.8/site-packages/tensorflow/python/distribute/values.py:940 _update
return self._update_cross_replica(update_fn, value, **kwargs)
/databricks/python/lib/python3.8/site-packages/tensorflow/python/distribute/values.py:893 _update_cross_replica
return self.distribute_strategy.extended.update(
/databricks/python/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:2494 update
return self._update(var, fn, args, kwargs, group)
/databricks/python/lib/python3.8/site-packages/tensorflow/python/distribute/mirrored_strategy.py:710 _update
fn(v, *distribute_utils.select_replica_mirrored(i, args),
/databricks/python/lib/python3.8/site-packages/tensorflow/python/distribute/values_util.py:139 <lambda> **
assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
/databricks/python/lib/python3.8/site-packages/tensorflow/python/ops/resource_variable_ops.py:882 assign
value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
/databricks/python/lib/python3.8/site-packages/tensorflow/python/profiler/trace.py:163 wrapped
return func(*args, **kwargs)
/databricks/python/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:1507 convert_to_tensor
raise ValueError(
ValueError: Tensor conversion requested dtype float32 for Tensor with dtype float16: <tf.Tensor 'prune_low_magnitude_stem_conv2d/Mul:0' shape=(3, 3, 3, 48) dtype=float16>
I find the result layer from prune_low_magnitude would end up with different dtype_policy from the original one. And here is a solution: call _set_dtype_policy after prune_low_magnitude.
for example
from tensorflow.keras import models, layers
import tensorflow_model_optimization as tfmo
d1 = layers.Dense(5, dtype='mixed_float16')
d2 = layers.Dense(5, dtype='mixed_float16')
print(d1.dtype_policy)
print(d2.dtype_policy)
d11 = tfmo.sparsity.keras.prune_low_magnitude(d1)
d22 = tfmo.sparsity.keras.prune_low_magnitude(d2)
print(d11.dtype_policy)
print(d22.dtype_policy)
d11._set_dtype_policy(d1.dtype_policy)
d22._set_dtype_policy(d2.dtype_policy)
print(d11.dtype_policy)
print(d22.dtype_policy)
# now it works fine
inp = layers.Input((10))
tensor = d11(inp)
tensor = d22(tensor)
m = models.Model(inputs=inp, outputs=tensor)
m.compile(loss='mse')
I haven't try global policy, but it should be the same