[Optimization][Operator] Implement and enable Conv2d-Reshape-Add-ReLU fusion
This PR introduces an operator fusion for the common conv2d followed by reshape, add, and relu sequence, commonly found in deep learning models (e.g., convolution + bias + activation pattern in PyTorch). This optimization significantly improves performance and efficiency by reducing overhead and optimizing memory usage.
Specific Benefits:
-
Performance Improvement:
-
Reduced Kernel Launch Overhead: Previously,
conv2d,reshape,add, andrelueach required separate kernel calls. By fusing these four operations into a single, unified DNNL kernel (e.g.,dnnl_fused_conv2d_bias_relu), the overhead from multiple kernel launches is significantly reduced. This is evident fromsrc/runtime/contrib/dnnl/dnnl.cc:154-158, where all operations are handled by a singleexecutecall. -
Decreased Memory Bandwidth Consumption: Intermediate results of individual operations (e.g.,
conv_out,bias_add) traditionally required costly memory write-backs and reads. Fusion allows these intermediate values to be processed directly in registers or cache, reducing unnecessary memory accesses, and thus decreasing memory bandwidth usage and overall execution time.
-
Reduced Kernel Launch Overhead: Previously,
-
Increased Efficiency:
-
Leveraging Compiler Optimizations: By utilizing TVM's
FuseOpsByPatternandMergeCompositeFunctionspasses, this change generates a composite operation optimized for specific backends (like DNNL). This ensures that common patterns from frontends like PyTorch are automatically recognized within the TVM graph and mapped to high-performance fused kernels provided by libraries like DNNL. - Simplified IR Module: Compilers' Intermediate Representation (IR) becomes less complex as multiple operation nodes are condensed into a single composite node. This simplification enhances efficiency in subsequent optimization and code generation stages.
-
Leveraging Compiler Optimizations: By utilizing TVM's
How Fusion Works:
This fusion is achieved through a two-stage transformation within the TVM Relax framework:
-
Pattern Recognition and Composite Function Creation (
FuseConv2dReshapeAddReluPass):- The
FuseConv2dReshapeAddReluclass, registered as atvm.transform.module_pass, transforms theIRModule. - The
_conv2d_reshape_add_relu_pattern()helper function defines the specific sequence:conv2d->reshape(applied to bias) ->add->reluusing TVM's Declarative Pattern Language (DPL). This includes matching input tensors (data,weight,bias,shape) usingwildcard()and identifying operation sequence withis_op(). - The
relax.transform.FuseOpsByPatternpass identifies this pattern in the inputIRModule. Upon detection, the operation sequence is encapsulated into a new Relax function with{"Composite": "dnnl.conv2d_reshape_add_relu", "Primitive": True}attributes, marking it as a logical "composite" unit.
- The
-
Composite Function Merging and Codegen Attribute Assignment (
MergeCompositeFunctionsPass):- Following the
FuseConv2dReshapeAddRelupass, theMergeCompositeFunctionspass is applied viatvm.ir.transform.Sequential. - This pass identifies functions marked with the
Compositeattribute and transforms them into external functions bearing the{"Codegen": "dnnl"}attribute. ThisCodegenattribute indicates that the composite operation should be offloaded to a specific TVM backend, such as DNNL. - Consequently, during graph execution, the fused function with the
Codegenattribute will be mapped and executed by an optimized, single DNNL kernel, for instance,dnnl_fused_conv2d_bias_relu(defined insrc/runtime/contrib/dnnl/dnnl.cc:199-207).
- Following the
Key Achievement:
This implementation successfully enables the fusion of the conv2d + reshape + add + relu pattern. This ensures that common convolution + bias + activation patterns originating from frontends like PyTorch are now fully optimized and executed as a single, highly efficient DNNL kernel within TVM.
How to Test:
To verify this fusion, you can directly run the specific test case:
python tests/python/relax/test_conv2d_reshape_add_relu.py