addons icon indicating copy to clipboard operation
addons copied to clipboard

Refactor crf_decode_forward

Open WindQAQ opened this issue 5 years ago • 4 comments

Describe the feature and the current behavior/state.

Currently, tfa.text.crf_decode_forward is designed as

  1. Create an RNN cell and pass a loop in variant transition_params into constructor.
  2. Create an RNN layer and return the call output of it.
  3. To make transition_params serializable, it's converted to list in get_config. https://github.com/tensorflow/addons/blob/master/tensorflow_addons/text/crf.py#L452-L457

There are three weird points

  1. transition_params is loop invariant, we can pass it as a constants call argument into rnn cell and rnn layer. https://www.tensorflow.org/api_docs/python/tf/keras/layers/RNN?version=nightly
  2. Two classes are instantiated but not visible to users. I don't see any similar usage in core TF. Almost all tf.keras.layers.* have it's functional counterpart, so does tf.keras.layers.RNN.
  3. As far as I'm concerned, it's uncommon to save a tensor to list in tf.keras.layers.Layer. Actually, transition_params is a tf.Variable in tfa.layers.CRF, which can be saved automatically when users invoke model.save or kinda that.

We can design either patterns in the following


Use tf.keras.backend.rnn

def _crf_decode_forward_step(inputs, states, constants):
    # Something like `CrfDecodeForwardRnnCell.call`
    pass


def crf_decode_forward(
    inputs: TensorLike,
    state: TensorLike,
    transition_params: TensorLike,
    sequence_lengths: TensorLike,
):
    (_, outputs, states) = tf.keras.backend.rnn(_crf_decode_forward_step, ..., constants=[transition_params])
    return outputs, states

Pros:

  • CrfDecodeForwardRnnCell and RNN layer are not instantiated.
  • Do not worry about serialization.

Cons:

  • Cannot enjoy tf.keras autocast, which might be helpful for mixed precision policy. However, I do not see there is any function in tf.keras supports mixed precision. crf_decode_forward as a function instead of layer subclass should outputs data type according to input data type rather than global precision policy.

Pass transition_params as constants call argument in rnn cell and layer.

class CrfDecodeForwardRnnCell(tf.keras.layers.AbstractRNNCell):
    @typechecked
    def __init__(self, num_tags: int, **kwargs):
        super().__init__(**kwargs)
        self._num_tags = num_tags

    def call(self, inputs, state, constants):
        transition_params = tf.expand_dims(constants[0], 0)
        # Similar to current implementation
        return backpointers, new_state


def crf_decode_forward(
    inputs: TensorLike,
    state: TensorLike,
    transition_params: TensorLike,
    sequence_lengths: TensorLike,
):
    sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
    mask = tf.sequence_mask(sequence_lengths, tf.shape(inputs)[1])
    crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params.shape[0], dtype=inputs.dtype)
    crf_fwd_layer = tf.keras.layers.RNN(
        crf_fwd_cell, return_sequences=True, return_state=True, dtype=inputs.dtype
    )
    return crf_fwd_layer(inputs, state, constants=transition_params, mask=mask)

Pros:

  • Can enjoy tf.keras autocast theoretically.
  • Do not worry about serialization.

Cons:

  • So far, mixed precision autocast only applies to the first arg so we cannot expect state and constants are casted correctly. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/engine/base_layer.py#L936-L947
  • CrfDecodeForwardRnnCell and RNN layer are instantiated but not visible to users.
  • Have to check if state is a KerasTensor and decide if constants should be passed as a KerasTensor; otherwise this line will fail.

cc @tensorflow/sig-addons-maintainers for visibility.

WindQAQ avatar Dec 07 '20 03:12 WindQAQ

Using tf.keras.backend.rnn is better as its pros make it more preferable.

chazuttu avatar Jan 04 '21 05:01 chazuttu

Hi @WindQAQ, thank you for your detailed report. Since this issue was written about 10 months ago, is there any update about those Pros and Cons of both choices? Could you update those statements to make it updated? I am happy to refactor the crf_decode_forward if I get the enough estimate info up to date.

howl-anderson avatar Sep 18 '21 02:09 howl-anderson

@howl-anderson I don't think that he is still active in the project.

bhack avatar Sep 20 '21 10:09 bhack

@bhack Thank you! I will try to do this by myself.

howl-anderson avatar Sep 22 '21 06:09 howl-anderson

TensorFlow Addons is transitioning to a minimal maintenance and release mode. New features will not be added to this repository. For more information, please see our public messaging on this decision: TensorFlow Addons Wind Down

Please consider sending feature requests / contributions to other repositories in the TF community with a similar charters to TFA: Keras Keras-CV Keras-NLP

seanpmorgan avatar Mar 01 '23 04:03 seanpmorgan