PyTorch-extension-Convolution
PyTorch-extension-Convolution copied to clipboard
An example of C++ extension for PyTorch.
Implemented convolution based on CUDA extensions in PyTorch
A convolution implementation based on cuda extension for PyTorch. The source code reference to the PyTorch's inefficient implementation here. See here for the accompanying tutorial.
- Build CUDA extensions by going into the
conv_cuda/folder and executingpython setup.py install, - JIT-compile CUDA extensions by going into the
conv_cuda/folder and callingpython jit.py, which will JIT-compile the extension and load it, - Check the result of the convolution by running
python test.py