import trax takes 17 seconds
Description
Importing trax takes too long.
Environment information
OS: ubuntu 18.04
$ pip freeze | grep trax
trax==1.3.7
$ pip freeze | grep tensor
mesh-tensorflow==0.1.18
tensorboard==2.4.0
tensorboard-plugin-wit==1.7.0
tensorflow==2.4.0
tensorflow-datasets==4.2.0
tensorflow-estimator==2.4.0
tensorflow-gpu==2.4.0
tensorflow-hub==0.11.0
tensorflow-metadata==0.26.0
tensorflow-probability==0.12.1
tensorflow-text==2.4.3
$ pip freeze | grep jax
jax==0.2.8
jaxlib==0.1.58
$ python -V
Python 3.6.9
For bugs: reproduction and error logs
Use the following simple program.
#!/usr/bin/env python3
import time
start = time.time()
import trax
end = time.time()
print('elapsed:', end - start)
Output:
elapsed: 17.436033487319946
Error logs:
This is a well known problem occurring on basically all setups (local, colab, gpu cluster) and it is not a big issue for running long experiments, however it does make local debugging hard. I have tried debugging the import graph with profiler, but without success yet. It looks like even from trax import fastmath has plenty of dependencies - here is the tree generated by importlab library for trax.fastmath.__init__ module:
importlab --tree __init__.py
out:
https://gist.github.com/syzymon/3bb6f59063f918b4b62b77cdb223da72
Same, and worse, 25s in my machine. It is really inconvenient for debugging.