ChainRulesCore.jl
ChainRulesCore.jl copied to clipboard
AD-backend agnostic system defining custom forward and reverse mode rules. This is the light weight core to allow you to define rules for your functions in your packages, without depending on any part...
Currently, structured zero tangents are allowed to pass through during projection for (I believe) all number types. This was found while trying to write a rrule that passes https://github.com/FluxML/Zygote.jl/blob/v0.6.43/test/features.jl#L528 (itself...
This PR wants to use `add!!` to add thunks, which should be safe if the result of `unthunk`ing is always an array we are free to mutate. And it adds...
based on discussion https://github.com/JuliaDiff/Diffractor.jl/pull/54 and in slack. This is an initial implementation of forward and reverse chunked mode AD.
The example below attempts to project a `Unitful.Quantity` differential onto a sparse array of `Float64` ```julia using SparseArrays using ChainRulesCore using Unitful using UnitfulChainRules # defines projection of quantities onto...
This is a general issue, but for a specific incarnation, https://github.com/JuliaDiff/ChainRules.jl/blob/8073c7c4638bdd46f4e822d2ab72423c051c5e4b/src/rulesets/Base/array.jl#L40 ``` function rrule(::typeof(Base.vect), X::Vararg{T, N}) where {T, N} vect_pullback(ȳ) = (NoTangent(), NTuple{N}(ȳ)...) return Base.vect(X...), vect_pullback end ``` This rule...
Hello! I have a minimum example here about Flux and ChainRulesCore ```julia using Flux, ChainRulesCore import ChainRulesCore.rrule function f(a::Float32, b::Float32) return a * b end function rrule(::typeof(f), a::Float32, b::Float32) println("rrule...
from #90 it seems that @YingboMa wants `frule` to be able to be called on a `Vector` of sensitivies for the same primal value, and get a of sensitivities vector...
This is one way not to make a Tangent with only zero types in it: ```julia julia> ProjectTo(Ref(1))(Ref(1)) # ok Tangent{Base.RefValue{Int64}}(x = 1.0,) julia> ProjectTo(Ref(1))(Ref(NoTangent())) # could collapse to NoTangent()...
This adds a function to check whether a given type is non-differentiable. The purpose is to let you test whether to take the trivial path for some rule. It goes...