refactor: construc `.eq_def` before `.eq_n`
I have a theory that we’d get less “failed to generate equational theorem” like in #5667 if we first generate the .unfold theorem (a bit easier, no hypotheses to consider), and then use that to define the .eq_n that don't hold by rfl anyways.
This (stashed, partial) work explores that direction, and shows that indeed it fixes #5667.
But of course plenty of other programs break. This brecOn business is not so easy!
Mathlib CI status (docs):
- 💥 Mathlib branch lean-pr-testing-5669 build failed against this PR. (2024-10-10 17:18:37) View Log
- 💥 Mathlib branch lean-pr-testing-5669 build failed against this PR. (2024-10-10 20:58:55) View Log
- 💥 Mathlib branch lean-pr-testing-5669 build failed against this PR. (2025-02-06 16:00:24) View Log
I gave this another shot, storing some of my observations here.
I figured it might be more robust to prove the equational theorems if we had an unfolding lemma for T.brecOn, like we do for WellFounded.fix. Here is a possible construction of such a lemma for Nat:
def Nat.below.mk (motive : Nat → Sort u) (f : (n : Nat) → motive n) (n : Nat) :
Nat.below (motive := motive) n :=
Nat.rec (motive := fun n => Nat.below (motive := motive) n)
PUnit.unit (fun n ih => PProd.mk (f n) ih) n
@[simp] theorem Nat.below.mk_eq1 :
Nat.below.mk motive f 0 = PUnit.unit := rfl
@[simp] theorem Nat.below.mk_eq2 :
Nat.below.mk motive f (n+1) = PProd.mk (f n) (Nat.below.mk motive f n) := rfl
protected def Nat.brecOn'.{u} {motive : Nat → Sort u}
(t : Nat) (F_1 : (t : Nat) → Nat.below (motive := motive) t → motive t) :
motive t ×' Nat.below (motive := motive) t :=
Nat.rec ⟨F_1 Nat.zero PUnit.unit, PUnit.unit⟩ (fun n n_ih => ⟨F_1 n.succ n_ih, n_ih⟩) t
-- @[simp] theorem Nat.brecOn'_eq1 :
-- Nat.brecOn' (motive := motive) 0 f = ⟨f 0 PUnit.unit, PUnit.unit⟩ := rfl
-- @[simp] theorem Nat.brecOn'_eq2 :
-- Nat.brecOn' (motive := motive) (n+1) f =
-- ⟨ f (n+1) (Nat.brecOn' (motive := motive) n f),
-- (Nat.brecOn' (motive := motive) n f) ⟩ := rfl
theorem Nat.brecOn'_eq (motive : Nat → Sort u) (f) (n : Nat) :
Nat.brecOn' (motive := motive) n f =
let f' := Nat.below.mk motive (fun n => (Nat.brecOn' (motive := motive) n f).1) n
⟨f n f', f'⟩ :=
Nat.rec rfl (fun n n_ih =>
congrArg (fun x => PProd.mk (f n.succ x) x)
(n_ih.trans (congrArg (fun x => PProd.mk x.1 _) n_ih.symm)
)) n
theorem Nat.brecOn_eq (motive : Nat → Sort u) (f) (n : Nat) :
Nat.brecOn (motive := motive) n f =
f n (Nat.below.mk motive (fun n => Nat.brecOn (motive := motive) n f) n) :=
congrArg (fun x => x.1) (Nat.brecOn'_eq motive f n)
The @[simp] lemmas are just to understand the proofs better when dsimp’ing.
First big question: Can we write metaprogram that generate all these definitions and proofs, even for much more complex data types. (Especially Nat.brecOn'_eq looks hard, maybe there is an easier formulation).
With that, it seems we can prove unfolding lemmas a bit more easily:
example {f : Nat → Option Nat} {t} :
-- only needed to instantiate, because unification does not do that
let F := fun t f_1 =>
match f t with
| some u => u
| none =>
(match (motive := (t : Nat) → Nat.below (motive := (fun t => Nat)) t → Nat) t with
| Nat.zero => fun x => Nat.zero
| Nat.succ t' => fun x => x.1)
f_1
replace f t =
match f t with
| some u => u
| none =>
match t with
| Nat.zero => Nat.zero
| Nat.succ t' => replace f t' := by
intro F
delta replace
apply (Nat.brecOn_eq (motive := (fun t => Nat)) F t).trans
· unfold F
split
· rfl
· split
· rfl
· rfl
Anyways, it's clearly not a simple option to go that route, so shelving that idea here.
Just had another stab at this, wondering if after #10415 this (proving the unfolding lemma directly, using the existing machinery for proving the equational theorems) works better now. But it again fails already with code like
def g (i j : Nat) : Nat :=
if i < 5 then 0 else
match j with
| Nat.zero => 1
| Nat.succ j => g i j
The problem is in go3 the LHS has the the .brecOn unfolded, so the recursive calls have their bodies exposed, so steps like simpMatch? or simpTargetStar change these bodies, so the final URefl fails. This is is goal we have to deal with here:
case h_2
i : Nat
h✝ : ¬i < 5
j_1 j_2 : Nat
⊢ (if i < 5 then 0
else
(match (motive := (j : Nat) → Nat.below j → Nat) j_2.succ with
| Nat.zero => fun x => 1
| j.succ => fun x => x.1)
(Nat.rec
⟨(fun j f =>
if i < 5 then 0
else
(match (motive := (j : Nat) → Nat.below j → Nat) j with
| Nat.zero => fun x => 1
| j.succ => fun x => x.1)
f)
Nat.zero PUnit.unit,
PUnit.unit⟩
(fun n n_ih =>
⟨(fun j f =>
if i < 5 then 0
else
(match (motive := (j : Nat) → Nat.below j → Nat) j with
| Nat.zero => fun x => 1
| j.succ => fun x => x.1)
f)
n.succ n_ih,
n_ih⟩)
j_2)) =
g i j_2
I wonder if I can have- or let-bind the recursive .rec call or it's arguments after unfolding, to avoid this, as a cheap approximation of the plan in https://github.com/leanprover/lean4/pull/5669#issuecomment-2639904558.
Ah, there is a much easier way that doesn’t require the .below.mk construction:
protected def Nat.brecOn'.{u} {motive : Nat → Sort u}
(t : Nat) (F_1 : (t : Nat) → Nat.below (motive := motive) t → motive t) :
motive t ×' Nat.below (motive := motive) t :=
Nat.rec ⟨F_1 Nat.zero PUnit.unit, PUnit.unit⟩ (fun n n_ih => ⟨F_1 n.succ n_ih, n_ih⟩) t
theorem Nat.brecOn_eq_brecOn' (motive : Nat → Sort u) (f) (n : Nat) :
Nat.brecOn (motive := motive) n f = (Nat.brecOn' (motive := motive) n f).1 := by rfl
theorem Nat.brecOn_eq (motive : Nat → Sort u) (f) (n : Nat) :
Nat.brecOn (motive := motive) n f =
f n (Nat.brecOn' (motive := motive) n f).2 := by cases n <;> rfl
I should turn this into a RFC. Or just do it.
Hmm, although that still exposes the f in Nat.brecOn' (motive := motive) n f that we don't want to rewrite in. So maybe I need to let-bind it, so that simp does not touch it, but URefl can still unfold it.
Here's another idea, inspired by how inductive predicate brecOn works now:
def Nat.brec.below.{u} {motive : Nat → Sort u} (F_1 : (t : Nat) → @Nat.below motive t → motive t) : ∀ t : Nat, @Nat.below motive t :=
@Nat.rec (@Nat.below motive) ⟨⟩ (fun _ ih => ⟨F_1 _ ih, ih⟩)
def Nat.brec.{u} {motive : Nat → Sort u} (F_1 : (t : Nat) → @Nat.below motive t → motive t) : ∀ t : Nat, motive t :=
fun t => F_1 t (@Nat.brec.below motive F_1 t)
def Nat.add2 (x y : Nat) : Nat :=
@Nat.brec (fun _ => Nat → Nat)
(fun x f x_2 =>
Nat.add.match_1 (fun _ x => @Nat.below (fun _ => Nat → Nat) x → Nat) x_2 x (fun a _ => a)
(fun a _ x => (x.1 a).succ) f)
y x
theorem Nat.add2.my_eq_def (x y : Nat) : x.add2 y = Nat.add.match_1 (fun _ _ => Nat) x y (fun a => a) fun a b => (a.add2 b).succ := by
delta Nat.add2
conv => lhs; delta brec; dsimp only
split
· rfl
· rfl
One nice side effect is that stuff like this works with that idea:
def testMe (n : Nat) (b : Bool) : Nat :=
if b then
3
else
match n with
| 0 => 27
| k + 1 => testMe k !b
set_option smartUnfolding false
example : testMe n true = 3 := by
fail_if_success rfl -- currently fails, but works with `F_1` on the outside
cases n <;> rfl
What's the idea here in prose? I'm too tired to see it from the code right now
And do the desired definitional equalities still hold, e.g.
theorem Nat.add2_eq1 (a b : Nat) : Nat.add2 (a.succ) b = (a.add2 b).succ := rfl
The idea is to create an application of F_1 on the outer level instead of merely burying it in the recursor application. So basically similar to defining brecOn as f _ (brecOn' ...) directly.
And do the desired definitional equalities still hold, e.g.
theorem Nat.add2_eq1 (a b : Nat) : Nat.add2 (a.succ) b = (a.add2 b).succ := rfl
Yes (not this one but):
theorem Nat.add2_eq1 (a : Nat) : Nat.add2 a .zero = a := rfl
theorem Nat.add2_eq2 (a b : Nat) : Nat.add2 a b.succ = (a.add2 b).succ := rfl
both work. It seems like Nat.add2 reduces quite a bit slower than Nat.add here though, so maybe this is not a good idea...
Ah, right, wrong argument…
I’m surprised it's slower. Do you have an idea why? In the kernel or the elaborator? Are you comparing with “the” Nat.add (with reduce_nat support), or a copy that is not special-cased anywhere.
I actually looked into it a bit more and the results are actually quite surprising:
Test with results
import Lean
inductive MyNat where
| zero
| succ (n : MyNat)
open Lean Elab Meta Term Tactic
elab "share% " e:term : term => do
let e ← elabTerm e ‹_›
let env ← getEnv
let e := e.replace fun e => do
let .const fn _ := e.getAppFn | none
let info ← env.getProjectionFnInfo? fn
let args := e.getAppArgs
if h : info.numParams < args.size then
let some (.ctorInfo c) := env.find? info.ctorName | none
return mkAppN (.proj c.induct info.i args[info.numParams]) (args.extract (info.numParams + 1))
none
return ShareCommon.shareCommon' e
@[noinline, nospecialize]
def measureHeartbeats (x : MetaM α) : MetaM (α × Nat) := do
let t1 ← IO.getNumHeartbeats
let res ← x
let t2 ← IO.getNumHeartbeats
return (res, t2 - t1)
elab "test" : tactic => do
let goal ← getMainGoal
let mkApp3 (.const ``Eq us) α lhs rhs ← goal.getType | throwError "invalid goal"
modifyThe Lean.Meta.State fun state => { state with cache := {} }
let (b, t1) ← measureHeartbeats (isDefEq lhs rhs)
let (res, t2) ← measureHeartbeats (return Kernel.isDefEq (← getEnv) (← getLCtx) lhs rhs)
let .ok b' := res | throwError "kernel error"
unless b && b' do
throwError "not defeq"
Lean.logInfo m!"{t1} and {t2}"
goal.assign (mkApp2 (.const ``Eq.refl us) α lhs)
set_option linter.unusedVariables false
noncomputable abbrev MyNat.brec.below.{u} : ∀ {motive : MyNat → Sort u} (F_1 : (t : MyNat) → @MyNat.below motive t → motive t),
∀ t : MyNat, @MyNat.below motive t :=
@(share% fun {motive} F_1 => @MyNat.rec (@MyNat.below motive) ⟨⟩ (fun _ ih => ⟨F_1 _ ih, ih⟩))
noncomputable abbrev MyNat.brec.{u} : ∀ {motive : MyNat → Sort u} (F_1 : (t : MyNat) → @MyNat.below motive t → motive t), ∀ t : MyNat, motive t :=
@(share% fun {motive} F_1 => fun t => F_1 t (@MyNat.brec.below motive F_1 t))
noncomputable def MyNat.brec'.{u} : ∀ {motive : MyNat → Sort u} (F_1 : (t : MyNat) → @MyNat.below motive t → motive t), ∀ t : MyNat, motive t :=
@(share% fun {motive} F_1 => fun t => F_1 t (@MyNat.rec (@MyNat.below motive) ⟨⟩ (fun _ ih => ⟨F_1 _ ih, ih⟩) t))
-- basic variant with built-in support
def MyNat.add (x y : MyNat) : MyNat :=
match x, y with
| x, .zero => x
| x, .succ y => (x.add y).succ
-- sanity check, equivalent to `MyNat.add`
noncomputable def MyNat.add2 (x y : MyNat) : MyNat :=
share% @MyNat.brecOn (fun _ => MyNat → MyNat) y
(fun y f x =>
MyNat.add.match_1 (fun _ y => @MyNat.below (fun _ => MyNat → MyNat) y → MyNat) x y (fun a _ => a)
(fun a _ x => (x.1 a).succ) f)
x
-- using `brec`
noncomputable def MyNat.add3 (x y : MyNat) : MyNat :=
share% @MyNat.brec (fun _ => MyNat → MyNat)
(fun x f x_2 =>
MyNat.add.match_1 (fun _ x => @MyNat.below (fun _ => MyNat → MyNat) x → MyNat) x_2 x (fun a _ => a)
(fun a _ x => (x.1 a).succ) f)
y x
-- using `brec'`
noncomputable def MyNat.add4 (x y : MyNat) : MyNat :=
share% @MyNat.brec' (fun _ => MyNat → MyNat)
(fun x f x_2 =>
MyNat.add.match_1 (fun _ x => @MyNat.below (fun _ => MyNat → MyNat) x → MyNat) x_2 x (fun a _ => a)
(fun a _ x => (x.1 a).succ) f)
y x
set_option smartUnfolding false
/-- info: 1265 and 230 -/
#guard_msgs in example : MyNat.add x .zero = x := by test
/-- info: 1265 and 230 -/
#guard_msgs in example : MyNat.add2 x .zero = x := by test
/-- info: 1444 and 135 -/
#guard_msgs in example : MyNat.add3 x .zero = x := by test
/-- info: 1542 and 192 -/
#guard_msgs in example : MyNat.add4 x .zero = x := by test
/-- info: 2387 and 553 -/
#guard_msgs in example : MyNat.add x (.succ y) = .succ (MyNat.add x y) := by test
/-- info: 2380 and 553 -/
#guard_msgs in example : MyNat.add2 x (.succ y) = .succ (MyNat.add2 x y) := by test
/-- info: 6024 and 763 -/
#guard_msgs in example : MyNat.add3 x (.succ y) = .succ (MyNat.add3 x y) := by test
/-- info: 5014 and 620 -/
#guard_msgs in example : MyNat.add4 x (.succ y) = .succ (MyNat.add4 x y) := by test
/-- info: 2415 and 403 -/
#guard_msgs in example : MyNat.add .zero (.succ .zero) = .succ .zero := by test
/-- info: 2414 and 403 -/
#guard_msgs in example : MyNat.add2 .zero (.succ .zero) = .succ .zero := by test
/-- info: 2992 and 493 -/
#guard_msgs in example : MyNat.add3 .zero (.succ .zero) = .succ .zero := by test
/-- info: 2868 and 359 -/
#guard_msgs in example : MyNat.add4 .zero (.succ .zero) = .succ .zero := by test
/-- info: 3453 and 569 -/
#guard_msgs in example : MyNat.add .zero (.succ (.succ .zero)) = .succ (.succ .zero) := by test
/-- info: 3452 and 569 -/
#guard_msgs in example : MyNat.add2 .zero (.succ (.succ .zero)) = .succ (.succ .zero) := by test
/-- info: 4194 and 655 -/
#guard_msgs in example : MyNat.add3 .zero (.succ (.succ .zero)) = .succ (.succ .zero) := by test
/-- info: 4049 and 519 -/
#guard_msgs in example : MyNat.add4 .zero (.succ (.succ .zero)) = .succ (.succ .zero) := by test
/-- info: 4494 and 735 -/
#guard_msgs in example : MyNat.add .zero (.succ (.succ (.succ .zero))) = .succ (.succ (.succ .zero)) := by test
/-- info: 4495 and 735 -/
#guard_msgs in example : MyNat.add2 .zero (.succ (.succ (.succ .zero))) = .succ (.succ (.succ .zero)) := by test
/-- info: 5379 and 815 -/
#guard_msgs in example : MyNat.add3 .zero (.succ (.succ (.succ .zero))) = .succ (.succ (.succ .zero)) := by test
/-- info: 5235 and 679 -/
#guard_msgs in example : MyNat.add4 .zero (.succ (.succ (.succ .zero))) = .succ (.succ (.succ .zero)) := by test
/-- info: 5544 and 901 -/
#guard_msgs in example : MyNat.add .zero (.succ (.succ (.succ (.succ .zero)))) = .succ (.succ (.succ (.succ .zero))) := by test
/-- info: 5542 and 901 -/
#guard_msgs in example : MyNat.add2 .zero (.succ (.succ (.succ (.succ .zero)))) = .succ (.succ (.succ (.succ .zero))) := by test
/-- info: 6567 and 975 -/
#guard_msgs in example : MyNat.add3 .zero (.succ (.succ (.succ (.succ .zero)))) = .succ (.succ (.succ (.succ .zero))) := by test
/-- info: 6424 and 839 -/
#guard_msgs in example : MyNat.add4 .zero (.succ (.succ (.succ (.succ .zero)))) = .succ (.succ (.succ (.succ .zero))) := by test
/-- info: 9695 and 1565 -/
#guard_msgs in example : MyNat.add .zero (.succ (.succ (.succ (.succ (.succ (.succ (.succ (.succ .zero)))))))) =
.succ (.succ (.succ (.succ (.succ (.succ (.succ (.succ .zero))))))) := by test
/-- info: 9694 and 1565 -/
#guard_msgs in example : MyNat.add2 .zero (.succ (.succ (.succ (.succ (.succ (.succ (.succ (.succ .zero)))))))) =
.succ (.succ (.succ (.succ (.succ (.succ (.succ (.succ .zero))))))) := by test
/-- info: 11345 and 1615 -/
#guard_msgs in example : MyNat.add3 .zero (.succ (.succ (.succ (.succ (.succ (.succ (.succ (.succ .zero)))))))) =
.succ (.succ (.succ (.succ (.succ (.succ (.succ (.succ .zero))))))) := by test
/-- info: 11212 and 1479 -/
#guard_msgs in example : MyNat.add4 .zero (.succ (.succ (.succ (.succ (.succ (.succ (.succ (.succ .zero)))))))) =
.succ (.succ (.succ (.succ (.succ (.succ (.succ (.succ .zero))))))) := by test
It seems like the new brec works well when reducing closed terms (although the kernel is the only one to think so) and not so well for the generic equations (which is I guess to be expected). Also worth noting is that using the brec directly as I stated above is only efficient for .zero whereas larger closed terms need the inlined version to be competetive.
Thanks for investigating!
You turn set_option smartUnfolding false off, but that makes your elaborator measurements not representative. I guess normally, with smartUnfolding on, it doesn’t matter which encoding is used because the elaborator doesn’t see it anyways.
The kernel perf numbers vary a bit unpredicable between whether there are many .succ calls or not…
Maybe it’s worth getting #10606 to work to the point where we can run !bench on lean or even on mathlib, and then try your variant as well.
Closing this in favor of #10606