[FEA] LinearCombinationSilu epilogue
I modified the epilogue function in Example 17 from LinearCombinationRelu to LinearCombinationSilu, like this: using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu< ElementOutput, // Data type of output matrix. 128 / cutlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized. // memory access. This becomes the vector width of // math instructions in the epilogue too. ElementAccumulator, // Data type of accumulator ElementComputeEpilogue, // Data type for alpha in linear combination cutlass::epilogue::thread::ScaleType::NoBetaScaling>; // alpha X C + per channel bias
and reported an error: error: no instance of constructor "cutlass::conv::kernel::ImplicitGemmConvolution<Mma_, Epilogue_, ThreadblockSwizzle_, ConvOperator, ConvProblemSize_>::Arguments::Arguments [with Mma_=cutlass::conv::threadblock::ImplicitGemmMultistage<ThreadblockShape, cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized<cutlass::MatrixShape<128, 32>, ElementInputA, LayoutInputA, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, cutlass::AlignedArray<ElementInputA, 8, 16>>, cutlass::transform::threadblock::RegularTileAccessIterator<cutlass::MatrixShape<128, 32>, ElementInputA, cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<16, 32>, 0, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, 16>, cutlass::arch::CacheOperation::Always, cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized<cutlass::MatrixShape<32, 128>, ElementInputB, LayoutInputB, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, cutlass::AlignedArray<ElementInputA, 8, 16>>, cutlass::transform::threadblock::RegularTileAccessIterator<cutlass::MatrixShape<32, 128>, ElementInputB, cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<16, 32>, 1, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<32, 128>, 128, cutlass::PitchLinearShape<4, 8>, 8>, 16>, cutlass::arch::CacheOperation::Global, cutlass::gemm::threadblock::MmaPolicy<cutlass::gemm::warp::MmaTensorOp<WarpShape, ElementInputA, cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<16, 32>, ElementInputB, cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<16, 32>, ElementAccumulator, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaTensorOpPolicy<cutlass::arch::Mma<cutlass::gemm::GemmShape<16, 8, 16>, 32, cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, float, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>, cutlass::MatrixShape<1, 1>>, 1, false, __nv_bool>, cutlass::MatrixShape<0, 0>, cutlass::MatrixShape<0, 0>, 1>, 4, _nv_bool>, Epilogue=cutlass::epilogue::threadblock::Epilogue<ThreadblockShape, cutlass::gemm::warp::MmaTensorOp<WarpShape, ElementInputA, cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<16, 32>, ElementInputB, cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<16, 32>, ElementAccumulator, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaTensorOpPolicy<cutlass::arch::Mma<cutlass::gemm::GemmShape<16, 8, 16>, 32, cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, float, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>, cutlass::MatrixShape<1, 1>>, 1, false, nv_bool>, 1, cutlass::epilogue::threadblock::PredicatedTileIterator<cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<cutlass::epilogue::threadblock::OutputTileShape<128, 8, 2, 1, 1>, cutlass::epilogue::threadblock::OutputTileShape<1, 8, 1, 1, 8>, 128, 4, 32>, ElementOutput, false>, cutlass::epilogue::warp::FragmentIteratorTensorOp<WarpShape, cutlass::gemm::GemmShape<16, 8, 16>, float, cutlass::Array<float, 4, true>, cutlass::layout::RowMajor>, cutlass::epilogue::warp::TileIteratorTensorOp<WarpShape, cutlass::gemm::GemmShape<16, 8, 16>, float, cutlass::layout::RowMajor>, cutlass::epilogue::threadblock::SharedLoadIterator<cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<cutlass::epilogue::threadblock::OutputTileShape<128, 8, 2, 1, 1>, cutlass::epilogue::threadblock::OutputTileShape<1, 8, 1, 1, 8>, 128, 4, 32>::CompactedThreadMap, float, 16>, EpilogueOp, cutlass::MatrixShape<0, 8>, 2, 0>, ThreadblockSwizzle=SwizzleThreadBlock, ConvOperator=cutlass::conv::Operator::kFprop, ConvProblemSize=cutlass::conv::Conv2dProblemSize]" matches the argument list argument types are: (cutlass::conv::Conv2dProblemSize, cutlass::TensorRef<ElementInputA, LayoutInputA>, cutlass::TensorRef<ElementInputA, LayoutInputA>, {...}, cutlass::TensorRef<ElementOutput, LayoutOutput>, {...})
1 error detected in the compilation of "/****/cutlass/examples/17_fprop_per_channel_bias/fprop_per_channel_bias_swish.cu". make[2]: *** [examples/17_fprop_per_channel_bias/CMakeFiles/17_fprop_per_channel_bias.dir/build.make:76: examples/17_fprop_per_channel_bias/CMakeFiles/17_fprop_per_channel_bias.dir/fprop_per_channel_bias_swish.cu.o] Error 1 make[1]: *** [CMakeFiles/Makefile2:3608: examples/17_fprop_per_channel_bias/CMakeFiles/17_fprop_per_channel_bias.dir/all] Error 2 make: *** [Makefile:166: all] Error 2
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; I have a doubt how to determine the values of these three variables(ThreadblockShape,WarpShape, InstructionShape) and what are they related to
LinearCombinationRelu has default value for beta. LinearCombinationSilu does not. I can add one very quickly.
To work around it, you can change this line (https://github.com/NVIDIA/cutlass/blob/master/examples/17_fprop_per_channel_bias/fprop_per_channel_bias.cu#L196) to {alpha, ElementComputeEpilogue(0)}
ThreadblockShape,WarpShape are as their name suggested the tile size of threadblock and warp respectively. InstructionShape is the size of the tensor core instructions.
You can check the kernels generated by cutlass profiler (https://github.com/NVIDIA/cutlass/blob/master/media/docs/profiler.md) for the plausible tile sizes. Our previous GTC talks also explain the tiling strategies of CUTLASS.
@hwu36 it works,thanks. And I have a doubt: for input Activation=f16:nhwc --Filter=f16:nhwc --Output=f16 --accumulator-type=f16, the input and output channel requirements are multiples of 8, and for channels that are not multiples of 8 situation, padding is needed, right? That is to say, the use of tensorcore for cutlass needs to meet the requirements for half: input and output channels need 8-aligned nhwc format
For the situation:input Activation=f16:nhwc --Filter=f16:nhwc --Output=f16 --accumulator-type=f16 , Does setting accumulator-type=f16 cause overflow?
That is to say, the use of tensorcore for cutlass needs to meet the requirements for half: input and output channels need 8-aligned nhwc format
We support small alignment convolution. Check the kernels that have align4, align2, align1 in the kernel name. You don't have to do padding if you use these small alignment kernels. However, small alignment kernels usually are slower than 128bit aligned kernel. We also provide padding kernels in 2.9 (https://github.com/NVIDIA/cutlass/blob/master/tools/util/include/cutlass/util/device_nhwc_padding.h)
For the situation:input Activation=f16:nhwc --Filter=f16:nhwc --Output=f16 --accumulator-type=f16 , Does setting accumulator-type=f16 cause overflow?
The accuracy will be worse if you use fp16 accumulation.
I added defaults of beta in https://github.com/NVIDIA/cutlass/commit/e49f690fd7969015343a2b5d72549848e760eb65
Hi @hwu36 ,I still have a doubt about Silu. In Example 17 ,under the same configuration except EpilogueOp, I just replaced EpilogueOp from LinearCombination to LinearCombinationSilu, and the run time increased from 0.03372ms to 0.064853ms, which is not what I expected. Fusion Silu with conv without additional IO processing should theoretically take about the same time as conv, what can I do to improve conv performance with Silu ? using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu< ElementOutput, // Data type of output matrix. // 128 / cutlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized. 8, // memory access. This becomes the vector width of // math instructions in the epilogue too. ElementAccumulator, // Data type of accumulator ElementComputeEpilogue, // Data type for alpha in linear combination cutlass::epilogue::thread::ScaleType::NoBetaScaling>; // alpha X C + per channel bias
// using EpilogueOp = cutlass::epilogue::thread::LinearCombination< // cutlass::half_t, // 8, // cutlass::half_t, // cutlass::half_t, // cutlass::epilogue::thread::ScaleType::NoBetaScaling // >;
and my problem size is :
cutlass::conv::Conv2dProblemSize problem_size(
{1, 270,480, 16}, // activation
{16, 1, 1, 16}, // filter
{0, 0, 0, 0}, // padding
{1, 1}, // striding
{1, 1}, // dilation
cutlass::conv::Mode::kCrossCorrelation, // mode (convolution or cross-correlation)
1 // split-k slices
);
Your channel number is small, your filter size is small too. So, not much time is spent in conv. silu is an expensive operation so that it can take the most of time.
What you can try is that
-
Use https://github.com/NVIDIA/cutlass/blob/master/include/cutlass/epilogue/thread/activation.h#L135-L138 this piece of code to compute sigmoid. It may have some impact on the accuracy.
-
Use
FewChannelsto compute conv. It is a new feature in 2.9. The old one wastes lots of computation power when the channel number is small. Check this test: https://github.com/NVIDIA/cutlass/blob/master/test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu
@hwu36 Hi,does cutlass support dethpthwise convlution? I haven't seen the introduction of dethwise in the documentation, sample and header files. I set group=inchannel=outchannel in the conv problem size in example 17, and I found that the result is the same as group=1
You can set the channel number as 1 just like what you did, but it is not the most efficient implementation. We don't have an efficient depthwise conv in our github.
@hwu36 “You can set the channel number as 1 just like what you did” means set group=inchannel=outchannel ? Do you plan to do efficient depthwise conv in the future?
Some one implemented it based on cutlass. Check the one forked most in https://github.com/NVIDIA/cutlass/network/members . They made it.
This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.
depthwise conv is supported in 2.10. we will keep improving it.