ao icon indicating copy to clipboard operation
ao copied to clipboard

Torchao import time

Open felipemello1 opened this issue 1 year ago • 3 comments

Hi folks, not a bug. In torchtune, importing the library takes ~7s. When I profile it, majority is coming from torchao imports.

just a simple 'import torchao' takes ~4s

import time
start = time.perf_counter()
import torchao
end = time.perf_counter()
print("time import: ", end - start)

Its possible to do some profiling like this (by cumulative):

python -X importtime -c "import torchao" 2> torchao_import_times.txt
sort -k5,5nr torchao_import_times.txt > sorted_torchao_by_cumulative.txt

by self

sort -k3,3nr torchao_import_times.txt > sorted_torchao_by_self.txt

Just wanted to share it here in case someone wants to take a look. Thanks!

outputs self

import time:    763864 |     819415 |     torch._C
import time:    436553 |     436553 |                             torchao.float8.float8_utils
import time:    247885 |     249457 |           torch._prims
import time:    220180 |     999455 |     torchao.quantization.autoquant
import time:    150263 |     285311 |               torch._inductor.decomposition
import time:    111738 |     111738 |                         torch.ao.quantization.quantizer.quantizer
import time:    109062 |     478208 |     torch._meta_registrations
import time:     83299 |     238870 |                         torch._dynamo.codegen
import time:     57174 |      57174 |                           mpmath.ctx_iv
import time:     53617 |      73218 |         torch._refs
import time:     42450 |     291907 |         torch._decomp.decompositions
import time:     36143 |      36143 |                                                 torch.distributed.tensor._collective_utils
import time:     32693 |     212937 |           torch._dynamo.polyfills.loader
import time:     30891 |    1893773 |   torch
import time:     25798 |      25798 |                       torch._dynamo.source
import time:     22112 |      23282 |                 torchgen.model
import time:     20698 |      20698 |                     triton.language.standard
import time:     16428 |     373652 |                   torch.fx.experimental.symbolic_shapes
import time:     15278 |      15627 |           torch._inductor.config
import time:     15120 |      15120 |                                   _csv
import time:     14515 |      43410 |           torch.fx.experimental._constant_symnode
import time:     13942 |      15210 |                         torch.fx.node
import time:     13130 |      13130 |                 torch._inductor.inductor_prims
import time:     12562 |     817979 |             torch._dynamo.symbolic_convert
import time:     11626 |      11735 |                       triton._C.libtriton
import time:     11372 |      12810 |                   numpy._core._multiarray_umath
import time:     10966 |     401829 |               torch._dynamo.trace_rules
import time:      9922 |       9922 |                       networkx.utils.backends
import time:      9892 |      37945 |                                 torch._subclasses.fake_tensor
import time:      9745 |      38499 |                             torch.utils._pytree
import time:      9706 |     136029 |         torchao.kernel.intmm_triton
import time:      9687 |      79312 |               torch.nn.functional
import time:      9642 |       9642 |                           torch.onnx.symbolic_opset9
import time:      9526 |       9776 |                                   torch._subclasses.meta_utils
import time:      8763 |       8763 |                     triton.language.random
import time:      8699 |       8699 |                     torch._functorch._aot_autograd.schemas
import time:      8633 |       8633 |                                 torch._dynamo._trace_wrapped_higher_order_op
import time:      8586 |       8586 |                                         dill.session
import time:      8447 |     398127 |                 torch._dynamo.utils
import time:      7971 |      10009 |                                   sympy.integrals.transforms
import time:      7683 |       7683 |                                           torch.distributed.tensor._ops._pointwise_ops
import time:      7199 |     350781 |                     torch.utils._sympy.functions
import time:      7145 |       7361 |                             torch._guards
import time:      6978 |      65819 |                     triton.backends
import time:      6969 |     107833 |                   networkx
import time:      6952 |       6952 |                         torch.ao.quantization.backend_config.backend_config
import time:      6506 |       8114 |                           torch._dynamo.variables.torch_function
import time:      6481 |       6648 |           torch._refs.nn.functional
import time:      6459 |       7007 |               torch._dynamo.config
import time:      6405 |       6405 |           torch._refs.fft
import time:      6368 |       6368 |                               sympy.sets.handlers.intersection
import time:      6067 |       6283 |                                   torch._subclasses.fake_impls
import time:      5856 |       5856 |             torch.masked.maskedtensor.unary
import time:      5738 |      12546 |                                               sympy.functions.combinatorial.numbers
import time:      5604 |      18271 |                                 sympy.core.numbers
import time:      5450 |       5450 |                                     torch.distributed.fsdp.api
import time:      5442 |       5442 |                         torch.export.graph_signature
import time:      5193 |      15833 |                 torch._functorch._aot_autograd.runtime_wrappers
import time:      4764 |       6177 |         torch.profiler._memory_profiler

