ExplainableAI.jl
ExplainableAI.jl copied to clipboard
Test GPU support
GPU support is currently untested. In theory, GPU tests could be run on CI using the JuliaGPU Buildkite CI.
Locally, a first test of GPU support can be run by modifying the readme example to cast input to a GPU array.
It should be possible to run the following code in a fresh temp-environment:
using CUDA
using ExplainableAI
using Flux
using MLDatasets
using Downloads: download
using BSON: @load
model_url = "https://github.com/adrhill/ExplainableAI.jl/raw/master/docs/src/model.bson"
path = joinpath(@__DIR__, "model.bson")
!isfile(path) && download(model_url, path)
@load "model.bson" model
model = strip_softmax(model)
x, _ = MNIST.testdata(Float32, 10)
input = reshape(x, 28, 28, 1, :)
input_gpu = gpu(input) # cast input to GPU array
analyzer = LRP(model)
expl = analyze(input_gpu, analyzer)
It would also be interesting to compare CPU and GPU performance using the benchmark
@benchmark analyze($input_gpu, $analyzer)
@benchmark analyze($input, $analyzer)