trax
trax copied to clipboard
How to make predictions on a trained model
Description
I was using the pre-defined model trax.models.rnn.LSTMSeq2SeqAttn() to train on a NLP dataset, which is transformed to vectors. The training was okay and I have the saved model along with weights (the 'model.pkl.gz' file and all others this model generated) but I cannot find a way to predict on the test set.
I followed the code from a Coursera course called NLP Specialization (C3W3 Assignment Part 4) by deeplearning.ai and try to make predictions but still got no luck. Here is the code:
model = NER()
model.init(trax.shapes.ShapeDtype((1, 1), dtype=np.int32))
# Load the pretrained model
model.init_from_file('model.pkl.gz', weights_only=True)
x, y = next(data_generator(len(test_sentences), test_sentences, test_labels, vocab['<PAD>']))
print("input shapes", x.shape, y.shape)
# sample prediction
tmp_pred = model(x)
print(type(tmp_pred))
print(f"tmp_pred has shape: {tmp_pred.shape}")
I didn't use the ShapeDtype because I do not clearly know what to put there.
My data is a list of lists of integers. The output is either 0 or 1. Please let me know if any further information is required.
...
Environment information
OS: Google Colab
$ pip freeze | grep trax
trax==1.3.9
$ pip freeze | grep tensor
mesh-tensorflow==0.1.19
tensorboard==2.5.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow @ file:///tensorflow-2.5.0-cp37-cp37m-linux_x86_64.whl
tensorflow-datasets==4.0.1
tensorflow-estimator==2.5.0
tensorflow-gcs-config==2.5.0
tensorflow-hub==0.12.0
tensorflow-metadata==1.2.0
tensorflow-probability==0.13.0
tensorflow-text==2.5.0
$ pip freeze | grep jax
jax==0.2.17
jaxlib @ https://storage.googleapis.com/jax-releases/cuda110/jaxlib-0.1.69+cuda110-cp37-none-manylinux2010_x86_64.whl
$ python -V
Python 3.7.11
For bugs: reproduction and error logs
# Steps to reproduce:
...
# Error logs:
...
AttributeError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
586 if not self.has_backward:
--> 587 outputs = self.forward(x)
588 s = self.state
27 frames
AttributeError: 'int' object has no attribute 'shape'
During handling of the above exception, another exception occurred:
AttributeError Traceback (most recent call last)
AttributeError: 'int' object has no attribute 'shape'
During handling of the above exception, another exception occurred:
AttributeError Traceback (most recent call last)
AttributeError: 'int' object has no attribute 'shape'
During handling of the above exception, another exception occurred:
AttributeError Traceback (most recent call last)
AttributeError: 'int' object has no attribute 'shape'
During handling of the above exception, another exception occurred:
AttributeError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/trax/shapes.py in signature(obj)
100 return {k: signature(v) for (k, v) in obj.items()}
101 else:
--> 102 return ShapeDtype(obj.shape, obj.dtype)
103
104
AttributeError: 'int' object has no attribute 'shape'