penzai
penzai copied to clipboard
A JAX research toolkit for building, editing, and visualizing neural networks.
Sometimes `nmap`'ed computations don't all fit in memory at once and there are not enough devices to shard the computation over (this is limitation is particularly salient when using penzai...
It'd be great if penzai would support model quantization out of the box. I know this is a lot of work to implement, but right now the lack of quantization...
## Changes This PR adds missing attributes to the lists of handled/ignored configuration attributes in the model conversion functions for: - Llama models (`llama_from_huggingface_model`) - Mistral models (`mistral_from_huggingface_model`) - GPT-NeoX...
I'm confused what the intended behavior is between penzai.nn and pz and pz.nn. Here's an example of the confusing behavior. Basically, when you import nn, you don't get everything in...
The following code outputs "Call 1 succeeded" and then hangs indefinitely ```python import dataclasses import jax import jax.numpy as jnp from penzai import pz @pz.pytree_dataclass class Indexer(pz.Struct): index: int =...
When I run ```py hf_model = transformers.LlamaForCausalLM.from_pretrained("Unbabel/TowerInstruct-7B-v0.2") pz_model = penzai.models.transformer.variants.llama.llama_from_huggingface_model(hf_model) ``` the second line fails with ```sh ValueError: Conversion of a LlamaForCausalLM does not support these configuration attributes: {'use_cache': False,...
## Bug Description When attempting to convert a HuggingFace model to a Penzai model using `[llama/mistral/gpt_neox]_from_huggingface_model`, the conversion fails with a ValueError when the model configuration contains certain attributes that...
I am trying to create a simple linear layer as follows, ``` from penzai import pz import jax embed_axis = "embed_axis" head_axis = "head_axis" num_heads = 4 embed_size = 10...