FYI: Equinox example with variadic generics
This is not so much an issue as an FYI. I was going through the Equinox examples and converted the BERT one to use a variadic generics approach to typing. I know there's some thinking on this topic at https://github.com/google/jaxtyping/blob/main/docs/faq.md. I thought it could be interesting for others to see (one first attempt at) what's possible with this approach ATM. Most of the types are pretty precise (i.e. able to add and remove appropriately-sized dimensions), but I didn't really try to do anything sensible with pytrees. This is maybe just because I'm type-brained, but I found the additional type declarations quite helpful for my understanding. (I can show the typing stubs I'm using too if that would be useful.)
I also think doing this uncovered one discrepancy. The example has Float[Array, "seq_len hidden_size"] for the FeedForward block, but it seems like it should be just Float[Array, "hidden_size"] because the seq_len is only added by the vmap later.
from __future__ import annotations
import functools
from math import ceil
import os
import re
from typing import (
Any,
Generic,
Iterator,
NewType,
Optional,
TypeVar,
TypeVarTuple,
TypedDict,
cast,
)
from datasets import DatasetDict, load_dataset
import einops
import equinox as eqx
import jax
import jax.numpy as jnp
from jax.numpy import ndarray
from jax.random import PRNGKey
import numpy as np
import optax
import tqdm
from transformers import AutoTokenizer
SeqLenT = TypeVar("SeqLenT", bound=int)
EmbedSizeT = TypeVar("EmbedSizeT", bound=int)
class EmbedderBlock(eqx.Module, Generic[EmbedSizeT]):
"""BERT embedder."""
token_embedder: eqx.nn.Embedding[EmbedSizeT]
segment_embedder: eqx.nn.Embedding[EmbedSizeT]
position_embedder: eqx.nn.Embedding[EmbedSizeT]
layernorm: eqx.nn.LayerNorm[EmbedSizeT]
dropout: eqx.nn.Dropout
def __init__(
self,
vocab_size: int,
max_length: int,
type_vocab_size: int,
embedding_size: EmbedSizeT,
hidden_size: EmbedSizeT,
dropout_rate: float,
key: jax.random.PRNGKey,
):
token_key, segment_key, position_key = jax.random.split(key, 3)
self.token_embedder = eqx.nn.Embedding(
num_embeddings=vocab_size, embedding_size=embedding_size, key=token_key
)
self.segment_embedder = eqx.nn.Embedding(
num_embeddings=type_vocab_size,
embedding_size=embedding_size,
key=segment_key,
)
self.position_embedder = eqx.nn.Embedding(
num_embeddings=max_length, embedding_size=embedding_size, key=position_key
)
self.layernorm = eqx.nn.LayerNorm(shape=hidden_size)
self.dropout = eqx.nn.Dropout(dropout_rate)
def __call__(
self,
token_ids: ndarray[SeqLenT, int],
position_ids: ndarray[SeqLenT, int],
segment_ids: ndarray[SeqLenT, int],
enable_dropout: bool = False,
key: Optional[PRNGKey] = None,
) -> ndarray[SeqLenT, EmbedSizeT, float]:
tokens = self.token_embedder(token_ids)
segments = self.segment_embedder(segment_ids)
positions = self.position_embedder(position_ids)
embedded_inputs: ndarray[SeqLenT, EmbedSizeT, float] = tokens + segments + positions
embedded_inputs = jax.vmap(self.layernorm)(embedded_inputs)
embedded_inputs = self.dropout(embedded_inputs, inference=not enable_dropout, key=key)
return embedded_inputs
class FeedForwardBlock(eqx.Module, Generic[EmbedSizeT]):
"""A single transformer feed forward block."""
IntermediateSize = NewType("IntermediateSize", int)
mlp: eqx.nn.Linear[EmbedSizeT, FeedForwardBlock.IntermediateSize]
output: eqx.nn.Linear[FeedForwardBlock.IntermediateSize, EmbedSizeT]
layernorm: eqx.nn.LayerNorm[EmbedSizeT]
dropout: eqx.nn.Dropout
def __init__(
self,
hidden_size: EmbedSizeT,
intermediate_size: int,
dropout_rate: float,
key: jax.random.PRNGKey,
):
mlp_key, output_key = jax.random.split(key)
self.mlp = eqx.nn.Linear(
in_features=hidden_size,
out_features=self.IntermediateSize(intermediate_size),
key=mlp_key,
)
self.output = eqx.nn.Linear(
in_features=self.IntermediateSize(intermediate_size),
out_features=hidden_size,
key=output_key,
)
self.layernorm = eqx.nn.LayerNorm(shape=hidden_size)
self.dropout = eqx.nn.Dropout(dropout_rate)
def __call__(
self,
inputs: ndarray[EmbedSizeT, float],
enable_dropout: bool = True,
key: Optional[PRNGKey] = None,
) -> ndarray[EmbedSizeT, float]:
hidden = self.mlp(inputs)
hidden = jax.nn.gelu(hidden)
output = self.output(hidden)
output = self.dropout(output, inference=not enable_dropout, key=key)
# Residual and layer norm.
output += inputs
output = self.layernorm(output)
return output
class AttentionBlock(eqx.Module, Generic[SeqLenT, EmbedSizeT]):
"""A single transformer attention block."""
NumHeads = NewType("NumHeads", int)
attention: eqx.nn.MultiheadAttention[
AttentionBlock.NumHeads,
SeqLenT,
SeqLenT,
EmbedSizeT,
EmbedSizeT,
EmbedSizeT,
EmbedSizeT,
]
layernorm: eqx.nn.LayerNorm[EmbedSizeT]
dropout: eqx.nn.Dropout
num_heads: AttentionBlock.NumHeads = eqx.field(static=True)
def __init__(
self,
hidden_size: EmbedSizeT,
num_heads: int,
dropout_rate: float,
attention_dropout_rate: float,
key: jax.random.PRNGKey,
):
self.num_heads = self.NumHeads(num_heads)
self.attention = eqx.nn.MultiheadAttention(
num_heads=self.num_heads,
query_size=hidden_size,
key_size=hidden_size,
value_size=hidden_size,
output_size=hidden_size,
use_query_bias=True,
use_key_bias=True,
use_value_bias=True,
use_output_bias=True,
dropout_p=attention_dropout_rate,
key=key,
)
self.layernorm = eqx.nn.LayerNorm(shape=hidden_size)
self.dropout = eqx.nn.Dropout(dropout_rate)
def __call__(
self,
inputs: ndarray[SeqLenT, EmbedSizeT, float],
mask: Optional[ndarray[SeqLenT, int]],
enable_dropout: bool = False,
key: Optional[PRNGKey] = None,
) -> ndarray[SeqLenT, EmbedSizeT, float]:
full_mask = None if mask is None else self._make_self_attention_mask(mask)
attention_key, dropout_key = (None, None) if key is None else jax.random.split(key)
attention_output = self.attention(
query=inputs,
key_=inputs,
value=inputs,
mask=full_mask,
inference=not enable_dropout,
key=attention_key,
)
result = attention_output
result = self.dropout(result, inference=not enable_dropout, key=dropout_key)
result = result + inputs
result = jax.vmap(self.layernorm)(result)
return result
def _make_self_attention_mask(
self, mask: ndarray[SeqLenT, int]
) -> ndarray[AttentionBlock.NumHeads, SeqLenT, SeqLenT, float]:
square_mask = jnp.multiply(jnp.expand_dims(mask, axis=-1), jnp.expand_dims(mask, axis=-2))
xmask = jnp.expand_dims(square_mask, axis=-3)
return jnp.repeat(xmask, repeats=self.num_heads, axis=-3).astype(jnp.float32)
class TransformerLayer(eqx.Module, Generic[SeqLenT, EmbedSizeT]):
"""A single transformer layer."""
attention_block: AttentionBlock[SeqLenT, EmbedSizeT]
ff_block: FeedForwardBlock[EmbedSizeT]
def __init__(
self,
hidden_size: EmbedSizeT,
intermediate_size: int,
num_heads: int,
dropout_rate: float,
attention_dropout_rate: float,
key: PRNGKey,
):
attention_key, ff_key = jax.random.split(key)
self.attention_block = AttentionBlock(
hidden_size=hidden_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
attention_dropout_rate=attention_dropout_rate,
key=attention_key,
)
self.ff_block = FeedForwardBlock(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
dropout_rate=dropout_rate,
key=ff_key,
)
def __call__(
self,
inputs: ndarray[SeqLenT, EmbedSizeT, float],
mask: Optional[ndarray[SeqLenT, int]] = None,
*,
enable_dropout: bool = False,
key: Optional[PRNGKey] = None,
) -> ndarray[SeqLenT, EmbedSizeT, float]:
attn_key, ff_key = (None, None) if key is None else jax.random.split(key)
attention_output = self.attention_block(inputs, mask, enable_dropout=enable_dropout, key=attn_key)
seq_len = inputs.shape[0]
ff_keys = None if ff_key is None else jax.random.split(ff_key, num=seq_len)
output = jax.vmap(self.ff_block, in_axes=(0, None, 0))(attention_output, enable_dropout, ff_keys)
return output
class EncoderOut(TypedDict, Generic[SeqLenT, EmbedSizeT]):
embeddings: ndarray[SeqLenT, EmbedSizeT, float]
layers: list[ndarray[SeqLenT, EmbedSizeT, float]]
pooled: ndarray[EmbedSizeT, float]
class Encoder(eqx.Module, Generic[SeqLenT, EmbedSizeT]):
"""Full BERT encoder."""
embedder_block: EmbedderBlock[EmbedSizeT]
layers: list[TransformerLayer[SeqLenT, EmbedSizeT]]
pooler: eqx.nn.Linear[EmbedSizeT, EmbedSizeT]
def __init__(
self,
vocab_size: int,
max_length: int,
type_vocab_size: int,
embedding_size: EmbedSizeT,
hidden_size: EmbedSizeT,
intermediate_size: int,
num_layers: int,
num_heads: int,
dropout_rate: float,
attention_dropout_rate: float,
key: jax.random.PRNGKey,
):
embedder_key, layer_key, pooler_key = jax.random.split(key, num=3)
self.embedder_block = EmbedderBlock(
vocab_size=vocab_size,
max_length=max_length,
type_vocab_size=type_vocab_size,
embedding_size=embedding_size,
hidden_size=hidden_size,
dropout_rate=dropout_rate,
key=embedder_key,
)
layer_keys = jax.random.split(layer_key, num=num_layers)
self.layers = []
for layer_key in layer_keys:
self.layers.append(
TransformerLayer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
attention_dropout_rate=attention_dropout_rate,
key=layer_key,
)
)
self.pooler = eqx.nn.Linear(in_features=hidden_size, out_features=hidden_size, key=pooler_key)
def __call__(
self,
token_ids: ndarray[SeqLenT, int],
position_ids: ndarray[SeqLenT, int],
segment_ids: ndarray[SeqLenT, int],
*,
enable_dropout: bool = False,
key: Optional[PRNGKey] = None,
) -> EncoderOut[SeqLenT, EmbedSizeT]:
emb_key, l_key = (None, None) if key is None else jax.random.split(key)
embeddings = self.embedder_block(
token_ids=token_ids,
position_ids=position_ids,
segment_ids=segment_ids,
enable_dropout=enable_dropout,
key=emb_key,
)
# We assume that all 0-values should be masked out.
mask = (token_ids != 0).astype(jnp.int32)
x = embeddings
layer_outputs: list[ndarray[SeqLenT, EmbedSizeT, float]] = []
for layer in self.layers:
cl_key, l_key = (None, None) if l_key is None else jax.random.split(l_key)
x = layer(x, mask, enable_dropout=enable_dropout, key=cl_key)
layer_outputs.append(x)
# BERT pooling.
# The first token in the last layer is the embedding of the "[CLS]" token.
first_token_last_layer = x[..., 0, :]
pooled = self.pooler(first_token_last_layer)
pooled = jnp.tanh(pooled)
return {"embeddings": embeddings, "layers": layer_outputs, "pooled": pooled}
NumClassesT = TypeVar("NumClassesT", bound=int)
class BertClassifier(eqx.Module, Generic[SeqLenT, NumClassesT]):
"""BERT classifier."""
EmbedSize = NewType("EmbedSize", int)
encoder: Encoder[SeqLenT, BertClassifier.EmbedSize]
classifier_head: eqx.nn.Linear[BertClassifier.EmbedSize, NumClassesT]
dropout: eqx.nn.Dropout
def __init__(
self,
config: Config,
num_classes: NumClassesT,
key: jax.random.PRNGKey,
):
encoder_key, head_key = jax.random.split(key)
hidden = self.EmbedSize(config["hidden_size"])
self.encoder = Encoder(
vocab_size=config["vocab_size"],
max_length=config["max_position_embeddings"],
type_vocab_size=config["type_vocab_size"],
embedding_size=hidden,
hidden_size=hidden,
intermediate_size=config["intermediate_size"],
num_layers=config["num_hidden_layers"],
num_heads=config["num_attention_heads"],
dropout_rate=config["hidden_dropout_prob"],
attention_dropout_rate=config["attention_probs_dropout_prob"],
key=encoder_key,
)
self.classifier_head = eqx.nn.Linear(in_features=hidden, out_features=num_classes, key=head_key)
self.dropout = eqx.nn.Dropout(config["hidden_dropout_prob"])
def __call__(
self,
inputs: Inputs[(), SeqLenT],
enable_dropout: bool = True,
key: Optional[PRNGKey] = None,
) -> ndarray[NumClassesT, float]:
seq_len = inputs["input_ids"].shape[-1]
position_ids = jnp.arange(seq_len)
e_key, d_key = (None, None) if key is None else jax.random.split(key)
pooled_output = self.encoder(
token_ids=inputs["input_ids"],
segment_ids=inputs["token_type_ids"],
position_ids=position_ids,
enable_dropout=enable_dropout,
key=e_key,
)["pooled"]
pooled_output = self.dropout(pooled_output, inference=not enable_dropout, key=d_key)
return self.classifier_head(pooled_output)
class Config(TypedDict):
vocab_size: int
hidden_size: int
num_hidden_layers: int
num_attention_heads: int
hidden_act: str
intermediate_size: int
hidden_dropout_prob: float
attention_probs_dropout_prob: float
max_position_embeddings: int
type_vocab_size: int
initializer_range: float
SeqLen = NewType("SeqLen", int)
NumClasses = NewType("NumClasses", int)
# Tiny-BERT config.
bert_config: Config = {
"vocab_size": 30522,
"hidden_size": 128,
"num_hidden_layers": 2,
"num_attention_heads": 2,
"hidden_act": "gelu",
"intermediate_size": 512,
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"initializer_range": 0.02,
}
key = jax.random.PRNGKey(5678)
model_key, train_key = jax.random.split(key)
classifier = BertClassifier[SeqLen, NumClasses](config=bert_config, num_classes=NumClasses(2), key=model_key)
# Download the checkpoint from
# https://github.com/patrick-kidger/equinox/blob/main/examples/bert_checkpoint.eqx
classifier_chkpt = eqx.tree_deserialise_leaves("bert_checkpoint.eqx", classifier)
tokenizer = AutoTokenizer.from_pretrained("google/bert_uncased_L-2_H-128_A-2", model_max_length=128)
def tokenize(example: Any):
return tokenizer(example["sentence"], padding="max_length", truncation=True)
DSSize = NewType("DSSize", int)
BatchSize = NewType("BatchSize", int)
NumDevices = NewType("NumDevices", int)
BatchPerDevice = NewType("BatchPerDevice", int)
Shape = TypeVarTuple("Shape")
class Inputs(TypedDict, Generic[*Shape, SeqLenT]):
input_ids: ndarray[*Shape, SeqLenT, int]
token_type_ids: ndarray[*Shape, SeqLenT, int]
class InputsWithLabel(TypedDict, Generic[*Shape]):
input_ids: ndarray[*Shape, SeqLen, int]
token_type_ids: ndarray[*Shape, SeqLen, int]
label: ndarray[*Shape, int]
ds: DatasetDict[InputsWithLabel[DSSize]] = load_dataset("sst2")
ds = ds.map(tokenize, batched=True)
ds.set_format(type="jax", columns=["input_ids", "token_type_ids", "label"])
@eqx.filter_value_and_grad
def compute_loss(
classifier: BertClassifier[SeqLenT, NumClassesT],
inputs: InputsWithLabel[BatchPerDevice],
key: PRNGKey,
) -> ndarray[float]:
batch_size = inputs["input_ids"].shape[0]
batched_keys = jax.random.split(key, num=batch_size)
logits: ndarray[BatchPerDevice, NumClassesT, float] = jax.vmap(classifier, in_axes=(0, None, 0))(
inputs, True, batched_keys
)
return jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=inputs["label"]))
def make_step(
model: BertClassifier[SeqLenT, NumClassesT],
inputs: InputsWithLabel[BatchPerDevice],
opt_state: optax.OptState,
key: PRNGKey,
tx: optax.GradientTransformation,
):
key, new_key = jax.random.split(key)
loss, grads = compute_loss(model, inputs, key)
grads = jax.lax.pmean(grads, axis_name="devices")
updates, opt_state = tx.update(grads, opt_state, model)
model = eqx.apply_updates(model, updates)
return loss, model, opt_state, new_key
def make_eval_step(
model: BertClassifier[SeqLenT, NumClassesT],
inputs: Inputs[BatchPerDevice, SeqLenT],
) -> ndarray[BatchPerDevice, NumClassesT, float]:
return jax.vmap(functools.partial(model, enable_dropout=False))(cast(Any, inputs))
def p_make_eval_step(
model: BertClassifier[SeqLenT, NumClassesT],
inputs: Inputs[NumDevices, BatchPerDevice, SeqLenT],
) -> ndarray[NumDevices, BatchPerDevice, NumClasses, float]:
return eqx.filter_pmap(make_eval_step)(model, inputs)
epochs = 9
batch_size = 32
learning_rate = 1e-5
num_devices = jax.device_count()
assert batch_size % num_devices == 0, "The batch size must be a multiple of the number of devices."
tx = optax.adam(learning_rate=learning_rate)
tx = optax.chain(optax.clip_by_global_norm(1.0), tx)
opt_state = tx.init(classifier_chkpt)
def p_make_step(
model: BertClassifier[SeqLenT, NumClassesT],
inputs: InputsWithLabel[NumDevices, BatchPerDevice],
opt_state: optax.OptState,
key: PRNGKey,
) -> tuple[ndarray[NumDevices, float], BertClassifier[SeqLenT, NumClassesT], optax.OptState, PRNGKey]:
return eqx.filter_pmap(functools.partial(make_step, tx=tx), axis_name="devices")(model, inputs, opt_state, key)
# Replicate across devices.
opt_state = jax.device_put_replicated(opt_state, jax.local_devices())
model = jax.device_put_replicated(classifier_chkpt, jax.local_devices())
train_key = jax.device_put_replicated(train_key, jax.local_devices())
A = TypeVar("A")
def declare(_: type[A], x: A) -> A:
return x
for epoch in range(epochs):
with tqdm.tqdm(
declare(
Iterator[InputsWithLabel[BatchSize]],
ds["train"].iter(batch_size=batch_size, drop_last_batch=True),
),
total=ds["train"].num_rows // batch_size,
unit="steps",
desc=f"Epoch {epoch+1}/{epochs}",
) as tqdm_epoch:
for batch in tqdm_epoch:
token_ids, token_type_ids, label = (
batch["input_ids"],
batch["token_type_ids"],
batch["label"],
)
token_ids_by_dev: ndarray[NumDevices, BatchPerDevice, SeqLen, int] = einops.rearrange(
token_ids, "(b1 b2) s -> b1 b2 s", b1=num_devices
)
token_type_ids_by_dev: ndarray[NumDevices, BatchPerDevice, SeqLen, int] = einops.rearrange(
token_type_ids, "(b1 b2) s -> b1 b2 s", b1=num_devices
)
label_by_dev: ndarray[NumDevices, BatchPerDevice, int] = einops.rearrange(
label, "(b1 b2) -> b1 b2", b1=num_devices
)
inputs_by_dev: InputsWithLabel[NumDevices, BatchPerDevice] = {
"input_ids": token_ids_by_dev,
"token_type_ids": token_type_ids_by_dev,
"label": label_by_dev,
}
loss, model, opt_state, train_key = p_make_step(model, inputs_by_dev, opt_state, train_key)
tqdm_epoch.set_postfix(loss=np.sum(loss).item())
outputs: list[float] = []
for batch in tqdm.tqdm(
declare(
Iterator[InputsWithLabel[BatchSize]],
ds["validation"].iter(batch_size=batch_size),
),
unit="steps",
total=ceil(ds["validation"].num_rows / batch_size),
desc="Validation",
):
token_ids, token_type_ids = batch["input_ids"], batch["token_type_ids"]
label = batch["label"]
token_ids_by_dev: ndarray[NumDevices, BatchPerDevice, SeqLen, int] = einops.rearrange(
token_ids, "(b1 b2) s -> b1 b2 s", b1=num_devices
)
token_type_ids_by_dev: ndarray[NumDevices, BatchPerDevice, SeqLen, int] = einops.rearrange(
token_type_ids, "(b1 b2) s -> b1 b2 s", b1=num_devices
)
inputs: Inputs[NumDevices, BatchPerDevice, SeqLen] = {
"input_ids": token_ids_by_dev,
"token_type_ids": token_type_ids_by_dev,
}
logits: ndarray[NumDevices, BatchPerDevice, NumClasses, float] = p_make_eval_step(model, inputs)
flat_logits: ndarray[BatchSize, NumClasses, float] = logits.reshape((-1, NumClasses(2)))
correct: ndarray[BatchSize, bool] = np.argmax(flat_logits, axis=-1) == label
outputs.extend(list(correct.astype(float)))
print(f"Accuracy: {100 * sum(outputs) / len(outputs):.2f}%")
Thanks for the heads-up about the missed type annotation in the BERT example. I've just fixed that. (The examples don't actually have the runtime type checker enabled, which is why that slipped through.)
Personally I'm also a fan of using generics+typevars internally, but I generally don't tend to use them in public APIs because (a) I realise that this adds unhelpful noise to the many users who aren't familiar, and (b) it actually means you get spurious static type errors if the end user doesn't specify the parameterising type (https://github.com/microsoft/pyright/discussions/5599), which is often more than people really want to do.