AlexanderMath
AlexanderMath
Looks great! Minor comments: (1) some stuff wrt naming and hinting on linear algebra stuff (likely not too useful) and (2) potential performance issue with `at[i].set(.)` vs `jax.lax.dynamic_update_slice(.)` (could you...
Initial implementation 76M cycles. Aim for 1M cycles or so. Code on this branch https://github.com/graphcore-research/pyscf-ipu/tree/hessenberg Note: Algorithm is almost identical to tesselate_ipu.linalg.qr, it just multiplies with another H from the...
@balancap Do you have any pointers on hard parts? @paolot-gc is looking at improving above profile. I'll take a look at https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.eigh_tridiagonal.html and https://github.com/tensorflow/tensorflow/blob/v2.13.0/tensorflow/python/ops/linalg/linalg_impl.py#L1232-L1588 during weekend.
Profile of a single iteration @balancap 
@balancap @paolot-gc **Context:** For `M.shape=(1024,1024)` with `M.T=M` we want `eigh(M)`. We use the classic `hessenberg(M)=tri_diagonal` to turn problem into `eigh(tri_diagonal)`. **Problem:** Literature claims `eigvals(tri_diagonal)` are easy and `eigvcects(tri_diagonal)` are hard...
> You can't make the eigenvalue gap arbitrarily large if lambda_i = lambda_{i+1}, so in practice you can't make it arbitrarily large if they are close to equal. Agree. >...
@awf The following code demonstrates the inequality. ``` import pyscf import numpy as np mol = pyscf.gto.Mole(atom=[["C", (0,0,0)], ["C", (1,2,3)]]) mol.build() ERI = mol.intor("int2e_sph") N = mol.nao_nr() for a in...
(postpone moving nanoDFT.py until Sep 22 to not break links in NeurIPS rebuttal)
Hypothesis: nanoDFT(mol, opt) needs access to mol while Jax traces nanoDFT. Jax allows this with `static_argnums` by using hashing to check for recompilation. Because mol/opt doesn't support hash we instead...
Neat! Let's revise and make it into a PR :) **Q1.** What happens if we remove all the jax tree stuff? (I might just be misunderstanding, but don't see why...