penzai icon indicating copy to clipboard operation
penzai copied to clipboard

A JAX research toolkit for building, editing, and visualizing neural networks.

Results 19 penzai issues
Sort by recently updated
recently updated
newest added

Autoplay sliders would be really fun.

feature-request

Sometimes `nmap`'ed computations don't all fit in memory at once and there are not enough devices to shard the computation over (this is limitation is particularly salient when using penzai...

feature-request

It'd be great if penzai would support model quantization out of the box. I know this is a lot of work to implement, but right now the lack of quantization...

feature-request

## Changes This PR adds missing attributes to the lists of handled/ignored configuration attributes in the model conversion functions for: - Llama models (`llama_from_huggingface_model`) - Mistral models (`mistral_from_huggingface_model`) - GPT-NeoX...

I'm confused what the intended behavior is between penzai.nn and pz and pz.nn. Here's an example of the confusing behavior. Basically, when you import nn, you don't get everything in...

The following code outputs "Call 1 succeeded" and then hangs indefinitely ```python import dataclasses import jax import jax.numpy as jnp from penzai import pz @pz.pytree_dataclass class Indexer(pz.Struct): index: int =...

When I run ```py hf_model = transformers.LlamaForCausalLM.from_pretrained("Unbabel/TowerInstruct-7B-v0.2") pz_model = penzai.models.transformer.variants.llama.llama_from_huggingface_model(hf_model) ``` the second line fails with ```sh ValueError: Conversion of a LlamaForCausalLM does not support these configuration attributes: {'use_cache': False,...

## Bug Description When attempting to convert a HuggingFace model to a Penzai model using `[llama/mistral/gpt_neox]_from_huggingface_model`, the conversion fails with a ValueError when the model configuration contains certain attributes that...

I am trying to create a simple linear layer as follows, ``` from penzai import pz import jax embed_axis = "embed_axis" head_axis = "head_axis" num_heads = 4 embed_size = 10...