SequentialSamplingModels.jl icon indicating copy to clipboard operation
SequentialSamplingModels.jl copied to clipboard

Unified Wald constructor

Open DominiqueMakowski opened this issue 1 year ago • 13 comments

New issue to track and discuss the streamlining of the Wald API (merging of Wald and Mixture Wald) (https://github.com/itsdfish/SequentialSamplingModels.jl/issues/83#issuecomment-2212388196)

DominiqueMakowski avatar Jul 07 '24 14:07 DominiqueMakowski

The benchmarks below show that the mixture model is about 6 times slower than the wald model due to sampling the drift rate, and extra terms. On a development branch I show that condionally executing the wald code when eta is zero achieves the desired speed up without time instability problems.

My plan is to drop WaldMixture in favor of a general Wald model. I will not merge this until later. I would prefer to make some breaking changes in bulk. In the meantime, you can just use WaldMixture with the parametric constraint that eta = 0 to achieve the special case.

using BenchmarkTools
using SequentialSamplingModels

wald_mixture = WaldMixture(;ν=3.0, η=.2, α=.5, τ=.130)
wald = Wald(;ν=3.0, α=.5, τ=.130)

@benchmark rand($wald_mixture, $1000)

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  110.102 μs …  1.529 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     111.269 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   116.936 μs ± 20.719 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

  █▇▂▂▁    ▁        ▁                                          ▁
  ███████▆▇██▇▇█▆▆▆██▇▇█▇▆▅▇▇█▇▆▆▆▆▅▇█▆▆▆▆▆▇▆▆▆▆▇▇▇▇▅▆▇▇▇▇▆▆▆▆ █
  110 μs        Histogram: log(frequency) by time       178 μs <

 Memory estimate: 7.94 KiB, allocs estimate: 1.

 
@benchmark rand($wald, $1000)

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  19.509 μs … 361.751 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     20.727 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   21.309 μs ±   4.358 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

   ▄▆██▇▆▄▂                                                    ▂
  ██████████▇▇▆▆▆▆▆▆▆▄▅▅▅▃▅▇▇███▇▆▆▄▅▅▇▇█▇▇▇▇▇▇▇▆▅▄▃▁▄▃▁▃▃▅▃▄▄ █
  19.5 μs       Histogram: log(frequency) by time      35.7 μs <

 Memory estimate: 15.88 KiB, allocs estimate: 2.
 
rts = rand(wald_mixture, 1000)
@benchmark logpdf.($wald_mixture, $rts)

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  114.675 μs …  1.301 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     114.859 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   119.654 μs ± 18.190 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

  █ ▂     ▁       ▁                                            ▁
  ███▇▅▆▆▇██▇▇▇▆▅▆██▇▇▇▆▆▆▇▇▆▇▇▆▅▆█▇▆▆▆▅▅▄▆▇▆▆▆▅▄▄▅▅▅▄▄▃▃▄▃▄▅▄ █
  115 μs        Histogram: log(frequency) by time       185 μs <

 Memory estimate: 7.94 KiB, allocs estimate: 1.


rts = rand(wald, 1000)
@benchmark logpdf.($wald, $rts)

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  17.573 μs … 321.755 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     17.740 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   18.959 μs ±   5.142 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

  █▅▂                       ▁   ▁▁                             ▁
  ████▅▅▆▆▅▅▅▅▄▄▅▅▄▄▄▅▅▅▃▁▃▆██████▇▅▅▅▅▅▆▆▇▇▇▆▅▅▆▅▅▄▄▄▄▆▅▆▇▇▆▅ █
  17.6 μs       Histogram: log(frequency) by time      38.1 μs <

 Memory estimate: 7.94 KiB, allocs estimate: 1.

itsdfish avatar Jul 07 '24 14:07 itsdfish

Two more semi-related things:

rand(ExGaussian(0.3, 0.2, 0),

  • ExGaussian with tau=0 currently errors, where one would expect it to be a Normal(). Falling back on Normal() if tau=0 would prevent the need for extra-care in prior specification (where one needs to exclude 0)

rand(ExGaussian(0.3, 0.0, 0.1), 100)

  • I would have expected, when sigma=0, that it returns the same values (as for the LogNormal or Normal), but it still has some variability. Maybe something to clarify as it's a bit unexpected

DominiqueMakowski avatar Jul 07 '24 16:07 DominiqueMakowski

  • Falling back on Normal() if tau=0

That is a good idea. I'll make that change.

  • but it still has some variability. Maybe something to clarify as it's a bit unexpected

I think this behavior is to be expected. X ~ normal(3, 0) + exponential(.10) simplifies to a shifted exponential where the var(X) = .10:

using SequentialSamplingModels
 var(rand(ExGaussian(0.3, 0.0, 0.1), 10_000))
0.010246564388110698

itsdfish avatar Jul 07 '24 17:07 itsdfish

I think this behavior is to be expected

Indeed that makes sense, that it would still have variability from the exp distrib

DominiqueMakowski avatar Jul 07 '24 17:07 DominiqueMakowski

I'll have to think about circumventing the error when tau = 0. Upon further thought, I think it might not be a good idea because the expontial distribution is not defined when tau = 0, which, I think, would imply the exguassian is not defined. Along the same lines, the logpdf is not defined with tau = 0 because it is used as a divisor. Right now I think the current implementation is correct.

itsdfish avatar Jul 07 '24 17:07 itsdfish

which, I think, would imply the exguassian is not defined

Although one could argue that if the exponential is "null" then only the Normal remains for Normal + Exp. At the occasion we can check with other implementation see how they do (I just tried to test using brms but it doesn't give me access to the distribution constructor)

DominiqueMakowski avatar Jul 07 '24 18:07 DominiqueMakowski

I verified with gamlss in R:

> dexGAUS(1, mu = 1, sigma = 1, nu = 0, log = FALSE)
Error in dexGAUS(1, mu = 1, sigma = 1, nu = 0, log = FALSE) : 
  nu must be greater than 0  
 
> rexGAUS(1, mu = 1, sigma = 1, nu = 0, log = FALSE)
Error in rexGAUS(1, mu = 1, sigma = 1, nu = 0, log = FALSE) : 
  nu must be positive 

itsdfish avatar Jul 07 '24 19:07 itsdfish

For ExGaussian, the problem with it throwing a DomainError with τ=0 means that sampling often fails, even when specifying τ on a log-link and feeding as exp(τ) to ExGaussian() (see also #93 and #81). My guess is that due to some numeric imprecision, Turing explores very large negative values which gets turned into 0 (instead of 0.00...) and make the whole thing error.

Would it make sense to return -Inf for the logpdf when τ=0, (mostly for convenience when used in Turing), otherwise it currently requires adding some safeguards to the model and forcing the logprob to be -Inf if exp(τ)==0

DominiqueMakowski avatar Aug 14 '24 12:08 DominiqueMakowski

@DominiqueMakowski, please add the proposed fix via add SequentialSamplingModels#inf_fix to your environment to see if it solves your problem. If so, I will merge into main.

itsdfish avatar Aug 14 '24 15:08 itsdfish

I've ran a couple of times and it seems like it fixed the issue!

DominiqueMakowski avatar Aug 14 '24 17:08 DominiqueMakowski

Awesome. I will merge the fix and release a new version shortly. Hopefully, it continues to work well with Pigeons.

itsdfish avatar Aug 14 '24 18:08 itsdfish

On a side note: It seems like Wald is particularly prone to errors in Turing under fairly normal conditions, with

ERROR: DomainError with 0.0:
InverseGaussian: the condition λ > zero(λ) is not satisfied.

Despite having a link function on alpha that should prevent it from being 0. I'm not sure what the cause is, though.

Also, would it make sense (in terms of efficiency) to use the LocationScale() wrapper to create the distribution?

LocationScale(τ, 1, InverseGaussian(μ,λ))

DominiqueMakowski avatar Dec 02 '24 10:12 DominiqueMakowski

Numerical errors can be frustrating. It sounds like there are two potential points of error: (1) your link function, and (2) the reparameterization from Wald parameters to InverseGaussian parameters. One thing you could do is set up print statements to see where the problem is occurring and you could consider adding a check like x = max(x, eps()). Its worth noting that this will also mask errors other than numerical errors (e.g., there is a problem with your link function which causes large negative numbers).

itsdfish avatar Dec 04 '24 16:12 itsdfish

I incorporated the mixture wald into the wald model. As shown below, there is a large performance penalty when $\eta = 0$ in the mixture model. I mitigated the performance penalty with some conditional logic. The resulting increase in execution time is approximately 10% over the old wald model, which seems tolerable to simply the model.

using BenchmarkTools
using SequentialSamplingModels

wald_mixture = WaldMixture(; ν = 1, α = 1, τ = .3, η = 0)
wald = Wald(; ν = 1, α = 1, τ = .3)

data = rand(wald, 100)

@benchmark logpdf($wald, $data)

@benchmark logpdf($wald_mixture, $data)

Wald

BenchmarkTools.Trial: 10000 samples with 25 evaluations per sample.
 Range (min … max):  945.600 ns … 456.006 μs  ┊ GC (min … max): 0.00% … 99.53%
 Time  (median):     994.000 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):     1.170 μs ±   4.629 μs  ┊ GC (mean ± σ):  5.68% ±  2.70%

  ▂▅▇█▆▄▃▁▁▁      ▁▁  ▁▅▅▄▃▅▅▄▃▃▂ ▁▃▃▁▁                         ▂
  ██████████▆▅▅▆▆████▇██████████████████▇▇▇▇▆▆▆▅▆▆▅▆▄▁▄▄▃▄▄▄▃▃▄ █
  946 ns        Histogram: log(frequency) by time       1.63 μs <

 Memory estimate: 928 bytes, allocs estimate: 2.

Wald Mixture

BenchmarkTools.Trial: 10000 samples with 9 evaluations per sample.
 Range (min … max):  2.024 μs …   4.970 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     2.278 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   2.236 μs ± 230.107 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▁▄█▇▅▃▂ ▂   ▅▇▇▆▄▂       ▂▂▁                                ▂
  ███████▇█▆▄▄███████▇▆▆▆▇████▇▇▇▆▆▆▅▅▅▄▅▅▃▄▃▄▃▅▆▄▄▄▄▅▅▄▅▄▅▄▅ █
  2.02 μs      Histogram: log(frequency) by time       3.2 μs <

 Memory estimate: 928 bytes, allocs estimate: 2.

Unified Wald

BenchmarkTools.Trial: 10000 samples with 10 evaluations per sample.
 Range (min … max):  1.056 μs …  18.549 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     1.297 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   1.263 μs ± 269.025 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

   ▂▅▂             ▇█                                          
  ▄███▅▃▃▃▂▂▂▂▂▁▂▂▃███▄▃▂▂▂▅▆▄▃▂▂▂▂▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂ ▃
  1.06 μs         Histogram: frequency by time        1.86 μs <

 Memory estimate: 928 bytes, allocs estimate: 2.

itsdfish avatar Jul 31 '25 11:07 itsdfish