WIP: Improve inference of unflatten
This is pretty minor but I noticed that there are a few cases where unflatten doesn't seem to infer.
Some exampes in master on 1.6:
v, unflatten = flatten(deferred(vec, rand(5,)))
@inferred unflatten(rand(5,)) # return type ParameterHandling.Deferred{typeof(vec), Tuple{Vector{Float64}}} does not match inferred return type ParameterHandling.Deferred{typeof(vec), _A} where _A
v, unflatten = flatten(rand(4,4))
@inferred unflatten(rand(16,)) # return type Matrix{Float64} does not match inferred return type Any
Adding a test for type inference into the test sets shows up a few. I'm aware of the gotcha listed on the docs - ironically that one seems to work on master (Normal => normal)
This MR:
- Adds inference tests into the test suite
- Improved types inferece for a few of the flatten functions that showed up errors. This is through a revised 'oftype' function that has a type annotation. I tend not to see these used, is there a reason not to? It is a little surprising because this is the output of
oftype, so it should know the output.
I came to this because I was seeing some inference issues as part of the CRTU tests using a program with deferred, so put this MR together
Codecov Report
Merging #46 (cf662a8) into master (ced60f6) will increase coverage by
0.09%. The diff coverage is100.00%.
:exclamation: Current head cf662a8 differs from pull request most recent head 273f468. Consider uploading reports for the commit 273f468 to get more accurate results
@@ Coverage Diff @@
## master #46 +/- ##
==========================================
+ Coverage 96.55% 96.64% +0.09%
==========================================
Files 4 4
Lines 174 179 +5
==========================================
+ Hits 168 173 +5
Misses 6 6
| Impacted Files | Coverage Δ | |
|---|---|---|
| src/flatten.jl | 98.11% <100.00%> (+0.03%) |
:arrow_up: |
| src/parameters.jl | 97.53% <100.00%> (ø) |
|
| src/test_utils.jl | 93.18% <100.00%> (+0.68%) |
:arrow_up: |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact),ø = not affected,? = missing dataPowered by Codecov. Last update ced60f6...273f468. Read the comment docs.
Thanks for opening this. I like the general thrust (getting type stability). I'm concerned about the things you've had to do to get it.
In particular, one of the things we're trying to do in https://github.com/invenia/ParameterHandling.jl/pull/39 is remove these kinds of casting operations, so that we can push dual numbers through unflatten. I'm kind of intruiged that it helps to be honest. I'll have a play around locally to try and understand what's going on better.
I had meant to add in my descrition that this was speculative because I was a little suspicious of these changes (using an annotation) but couldn't see a way for this without that. Regardless, I should have looked closer at the other MRs in play - if the longer term intention is to remove these these operations then this of course is incompatible with that, and i'm happy if this gets closed.
I'm actually struggling to replicate your problem locally @AlexRobson . Could you provide your versioninfo()?
I'm actually struggling to replicate your problem locally @AlexRobson . Could you provide your
versioninfo()?
julia> versioninfo() Julia Version 1.6.2 Commit 1b93d53fc4 (2021-07-14 15:36 UTC) Platform Info: OS: macOS (x86_64-apple-darwin18.7.0) CPU: Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz WORD_SIZE: 64 LIBM: libopenlibm LLVM: libLLVM-11.0.1 (ORCJIT, skylake)
Ah, interesting. Could you try running on 1.6.3? I'll see if I can get a copy of 1.6.2.
edit: I actually don't know how to get a copy of 1.6.2 anymore without building from source...
Hmm i still see this on 1.6.3. ~~Is it possible my base environment is interfering? (I've been caught out by revise loading up before). I'll check that~~ EDIT: Nope.
(@v1.6) pkg> activate .
Activating environment at `~/Projects/Squads/2021-08-FlowsSquad/testing/dev/ParameterHandling/Project.toml`
julia> using ParameterHandling
[ Info: Precompiling ParameterHandling [2412ca09-6db7-441c-8e3a-88d5709968c5]
julia> using ParameterHandling: positive, deferred, fixed, value, flatten
julia> v, unflatten = flatten(deferred(vec, rand(5,)))
<snip>
julia> v, unflatten = flatten(rand(4,4))
<snip>
julia> using Test
julia> v, unflatten = flatten(deferred(vec, rand(5,)))
<snip>
julia> @inferred unflatten(rand(5,)) # return type ParameterHandling.Deferred{typeof(vec), Tuple{Vector{Float64}}} does not match inferred return type ParameterHandling.Deferred{typeof(vec), _A} where _A
ERROR: return type ParameterHandling.Deferred{typeof(vec), Tuple{Vector{Float64}}} does not match inferred return type ParameterHandling.Deferred{typeof(vec), _A} where _A
julia> v, unflatten = flatten(rand(4,4))
<snip>
julia>
julia> @inferred unflatten(rand(16,)) # return type Matrix{Float64} does not match inferred return type Any
ERROR: return type Matrix{Float64} does not match inferred return type Any
shell> git status
On branch master
Your branch is up to date with 'origin/master'.
julia> versioninfo()
Julia Version 1.6.3
Hmm. So I think I've managed to track down the problem, and it's in the unflattener for Vector{<:Real}.
Could you confirm that changing the method of flatten for Vector{<:Real} to
function flatten(::Type{T}, x::Vector{R}) where {T<:Real,R<:Real}
unflatten_to_Vector(v) = Vector{R}(v)
return Vector{T}(x), unflatten_to_Vector
end
also solves the problem on your end?
Hmm. So I think I've managed to track down the problem, and it's in the unflattener for
Vector{<:Real}. Could you confirm that changing the method offlattenforVector{<:Real}tofunction flatten(::Type{T}, x::Vector{R}) where {T<:Real,R<:Real} unflatten_to_Vector(v) = Vector{R}(v) return Vector{T}(x), unflatten_to_Vector endalso solves the problem on your end?
Aha! Looks like it.
This does fix the majority of the inference tests - thanks. The one that doesn't work is the one related to mvnormal.
@inferred flatten(deferred(mvnormal, fixed(randn(5)), deferred(pdiagmat, positive.(rand(5) .+ 1e-1)))) # Fine - what's already tested
v, unflatten = flatten(deferred(mvnormal, fixed(randn(5)), deferred(pdiagmat, positive.(rand(5) .+ 1e-1))));
@inferred unflatten(rand(10,)) # Fails
This does seem fixed with this change to Deferred though - is this one problematic? Otherwise, I can just set the test above to not check inference as the args are already there.
function flatten(::Type{T}, x::Deferred) where {T<:Real}
v, unflatten = flatten(T, x.args)
unflatten_Deferred(v_new::Vector{T}) = Deferred(x.f, unflatten(v_new))
return v, unflatten_Deferred
end
->
function flatten(::Type{T}, x::D) where {T<:Real, D<:Deferred}
v, unflatten = flatten(T, x.args)
unflatten_Deferred(v_new::Vector{T}) = D(x.f, unflatten(v_new))
return v, unflatten_Deferred
end
Hmm yeah, it'll have the same issue regarding pushing dual numbers through.
I'm a bit confused as to why this is failing to infer, because when I dig down a bit everything else seems to infer. I wonder whether we're hitting some kind of compiler heuristic and it's giving up trying to infer or something.
edit: I think this is beyond my skill set. Maybe we could ask around a bit? See if someone who knows the compiler better could assist?
Hmm yeah, it'll have the same issue regarding pushing dual numbers through.
Hmm yeah this is very true. This is still just casting.
I'm a bit confused as to why this is failing to infer, because when I dig down a bit everything else seems to infer. I wonder whether we're hitting some kind of compiler heuristic and it's giving up trying to infer or something.
I've been playing around with this a bit - descent and REPL play. For what it's worth, I think the point at which inference breaks is the output of the map in unflatten_to_Tuple(v::Vector{T}). Stepping through, it looks like it infers fine before then, but then breaks after. I found as I think you did, looking at the component pieces of that seem to infer fine, but then fails when put all together, and although I haven't used descend much, I think that that's what I was seeing there too. Could be a general issue with chained deferred expressions which may be a common use case.
Maybe we could ask around a bit? See if someone who knows the compiler better could assist?
That makes sense. Definitly beyond me!
At this point, perhaps it's appropriate for me to just tidy this up with just your suggested fix and adding in the inference tests, and leave Deferred as is and just mark this as to not attempt inference in the tests for the time-being.