Ross Wightman

Results 58 issues of Ross Wightman

I've been noticing some strange behaviour running OpenSim in a synchronous vectored environment. Initially, the 8-16 environments step evenly with roughly equal CPU utilizations and step times across the environments...

* add gradient caching support via (https://github.com/luyug/GradCache/tree/main/src/grad_cache), in contrast to #65 , this impl uses the gradcache impl as is, and uses separate image / text towers via wrappers *...

The branch associated with this PR is for development of `timm bits` -- a significant update and reorganizaiton of timm's training scripts and associated modules. It works w/ TPUs (PyTorch...

Goal is to adapt timm training and validation scripts to work well with PyTorch XLA on TPU/TPU Pods (maybe CPU/GPU), and PyTorch on GPU. As part of this I will...

enhancement

Along with updated training/validation components in #458 for TPU support, support use of DeepSpeed/ZeRO * https://pytorch.org/docs/master/distributed.optim.html#torch.distributed.optim.ZeroRedundancyOptimizer * https://github.com/microsoft/DeepSpeed It would be fairly easy to support w/ current training code, however...

enhancement

I'm working on training script (https://github.com/rwightman/efficientnet-jax/blob/master/tf_linen_train.py) based on Flax Linen ImageNet example (https://github.com/google/flax/blob/master/linen_examples/imagenet/imagenet_lib.py). It was working great on a system with 2 x Titan RTX. The same setup on 2...

bug
P0 (urgent)
NVIDIA GPU

Working from a modified ImageNet Linen example, I've added two state attr for Polyak averaging ema values as so ``` @flax.struct.dataclass class TrainState: step: int optimizer: flax.optim.Optimizer model_state: Any dynamic_scale:...

Status: pull requests welcome
Priority: P2 - eventual

### Problem you have encountered: Testing my EfficientNet impl on TPU w/ bfloat16 the training collapsed 2/3 of the way through my training schedule. The models were training reasonably on...

Re your MobileVit2, these two norms are not equivalent and it would be misleading to call it LayerNorm2d as the group norm w/ groups=1 is not equivalent. 'LayerNorm2d' is already...