ExplainableAI.jl icon indicating copy to clipboard operation
ExplainableAI.jl copied to clipboard

Test GPU support

Open adrhill opened this issue 3 years ago • 1 comments

GPU support is currently untested. In theory, GPU tests could be run on CI using the JuliaGPU Buildkite CI.

Locally, a first test of GPU support can be run by modifying the readme example to cast input to a GPU array. It should be possible to run the following code in a fresh temp-environment:

using CUDA
using ExplainableAI
using Flux
using MLDatasets
using Downloads: download
using BSON: @load

model_url = "https://github.com/adrhill/ExplainableAI.jl/raw/master/docs/src/model.bson"
path = joinpath(@__DIR__, "model.bson")
!isfile(path) && download(model_url, path)
@load "model.bson" model

model = strip_softmax(model)
x, _ = MNIST.testdata(Float32, 10)
input = reshape(x, 28, 28, 1, :)

input_gpu = gpu(input) # cast input to GPU array
analyzer = LRP(model)
expl = analyze(input_gpu, analyzer)

adrhill avatar May 20 '22 14:05 adrhill

It would also be interesting to compare CPU and GPU performance using the benchmark

@benchmark analyze($input_gpu, $analyzer)
@benchmark analyze($input, $analyzer)

adrhill avatar May 20 '22 14:05 adrhill