outputs cumulative:

import time: self [us] | cumulative | imported package
import time:      1740 |    4098456 | torchao
import time:       534 |    2193863 |   torchao.quantization
import time:     30891 |    1893773 |   torch
import time:       123 |    1186976 |     torchao.kernel
import time:       299 |    1186853 |       torchao.kernel.intmm
import time:       506 |    1046467 |         torch._dynamo
import time:    220180 |     999455 |     torchao.quantization.autoquant
import time:      1298 |     830580 |           torch._dynamo.convert_frame
import time:    763864 |     819415 |     torch._C
import time:     12562 |     817979 |             torch._dynamo.symbolic_convert
import time:       298 |     775419 |       torchao.dtypes
import time:      1153 |     772792 |         torchao.dtypes.affine_quantized_tensor_ops
import time:    109062 |     478208 |     torch._meta_registrations
import time:        42 |     471058 |           torchao.dtypes.floatx.float8_layout
import time:       325 |     471016 |             torchao.dtypes.floatx
import time:      1162 |     469622 |               torchao.dtypes.floatx.float8_layout
import time:        30 |     468460 |                 torchao.float8.inference
import time:       268 |     468431 |                   torchao.float8
import time:       528 |     463194 |                     torchao.float8.float8_linear_utils
import time:       448 |     462115 |                       torchao.float8.float8_linear
import time:       206 |     460604 |                         torchao.float8.distributed_utils
import time:      1337 |     437890 |                           torchao.float8.float8_tensor
import time:    436553 |     436553 |                             torchao.float8.float8_utils
import time:     10966 |     401829 |               torch._dynamo.trace_rules
import time:      1147 |     399274 |               torch._dynamo.exc
import time:      8447 |     398127 |                 torch._dynamo.utils
import time:       334 |     384203 |                 torch._dynamo.variables
import time:       876 |     383869 |                   torch._dynamo.variables.base
import time:     16428 |     373652 |                   torch.fx.experimental.symbolic_shapes
import time:       546 |     365670 |       torch._decomp
import time:      7199 |     350781 |                     torch.utils._sympy.functions
import time:      3333 |     344590 |                     torch._dynamo.variables.builder
import time:      1394 |     342723 |                       sympy
import time:     42450 |     291907 |         torch._decomp.decompositions
import time:       439 |     291042 |           torchao.dtypes.affine_quantized_tensor
import time:      3597 |     289228 |             torchao.quantization.quant_primitives
import time:    150263 |     285311 |               torch._inductor.decomposition
import time:    247885 |     249457 |           torch._prims
import time:       563 |     239433 |                       torch._dynamo.side_effects
import time:     83299 |     238870 |                         torch._dynamo.codegen
import time:     32693 |     212937 |           torch._dynamo.polyfills.loader
import time:      3344 |     206276 |     torch.functional
import time:        24 |     202417 |       torch.nn.functional
import time:       288 |     202393 |         torch.nn
import time:       768 |     186719 |           torch.nn.modules
import time:      1098 |     179754 |             torch._functorch.aot_autograd
import time:       435 |     148021 |                         sympy.polys
import time:       353 |     138673 |     torch.quantization
import time:       171 |     137354 |       torch.quantization.fake_quantize
import time:        28 |     137183 |         torch.ao.quantization.fake_quantize
import time:       588 |     137155 |           torch.ao.quantization
import time:      9706 |     136029 |         torchao.kernel.intmm_triton
import time:        26 |     135048 |                 torch.ao.quantization.fx._decomposed
import time:       238 |     135022 |                   torch.ao.quantization.fx
import time:      1581 |     130600 |                           torch._dynamo.variables.functions
import time:      2767 |     125378 |               torch._functorch.partitioners
import time:        24 |     124152 |                             torch.distributed.fsdp._fully_shard
import time:       368 |     124128 |                               torch.distributed.fsdp
import time:      2008 |     123693 |             torch.ao.quantization.pt2e._numeric_debugger
import time:       357 |     120122 |               torch.ao.quantization.pt2e.graph_utils
import time:       505 |     118673 |                 torch.export
import time:       771 |     112702 |                     torch.ao.quantization.fx.prepare
import time:       194 |     111932 |                       torch.ao.quantization.quantizer
import time:    111738 |     111738 |                         torch.ao.quantization.quantizer.quantizer
import time:       259 |     110492 |           triton
import time:       498 |     108531 |                 torch._functorch._activation_checkpointing.graph_info_provider
import time:      6969 |     107833 |                   networkx
import time:      2846 |      97056 |                                 torch.distributed.fsdp._flat_param
import time:       315 |      85810 |                                   torch.distributed.fsdp._fsdp_extensions
import time:      1922 |      85796 |             torch.nn.modules.module
import time:        31 |      84666 |                   torch.fx.passes.infra.pass_base
import time:       207 |      84636 |                     torch.fx.passes.infra
import time:       380 |      83535 |                       torch.fx.passes
import time:      2907 |      83486 |                       torch._dynamo.variables.torch
import time:       228 |      83183 |                           sympy.polys.polyfuncs
import time:       337 |      82955 |                             sympy.polys.specialpolys
import time:      2156 |      82613 |                         mpmath
import time:      1691 |      82388 |               torch.utils._python_dispatch
import time:       523 |      80528 |             torch.nn.modules.linear
import time:      9687 |      79312 |               torch.nn.functional
import time:      1201 |      78049 |                               sympy.polys.rings
import time:       173 |      75427 |                                 sympy.printing.defaults
import time:       612 |      75255 |                                   sympy.printing
import time:     53617 |      73218 |         torch._refs
import time:       412 |      70004 |                         torch.fx.passes.graph_drawer
import time:       278 |      69541 |                                     torch.distributed.fsdp._shard_utils
import time:       675 |      69482 |                           torch.fx.passes.shape_prop
import time:       204 |      69263 |                                       torch.distributed.tensor
import time:       411 |      69060 |                                         torch.distributed.tensor._ops
import time:       188 |      67813 |             triton.runtime
import time:       355 |      67626 |               triton.runtime.autotuner
import time:       909 |      66945 |                 triton.runtime.jit
import time:       900 |      66753 |                 torch._jit_internal
import time:       218 |      66036 |                   triton.runtime.driver
import time:      6978 |      65819 |                     triton.backends
import time:       762 |      60191 |                                     sympy.printing.pycode
import time:      1307 |      59438 |                   torch.distributed.rpc
import time:       501 |      59430 |                                       sympy.printing.codeprinter
import time:        23 |      58930 |                                         sympy.functions.elementary.complexes
import time:        22 |      58907 |                                           sympy.functions.elementary
import time:       641 |      58885 |                                             sympy.functions
import time:      1467 |      57545 |                     networkx.algorithms
import time:       380 |      57234 |                 torch.utils
import time:     57174 |      57174 |                           mpmath.ctx_iv
import time:      1157 |      55663 |                         torch.onnx.utils
import time:      3006 |      55551 |       numpy
import time:       232 |      54895 |                   torch.utils.data
import time:       188 |      54830 |                     torch.distributed.rpc.server_process_global_profiler
import time:       267 |      54643 |                       torch.autograd.profiler_legacy
import time:      1045 |      54408 |                     torch.utils.data.dataloader
import time:      2022 |      54376 |                         torch.autograd
import time:       278 |      51423 |                                           torch.distributed.tensor._ops._conv_ops
import time:       211 |      50727 |                           torch.onnx._internal.diagnostics
import time:       501 |      49081 |                             torch.onnx._internal.diagnostics._diagnostic
import time:       197 |      48581 |                               torch.onnx._internal.diagnostics.infra
import time:      4544 |      45907 |                                 torch.onnx._internal.diagnostics.infra._infra
import time:       378 |      45331 |     torch.nested
import time:       691 |      44954 |       torch.nested._internal.nested_tensor
import time:       318 |      44771 |                       triton.runtime.build
import time:       586 |      43970 |                         setuptools
import time:       230 |      43639 |         torch.nested._internal.nested_int
import time:     14515 |      43410 |           torch.fx.experimental._constant_symnode
import time:       999 |      43305 |                                             torch.distributed.tensor._dtensor_spec
import time:      2460 |      42306 |                                               torch.distributed.tensor.placement_types
import time:       145 |      42012 |             triton.compiler
import time:       522 |      41868 |               triton.compiler.compiler
import time:       258 |      41363 |                                   torch.onnx._internal.diagnostics.infra.formatter
import time:       625 |      41106 |                                     torch.onnx._internal.diagnostics.infra.sarif
import time:      1153 |      40580 |                 triton.compiler.code_generator
import time:       459 |      39132 |                   triton.language
import time:        30 |      39052 |                         sympy.core.cache
import time:       346 |      39022 |                           sympy.core
import time:       409 |      38907 |                           torch._vmap_internals
import time:        28 |      38576 |                             torch._subclasses.meta_utils
import time:       203 |      38549 |                               torch._subclasses
import time:      9745 |      38499 |                             torch.utils._pytree
import time:      3227 |      38278 |                     torch._dynamo.guards
import time:      9892 |      37945 |                                 torch._subclasses.fake_tensor
import time:     36143 |      36143 |                                                 torch.distributed.tensor._collective_utils
import time:       236 |      35946 |                       torch.utils.data.graph_settings
import time:        18 |      35517 |                         torch.utils.data.datapipes.iter.sharding
import time:        16 |      35500 |                           torch.utils.data.datapipes.iter
import time:       175 |      35484 |                             torch.utils.data.datapipes
import time:      2365 |      34309 |               torch._functorch._aot_autograd.autograd_cache
import time:       255 |      32348 |                           sympy.polys.partfrac
import time:       294 |      32093 |                             sympy.matrices
import time:       161 |      28896 |             torch.fx.experimental
import time:      1691 |      28754 |                               importlib.metadata
import time:       224 |      28735 |               torch.fx
import time:      1794 |      28666 |                             sympy.core.expr
import time:       445 |      27855 |         numpy.__config__
import time:        24 |      27410 |           numpy._core._multiarray_umath
import time:       723 |      27387 |             numpy._core
import time:       273 |      26109 |                   torch.export.decomp_utils
import time:       778 |      26105 |                           setuptools.dist
import time:      1567 |      25842 |                 torch.fx._symbolic_trace
import time:        36 |      25837 |                     torch._export.utils
import time:       952 |      25802 |                       torch._export
import time:     25798 |      25798 |                       torch._dynamo.source
import time:      1244 |      25464 | site
import time:       907 |      24864 |     torch.distributions
import time:     22112 |      23282 |                 torchgen.model
import time:       421 |      22871 |                             torch._dispatch.python
import time:       591 |      22721 |         numpy.lib
import time:       203 |      22618 |                               torch.utils.data.datapipes.dataframe
import time:       886 |      22509 |                           torch.distributed._tensor
import time:      1783 |      22340 |                               unittest.mock
import time:       743 |      21252 |   __editable___torchtune_0_0_0_finder
import time:       310 |      21224 |                   torch.fx._lazy_graph_module
import time:       886 |      20915 |                     torch.fx.graph_module
import time:      1856 |      20761 |                               sympy.core.mul
import time:     20698 |      20698 |                     triton.language.standard
import time:      2042 |      20645 |                           sympy.polys.polytools
import time:       881 |      20517 |                                 torch.utils.data.datapipes.dataframe.dataframes
import time:      3886 |      20029 |                       torch.fx.graph
import time:       965 |      19650 |                     torch.ao.quantization.fx.convert
import time:       565 |      19196 |                                   torch.utils.data.datapipes._decorator
import time:        52 |      18579 |                             torch.distributed.checkpoint.metadata
import time:       295 |      18528 |                               torch.distributed.checkpoint
import time:      5604 |      18271 |                                 sympy.core.numbers
import time:       408 |      18270 |                         torch._export.wrappers
import time:       393 |      18155 |                           mpmath.ctx_fp
import time:       262 |      17634 |                         sympy.geometry
import time:      1046 |      17413 |                                     torch.utils.data.datapipes.datapipe
import time:       187 |      17154 |     torch.nn.intrinsic
import time:       632 |      16896 |                             mpmath.ctx_base
import time:       747 |      16858 |                           torch._dynamo.variables.nn_module
import time:      1330 |      16812 |                               sympy.matrices.immutable
import time:        25 |      16503 |                           torch._higher_order_ops.strict_mode
import time:       532 |      16479 |                             torch._higher_order_ops
import time:      3976 |      16112 |                 torch._inductor.codecache
import time:        20 |      15956 |                                     torch.distributed._shard.sharded_tensor.api
import time:        16 |      15936 |                                       torch.distributed._shard.sharded_tensor
import time:       117 |      15921 |                                         torch.distributed._shard
import time:      5193 |      15833 |                 torch._functorch._aot_autograd.runtime_wrappers
import time:       279 |      15804 |                                           torch.distributed._shard.api
import time:       184 |      15751 |                         sympy.concrete
import time:       582 |      15702 |                                 csv
import time:     15278 |      15627 |           torch._inductor.config
import time:       493 |      15567 |                           sympy.concrete.products
import time:       395 |      15482 |                                 sympy.matrices.expressions
import time:       404 |      15448 |     torch._utils_internal
import time:      4500 |      15380 |                       torch.distributed
import time:     13942 |      15210 |                         torch.fx.node
import time:     15120 |      15120 |                                   _csv
import time:       602 |      14949 |           numpy.lib._index_tricks_impl
import time:       209 |      14817 |     torch.masked
import time:       434 |      14798 |                                       dill
import time:       580 |      14798 |     torch.hub
import time:       260 |      14729 |                             sympy.polys.constructor
import time:       827 |      14609 |       torch.masked._ops
import time:       375 |      14470 |                               sympy.polys.domains
import time:       389 |      14391 |                                             torch.distributed._shard.sharded_tensor
import time:      3002 |      14310 |                                 torch.distributed.fsdp.fully_sharded_data_parallel
import time:       833 |      14287 |                             sympy.concrete.summations
import time:      2792 |      14236 |               torch._inductor.output_code
import time:       672 |      14043 |               numpy._core.multiarray
import time:       420 |      14008 |                               sympy.matrices.dense
import time:       394 |      13735 |                                 asyncio
import time:        24 |      13621 |         torch.masked.maskedtensor.core
import time:       233 |      13597 |           torch.masked.maskedtensor
import time:       393 |      13372 |                 numpy._core.overrides
import time:        22 |      13188 |                               sympy.integrals.integrals
import time:       543 |      13186 |                             torch._dynamo.variables.lazy
import time:       190 |      13166 |                                 sympy.integrals
import time:     13130 |      13130 |                 torch._inductor.inductor_prims
import time:       974 |      13038 |                           _distutils_hack.override
import time:       294 |      12884 |                     networkx.utils
import time:     11372 |      12810 |                   numpy._core._multiarray_umath
import time:      1236 |      12644 |                               torch._dynamo.variables.tensor
import time:       155 |      12624 |             numpy.matrixlib
import time:      5738 |      12546 |                                               sympy.functions.combinatorial.numbers
import time:       356 |      12470 |               numpy.matrixlib.defmatrix
import time:       221 |      12396 |                                 torch.distributed.fsdp._fully_shard
import time:       201 |      12114 |                 numpy.linalg
import time:      3242 |      12102 |     torch.cuda
import time:      2196 |      12043 |     pathlib
import time:      2868 |      11991 |                                 sympy.matrices.matrixbase
import time:       733 |      11967 |                           sympy.geometry.point
import time:      1720 |      11779 |                   numpy.linalg._linalg
import time:     11626 |      11735 |                       triton._C.libtriton
import time:       152 |      11686 |       torch.nn.intrinsic.quantized
import time:      1421 |      11322 |                                   sympy.core.power
import time:       315 |      11322 |                     networkx.readwrite
import time:       335 |      11224 |                         sympy.solvers
import time:      1889 |      11094 |                             sympy.geometry.entity
import time:      1155 |      11042 |                                               torch.distributed._shard.sharded_tensor.api
import time:       152 |      10965 |         torch.nn.intrinsic.quantized.dynamic
import time:       255 |      10945 |                           sympy.polys.numberfields
import time:       148 |      10813 |           torch.nn.intrinsic.quantized.dynamic.modules
....

felipemello1 avatar Mar 24 '25 18:03 felipemello1

thanks for reporting this @felipemello1!

looks like some culprits are torchao.float8.float8_utils: ~437ms individual time and torchao.quantization.autoquant: ~220ms individual time and ~999ms cumulative

Also noticed that the float8 related modules have very deep nested imports for inference.

@vkuzo and @jerryzh168 can we fix this?

supriyar avatar Mar 24 '25 18:03 supriyar

@msaroufim is working on this: https://github.com/pytorch/ao/pull/2153

jerryzh168 avatar May 01 '25 23:05 jerryzh168

Please keep low expectations though, I'm having to make very big sweeping changes to the repo to deal with circular imports

msaroufim avatar May 01 '25 23:05 msaroufim