probability
probability copied to clipboard
save_weights causes ValueError when DistributionLambda is used with JointDistributionSequential
Hello!
I am currently trying to use JointDistributionSequential to predict multiple distributions using a Mixture Density Network.
Minimal example:
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd
from tensorflow_probability import bijectors as tfb
from functools import partial
neurons = 32
components = 2
no_dists = 20
HiddenLayer = partial(
tf.keras.layers.Dense,
activation="elu",
kernel_initializer="he_normal",
kernel_regularizer=tf.keras.regularizers.l2(0.01)
)
OutputLayer = partial(
tf.keras.layers.Dense,
activation="linear",
#kernel_regularizer=tf.keras.regularizers.l2(0.001)
)
inputs = tf.keras.layers.Input(shape=(1,))
h1 = HiddenLayer(neurons)(inputs)
h2 = HiddenLayer(neurons/2)(h1)
logits = OutputLayer(no_dists*components, name="logits")(h2)
logits_rshpd = tf.keras.layers.Reshape((no_dists,components))(logits)
locs = OutputLayer(no_dists*components, name="locs")(h2)
locs_rshpd = tf.keras.layers.Reshape((no_dists,components))(locs)
scales = OutputLayer(no_dists*components, activation='softplus', name="log_scales")(h2)
scales_rshpd = tf.keras.layers.Reshape((no_dists,components))(scales)
def joint(pvector):
logits, locs, scales = pvector
mixtures = []
for d in range(no_dists):
mixture = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(logits=logits[:,d]),
components_distribution=tfd.Normal(
loc=locs[:,d],
scale=scales[:,d]))
mixtures.append(mixture)
joint = tfd.JointDistributionSequential(mixtures)
return tfd.Blockwise(joint)
out_joint = tfp.layers.DistributionLambda(joint)(
(
logits_rshpd,
locs_rshpd,
scales_rshpd
)
)
gmm_model = tf.keras.Model(
inputs,
out_joint,
name ='mdn'
)
gmm_model.summary()
Training with the TensorFlow keras API works as expected, but when i use keras.callbacks.ModelCheckpoint or gmm_model.save_weights('test') to save the weights i get the following error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-3-6285a5a969b4> in <module>
----> 1 gmm_model.save_weights('test')
~/miniconda3/envs/tfgpu/lib/python3.7/site-packages/tensorflow/python/keras/engine/network.py in save_weights(self, filepath, overwrite, save_format)
1165 'saved.\n\nConsider using a TensorFlow optimizer from `tf.train`.')
1166 % (optimizer,))
-> 1167 self._trackable_saver.save(filepath, session=session)
1168 # Record this checkpoint so it's visible from tf.train.latest_checkpoint.
1169 checkpoint_management.update_checkpoint_state_internal(
~/miniconda3/envs/tfgpu/lib/python3.7/site-packages/tensorflow/python/training/tracking/util.py in save(self, file_prefix, checkpoint_number, session)
1185 file_io.recursive_create_dir(os.path.dirname(file_prefix))
1186 save_path, new_feed_additions = self._save_cached_when_graph_building(
-> 1187 file_prefix=file_prefix_tensor, object_graph_tensor=object_graph_tensor)
1188 if new_feed_additions:
1189 feed_dict.update(new_feed_additions)
~/miniconda3/envs/tfgpu/lib/python3.7/site-packages/tensorflow/python/training/tracking/util.py in _save_cached_when_graph_building(self, file_prefix, object_graph_tensor)
1125 (named_saveable_objects, graph_proto,
1126 feed_additions) = self._gather_saveables(
-> 1127 object_graph_tensor=object_graph_tensor)
1128 if (self._last_save_object_graph != graph_proto
1129 # When executing eagerly, we need to re-create SaveableObjects each time
~/miniconda3/envs/tfgpu/lib/python3.7/site-packages/tensorflow/python/training/tracking/util.py in _gather_saveables(self, object_graph_tensor)
1093 """Wraps _serialize_object_graph to include the object graph proto."""
1094 (named_saveable_objects, graph_proto,
-> 1095 feed_additions) = self._graph_view.serialize_object_graph()
1096 if object_graph_tensor is None:
1097 with ops.device("/cpu:0"):
~/miniconda3/envs/tfgpu/lib/python3.7/site-packages/tensorflow/python/training/tracking/graph_view.py in serialize_object_graph(self)
377 ValueError: If there are invalid characters in an optimizer's slot names.
378 """
--> 379 trackable_objects, path_to_root = self._breadth_first_traversal()
380 return self._serialize_gathered_objects(
381 trackable_objects, path_to_root)
~/miniconda3/envs/tfgpu/lib/python3.7/site-packages/tensorflow/python/training/tracking/graph_view.py in _breadth_first_traversal(self)
197 % (current_trackable,))
198 bfs_sorted.append(current_trackable)
--> 199 for name, dependency in self.list_dependencies(current_trackable):
200 if dependency not in path_to_root:
201 path_to_root[dependency] = (
~/miniconda3/envs/tfgpu/lib/python3.7/site-packages/tensorflow/python/training/tracking/graph_view.py in list_dependencies(self, obj)
157 # pylint: disable=protected-access
158 obj._maybe_initialize_trackable()
--> 159 return obj._checkpoint_dependencies
160 # pylint: enable=protected-access
161
~/miniconda3/envs/tfgpu/lib/python3.7/site-packages/tensorflow/python/training/tracking/data_structures.py in __getattribute__(self, name)
740 # in particular seems to look up properties on the wrapped object instead
741 # of the wrapper without this logic.
--> 742 return object.__getattribute__(self, name)
743 else:
744 return super(_DictWrapper, self).__getattribute__(name)
~/miniconda3/envs/tfgpu/lib/python3.7/site-packages/tensorflow/python/training/tracking/data_structures.py in _checkpoint_dependencies(self)
781 "mutable data structure.\n\nIf you don't need this dictionary "
782 "checkpointed, wrap it in a non-trackable "
--> 783 "object; it will be subsequently ignored." % (self,))
784 if self._self_external_modification:
785 raise ValueError(
ValueError: Unable to save the object {140228692728528: ListWrapper([<tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_1' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_2' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_3' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_4' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_5' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_6' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_7' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_8' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_9' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_10' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_11' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_12' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_13' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_14' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_15' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_16' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_17' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_18' batch_shape=[?] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'distribution_lambda_MixtureSameFamily_19' batch_shape=[?] event_shape=[] dtype=float32>])} (a dictionary wrapper constructed automatically on attribute assignment). The wrapped dictionary contains a non-string key which maps to a trackable object or mutable data structure.
If you don't need this dictionary checkpointed, wrap it in a non-trackable object; it will be subsequently ignored.
What am i doing wrong here?
Thank you very much for your help!
Did you solve it? 🤔
Did you solve it? thinking
No. I decided to implement it as a loss function instead.
Custom Loss Function:
import tensorflow as tf
from tensorflow.keras.losses import Loss
from thesis.distributions import MixedNormal
from thesis.distributions import MixedLogNormal
class MixtedDensityLoss(Loss):
def __init__(
self,
log_normal=False,
**kwargs):
if log_normal:
self.mixed_density = MixedLogNormal()
else:
self.mixed_density = MixedNormal()
super().__init__(**kwargs)
def call(self, y, pvector):
dist = self.mixed_density(pvector)
y = tf.squeeze(y)
nll = -dist.log_prob(y)
return nll
Gaussian Mixture Model:
class MixedNormal():
def __init__(self):
pass
def __call__(self, pvector):
mixture = self.gen_mixture(pvector)
return mixture
def slice_parameter_vectors(self, pvector):
""" Returns an unpacked list of paramter vectors.
"""
num_dist = pvector.shape[1]
sliced_pvectors = []
for d in range(num_dist):
sliced_pvector = [pvector[:, d, p] for p in range(3)]
sliced_pvectors.append(sliced_pvector)
return sliced_pvectors
def gen_mixture(self, out):
pvs = self.slice_parameter_vectors(out)
mixtures = []
for pv in pvs:
logits, locs, log_scales = pv
scales = tf.math.softmax(log_scales)
mixtures.append(
tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(logits=logits),
components_distribution=tfd.Normal(
loc=locs,
scale=scales))
)
joint = tfd.JointDistributionSequential(
mixtures, name='joint_mixtures')
blkws = tfd.Blockwise(joint)
return blkws
I'm having the same issue. Minimal example here: #1681