Specialize `sample` for sparse weights
This PR adds a new sample method for sparse weights, as well as tests. It brings the time complexity from O(n) to O(n_nonzero).
This would be useful for e.g. top-p sampling, where one might have on the order of 100k tokens to sample from, but only a few are considered.
Benchmarks across different sizes and densities
Results
This shows the dense baseline, and the relative performance increase to invoking sample with the generic method for sparse weights.
Dense vs Sparse vs Generic sampling:
size density dense_time generic_time sparse_time speedup_dense speedup_generic
----------------------------------------------------------------------------------------------------
10 0.10 8.200 ns 16.717 ns 20.900 ns 0.4x 0.8x
10 0.25 12.312 ns 29.648 ns 25.726 ns 0.5x 1.2x
10 0.50 15.000 ns 31.426 ns 31.000 ns 0.5x 1.0x
10 1.00 17.918 ns 46.821 ns 33.835 ns 0.5x 1.4x
100 0.01 53.799 ns 133.853 ns 21.421 ns 2.5x 6.2x
100 0.10 44.803 ns 235.024 ns 34.303 ns 1.3x 6.9x
100 0.25 54.095 ns 380.565 ns 40.302 ns 1.3x 9.4x
100 0.50 52.775 ns 454.237 ns 50.655 ns 1.0x 9.0x
100 1.00 51.160 ns 553.982 ns 69.706 ns 0.7x 7.9x
1000 0.01 376.093 ns 2.613 μs 34.102 ns 11.0x 76.6x
1000 0.10 405.793 ns 6.025 μs 70.871 ns 5.7x 85.0x
1000 0.25 393.353 ns 7.775 μs 128.072 ns 3.1x 60.7x
1000 0.50 383.527 ns 8.743 μs 219.973 ns 1.7x 39.7x
1000 1.00 384.167 ns 8.444 μs 398.155 ns 1.0x 21.2x
10000 0.01 3.533 μs 44.500 μs 69.706 ns 50.7x 638.4x
10000 0.10 3.778 μs 88.300 μs 403.333 ns 9.4x 218.9x
10000 0.25 3.689 μs 121.850 μs 940.230 ns 3.9x 129.6x
10000 0.50 3.720 μs 152.863 μs 1.880 μs 2.0x 81.3x
10000 1.00 3.744 μs 131.150 μs 3.750 μs 1.0x 35.0x
Benchmark setup
function benchmark_sparse_sample(; sizes=[10, 100, 1000, 10_000], densities=[0.01, 0.1, 0.25, 0.5, 1.0])
println("Dense vs Sparse vs Generic sampling:")
println("size\tdensity\t\tdense_time\tgeneric_time\tsparse_time\tspeedup_dense\tspeedup_generic")
println("-" ^ 100)
for n in sizes
for density in densities
n * density < 1 && continue
nnz = round(Int, n * density)
indices = sort!(sample(1:n, nnz, replace=false))
values = rand(nnz)
values ./= sum(values)
sparse_vector = sparsevec(indices, values, n)
sparse_weights = Weights(sparse_vector)
dense_weights = Weights(collect(sparse_vector))
dense = @benchmark sample($dense_weights)
sparse = @benchmark sample($sparse_weights)
generic = @benchmark invoke(sample, Tuple{AbstractRNG, AbstractWeights},
$(Random.default_rng()), $sparse_weights)
dense_time = median(dense).time
generic_time = median(generic).time
sparse_time = median(sparse).time
speedup_dense = dense_time / sparse_time
speedup_generic = generic_time / sparse_time
@printf("%-8d%-16.2f%-16s%-16s%-16s%.1fx\t\t%.1fx\n",
n, density, BenchmarkTools.prettytime(dense_time),
BenchmarkTools.prettytime(generic_time),
BenchmarkTools.prettytime(sparse_time), speedup_dense, speedup_generic)
end
println()
end
end
Note: For small vector lengths (~10) and low densities (~0.2) the performance difference becomes noisy and less meaningful. The generic method can sometimes be faster in these cases due to less overhead when it happens to find the target probability mass early in the vector. However, for these small cases the absolute timing differences are negligible (few nanoseconds) and sparse storage isn't really beneficial anyway.
Note: The implementation uses SparseArrays.nonzeroinds, which is not public.