agents icon indicating copy to clipboard operation
agents copied to clipboard

CategoricalDqnAgent Distributional Training Loss Function error

Open jacklu333333 opened this issue 3 years ago • 0 comments

Hi, I am using CategoricalDqnAgent with multiple GPUs, this is my code block.

with strategy.scope():

    collect_op_episode.run()

    experience, buffer_info = tf_agents.utils.eager_utils.get_next(iterator)
    train_loss = agent.train(experience)

and I got the following error.

RuntimeError                              Traceback (most recent call last)
File <timed exec>:30

File ~/.local/lib/python3.10/site-packages/tensorflow/python/util/traceback_utils.py:153, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    151 except Exception as e:
    152   filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153   raise e.with_traceback(filtered_tb) from None
    154 finally:
    155   del filtered_tb

File ~/.local/lib/python3.10/site-packages/tf_agents/agents/tf_agent.py:330, in TFAgent.train(self, experience, weights, **kwargs)
    325   raise RuntimeError(
    326       "Cannot find _train_fn.  Did %s.__init__ call super?"
    327       % type(self).__name__)
    329 if self._enable_functions:
--> 330   loss_info = self._train_fn(
    331       experience=experience, weights=weights, **kwargs)
    332 else:
    333   loss_info = self._train(experience=experience, weights=weights, **kwargs)

File ~/.local/lib/python3.10/site-packages/tf_agents/utils/common.py:188, in function_in_tf1.<locals>.maybe_wrap.<locals>.with_check_resource_vars(*fn_args, **fn_kwargs)
    184 check_tf1_allowed()
    185 if has_eager_been_enabled():
    186   # We're either in eager mode or in tf.function mode (no in-between); so
    187   # autodep-like behavior is already expected of fn.
--> 188   return fn(*fn_args, **fn_kwargs)
    189 if not resource_variables_enabled():
    190   raise RuntimeError(MISSING_RESOURCE_VARIABLES_ERROR)

File ~/.local/lib/python3.10/site-packages/tf_agents/agents/dqn/dqn_agent.py:393, in DqnAgent._train(self, experience, weights)
    391 def _train(self, experience, weights):
    392   with tf.GradientTape() as tape:
--> 393     loss_info = self._loss(
    394         experience,
    395         td_errors_loss_fn=self._td_errors_loss_fn,
    396         gamma=self._gamma,
    397         reward_scale_factor=self._reward_scale_factor,
    398         weights=weights,
    399         training=True)
    400   tf.debugging.check_numerics(loss_info.loss, 'Loss is inf or nan')
    401   variables_to_train = self._q_network.trainable_weights

File ~/.local/lib/python3.10/site-packages/tf_agents/agents/categorical_dqn/categorical_dqn_agent.py:436, in CategoricalDqnAgent._loss(self, experience, td_errors_loss_fn, gamma, reward_scale_factor, weights, training)
    431 else:
    432   critic_loss = tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2(
    433       labels=target_distribution,
    434       logits=chosen_action_logits)
--> 436 agg_loss = common.aggregate_losses(
    437     per_example_loss=critic_loss,
    438     regularization_loss=self._q_network.losses)
    439 total_loss = agg_loss.total_loss
    441 dict_losses = {'critic_loss': agg_loss.weighted,
    442                'reg_loss': agg_loss.regularization,
    443                'total_loss': total_loss}

File ~/.local/lib/python3.10/site-packages/tf_agents/utils/common.py:1376, in aggregate_losses(per_example_loss, sample_weight, global_batch_size, regularization_loss)
   1372     per_example_loss = tf.reduce_mean(per_example_loss, range(1, loss_rank))
   1374   global_batch_size = global_batch_size and tf.cast(global_batch_size,
   1375                                                     per_example_loss.dtype)
-> 1376   weighted_loss = tf.nn.compute_average_loss(
   1377       per_example_loss,
   1378       global_batch_size=global_batch_size)
   1379   total_loss = weighted_loss
   1380 # Add scaled regularization losses.

RuntimeError: You are calling `compute_average_loss` in cross replica context, while it was expected to be called in replica context.

It seems that the loss computed by the CategoricalDqnAgent doesn't support these features?

Best Regards, Jack Lu

jacklu333333 avatar Jan 07 '23 15:01 jacklu333333