AlexanderMath

Results 41 comments of AlexanderMath

Looks great! Minor comments: (1) some stuff wrt naming and hinting on linear algebra stuff (likely not too useful) and (2) potential performance issue with `at[i].set(.)` vs `jax.lax.dynamic_update_slice(.)` (could you...

Initial implementation 76M cycles. Aim for 1M cycles or so. Code on this branch https://github.com/graphcore-research/pyscf-ipu/tree/hessenberg Note: Algorithm is almost identical to tesselate_ipu.linalg.qr, it just multiplies with another H from the...

@balancap Do you have any pointers on hard parts? @paolot-gc is looking at improving above profile. I'll take a look at https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.eigh_tridiagonal.html and https://github.com/tensorflow/tensorflow/blob/v2.13.0/tensorflow/python/ops/linalg/linalg_impl.py#L1232-L1588 during weekend.

Profile of a single iteration @balancap ![image](https://github.com/graphcore-research/pyscf-ipu/assets/8614529/5ed451cf-836a-452d-abc4-7509d04251e2)

@balancap @paolot-gc **Context:** For `M.shape=(1024,1024)` with `M.T=M` we want `eigh(M)`. We use the classic `hessenberg(M)=tri_diagonal` to turn problem into `eigh(tri_diagonal)`. **Problem:** Literature claims `eigvals(tri_diagonal)` are easy and `eigvcects(tri_diagonal)` are hard...

> You can't make the eigenvalue gap arbitrarily large if lambda_i = lambda_{i+1}, so in practice you can't make it arbitrarily large if they are close to equal. Agree. >...

@awf The following code demonstrates the inequality. ``` import pyscf import numpy as np mol = pyscf.gto.Mole(atom=[["C", (0,0,0)], ["C", (1,2,3)]]) mol.build() ERI = mol.intor("int2e_sph") N = mol.nao_nr() for a in...

(postpone moving nanoDFT.py until Sep 22 to not break links in NeurIPS rebuttal)

Hypothesis: nanoDFT(mol, opt) needs access to mol while Jax traces nanoDFT. Jax allows this with `static_argnums` by using hashing to check for recompilation. Because mol/opt doesn't support hash we instead...

Neat! Let's revise and make it into a PR :) **Q1.** What happens if we remove all the jax tree stuff? (I might just be misunderstanding, but don't see why...