Refactor crf_decode_forward
Describe the feature and the current behavior/state.
Currently, tfa.text.crf_decode_forward is designed as
- Create an RNN cell and pass a loop in variant
transition_paramsinto constructor. - Create an RNN layer and return the call output of it.
- To make
transition_paramsserializable, it's converted to list inget_config. https://github.com/tensorflow/addons/blob/master/tensorflow_addons/text/crf.py#L452-L457
There are three weird points
-
transition_paramsis loop invariant, we can pass it as aconstantscall argument into rnn cell and rnn layer. https://www.tensorflow.org/api_docs/python/tf/keras/layers/RNN?version=nightly - Two classes are instantiated but not visible to users. I don't see any similar usage in core TF. Almost all
tf.keras.layers.*have it's functional counterpart, so doestf.keras.layers.RNN. - As far as I'm concerned, it's uncommon to save a tensor to list in
tf.keras.layers.Layer. Actually,transition_paramsis atf.Variableintfa.layers.CRF, which can be saved automatically when users invokemodel.saveor kinda that.
We can design either patterns in the following
Use tf.keras.backend.rnn
def _crf_decode_forward_step(inputs, states, constants):
# Something like `CrfDecodeForwardRnnCell.call`
pass
def crf_decode_forward(
inputs: TensorLike,
state: TensorLike,
transition_params: TensorLike,
sequence_lengths: TensorLike,
):
(_, outputs, states) = tf.keras.backend.rnn(_crf_decode_forward_step, ..., constants=[transition_params])
return outputs, states
Pros:
-
CrfDecodeForwardRnnCellandRNNlayer are not instantiated. - Do not worry about serialization.
Cons:
- Cannot enjoy
tf.kerasautocast, which might be helpful for mixed precision policy. However, I do not see there is any function intf.kerassupports mixed precision.crf_decode_forwardas a function instead of layer subclass should outputs data type according to input data type rather than global precision policy.
Pass transition_params as constants call argument in rnn cell and layer.
class CrfDecodeForwardRnnCell(tf.keras.layers.AbstractRNNCell):
@typechecked
def __init__(self, num_tags: int, **kwargs):
super().__init__(**kwargs)
self._num_tags = num_tags
def call(self, inputs, state, constants):
transition_params = tf.expand_dims(constants[0], 0)
# Similar to current implementation
return backpointers, new_state
def crf_decode_forward(
inputs: TensorLike,
state: TensorLike,
transition_params: TensorLike,
sequence_lengths: TensorLike,
):
sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
mask = tf.sequence_mask(sequence_lengths, tf.shape(inputs)[1])
crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params.shape[0], dtype=inputs.dtype)
crf_fwd_layer = tf.keras.layers.RNN(
crf_fwd_cell, return_sequences=True, return_state=True, dtype=inputs.dtype
)
return crf_fwd_layer(inputs, state, constants=transition_params, mask=mask)
Pros:
- Can enjoy
tf.kerasautocast theoretically. - Do not worry about serialization.
Cons:
- So far, mixed precision autocast only applies to the first arg so we cannot expect
stateandconstantsare casted correctly. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/engine/base_layer.py#L936-L947 -
CrfDecodeForwardRnnCellandRNNlayer are instantiated but not visible to users. - Have to check if
stateis aKerasTensorand decide ifconstantsshould be passed as aKerasTensor; otherwise this line will fail.
cc @tensorflow/sig-addons-maintainers for visibility.
Using tf.keras.backend.rnn is better as its pros make it more preferable.
Hi @WindQAQ, thank you for your detailed report. Since this issue was written about 10 months ago, is there any update about those Pros and Cons of both choices? Could you update those statements to make it updated? I am happy to refactor the crf_decode_forward if I get the enough estimate info up to date.
@howl-anderson I don't think that he is still active in the project.
@bhack Thank you! I will try to do this by myself.
TensorFlow Addons is transitioning to a minimal maintenance and release mode. New features will not be added to this repository. For more information, please see our public messaging on this decision: TensorFlow Addons Wind Down
Please consider sending feature requests / contributions to other repositories in the TF community with a similar charters to TFA: Keras Keras-CV Keras-NLP