nmt icon indicating copy to clipboard operation
nmt copied to clipboard

'AttentionWrapper' object has no attribute 'zero_state'

Open Neel125 opened this issue 6 years ago • 4 comments

def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,
                        source_sequence_length):
	"""Build a RNN cell with attention mechanism that can be used by decoder."""
	# No Attention
	if not self.has_attention:
		return super(AttentionModel, self)._build_decoder_cell(
			hparams, encoder_outputs, encoder_state, source_sequence_length)
	elif hparams["attention_architecture"] != "standard":
		raise ValueError(
			"Unknown attention architecture %s" % hparams["attention_architecture"])

	num_units = hparams["num_units"]
	num_layers = self.num_decoder_layers
	num_residual_layers = self.num_decoder_residual_layers
	infer_mode = hparams["infer_mode"]

	dtype = tf.float32

	# Ensure memory is batch-major
	if self.time_major:
		memory = tf.transpose(encoder_outputs, [1, 0, 2])
	else:
		memory = encoder_outputs

	if (self.mode == tf.estimator.ModeKeys.PREDICT and
			infer_mode == "beam_search"):
		memory, source_sequence_length, encoder_state, batch_size = (
			self._prepare_beam_search_decoder_inputs(
				hparams["beam_width"], memory, source_sequence_length,
				encoder_state))
	else:
		batch_size = self.batch_size

	# Attention
	attention_mechanism = self.attention_mechanism_fn(
		hparams["attention"], num_units, memory, source_sequence_length, self.mode)

	cell = model_helper.create_rnn_cell(
		unit_type=hparams["unit_type"],
		num_units=num_units,
		num_layers=num_layers,
		num_residual_layers=num_residual_layers,
		forget_bias=hparams["forget_bias"],
		dropout=hparams["dropout"],
		num_gpus=self.num_gpus,
		mode=self.mode,
		single_cell_fn=self.single_cell_fn)

	# Only generate alignment in greedy INFER mode.
	alignment_history = (self.mode == tf.estimator.ModeKeys.PREDICT and
	                     infer_mode != "beam_search")
	cell = tfa.seq2seq.AttentionWrapper(
		cell,
		attention_mechanism,
		attention_layer_size=num_units,
		alignment_history=alignment_history,
		output_attention=hparams["output_attention"],
		name="attention")

	# TODO(thangluong): do we need num_layers, num_gpus?
	device = tf.device(model_helper.get_device_str(num_layers-1, self.num_gpus))

	cell = tf.nn.rnn_cell.DeviceWrapper(cell,
	                                    device)
	cell = tf.nn.rnn_cell.DropoutWrapper(cell, input_keep_prob=0.8)
	if hparams["pass_hidden_state"]:
		decoder_initial_state = cell.zero_state(batch_size=batch_size*hparams["beam_width"], dtype=dtype).clone(
			cell_state=encoder_state)
	else:
		decoder_initial_state = cell.zero_state(batch_size=batch_size*hparams["beam_width"], dtype=dtype)

	return cell, decoder_initial_state

Error: File "/home/ml-ai4/Neel-dev023/ChatBot/nmt-chatbot/nmt/nmt/attention_model.py", line 144, in _build_decoder_cell decoder_initial_state = cell.zero_state(batch_size=batch_size*hparams["beam_width"], dtype=dtype).clone( File "/home/ml-ai4/Neel-dev023/ChatBot/nmt-chatbot/venv/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn_cell_wrapper_impl.py", line 199, in zero_state return self.cell.zero_state(batch_size, dtype) File "/home/ml-ai4/Neel-dev023/ChatBot/nmt-chatbot/venv/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn_cell_wrapper_impl.py", line 431, in zero_state return self.cell.zero_state(batch_size, dtype) AttributeError: 'AttentionWrapper' object has no attribute 'zero_state'

Neel125 avatar Jan 08 '20 05:01 Neel125

Facing the same issue while try to use tensorflow_addons with tf V2.X image

Abonia1 avatar May 15 '20 10:05 Abonia1

Capture Facing the same issue with tensorflow version 2.x

princebaretto99 avatar May 16 '20 13:05 princebaretto99

Got the solution : Just replace zero_state with get_inital_state, because the function get_initial_state returns an AttentionWrapperState tuple containing zeroed out tensors same as zero_state

princebaretto99 avatar May 16 '20 15:05 princebaretto99

Hello @princebaretto99 I have already found this solution the same day I encountered this issue but really sorry because I forget to update it here in github.Zero_state issue is resolved by using get_initial_state; Thank you for your solution. image

Abonia1 avatar May 17 '20 10:05 Abonia1