equinox icon indicating copy to clipboard operation
equinox copied to clipboard

FYI: Equinox example with variadic generics

Open colehaus opened this issue 2 years ago • 1 comments

This is not so much an issue as an FYI. I was going through the Equinox examples and converted the BERT one to use a variadic generics approach to typing. I know there's some thinking on this topic at https://github.com/google/jaxtyping/blob/main/docs/faq.md. I thought it could be interesting for others to see (one first attempt at) what's possible with this approach ATM. Most of the types are pretty precise (i.e. able to add and remove appropriately-sized dimensions), but I didn't really try to do anything sensible with pytrees. This is maybe just because I'm type-brained, but I found the additional type declarations quite helpful for my understanding. (I can show the typing stubs I'm using too if that would be useful.)

I also think doing this uncovered one discrepancy. The example has Float[Array, "seq_len hidden_size"] for the FeedForward block, but it seems like it should be just Float[Array, "hidden_size"] because the seq_len is only added by the vmap later.

from __future__ import annotations

import functools
from math import ceil
import os
import re
from typing import (
    Any,
    Generic,
    Iterator,
    NewType,
    Optional,
    TypeVar,
    TypeVarTuple,
    TypedDict,
    cast,
)


from datasets import DatasetDict, load_dataset
import einops
import equinox as eqx
import jax
import jax.numpy as jnp
from jax.numpy import ndarray
from jax.random import PRNGKey
import numpy as np
import optax
import tqdm
from transformers import AutoTokenizer

SeqLenT = TypeVar("SeqLenT", bound=int)
EmbedSizeT = TypeVar("EmbedSizeT", bound=int)


class EmbedderBlock(eqx.Module, Generic[EmbedSizeT]):
    """BERT embedder."""

    token_embedder: eqx.nn.Embedding[EmbedSizeT]
    segment_embedder: eqx.nn.Embedding[EmbedSizeT]
    position_embedder: eqx.nn.Embedding[EmbedSizeT]
    layernorm: eqx.nn.LayerNorm[EmbedSizeT]
    dropout: eqx.nn.Dropout

    def __init__(
        self,
        vocab_size: int,
        max_length: int,
        type_vocab_size: int,
        embedding_size: EmbedSizeT,
        hidden_size: EmbedSizeT,
        dropout_rate: float,
        key: jax.random.PRNGKey,
    ):
        token_key, segment_key, position_key = jax.random.split(key, 3)

        self.token_embedder = eqx.nn.Embedding(
            num_embeddings=vocab_size, embedding_size=embedding_size, key=token_key
        )
        self.segment_embedder = eqx.nn.Embedding(
            num_embeddings=type_vocab_size,
            embedding_size=embedding_size,
            key=segment_key,
        )
        self.position_embedder = eqx.nn.Embedding(
            num_embeddings=max_length, embedding_size=embedding_size, key=position_key
        )
        self.layernorm = eqx.nn.LayerNorm(shape=hidden_size)
        self.dropout = eqx.nn.Dropout(dropout_rate)

    def __call__(
        self,
        token_ids: ndarray[SeqLenT, int],
        position_ids: ndarray[SeqLenT, int],
        segment_ids: ndarray[SeqLenT, int],
        enable_dropout: bool = False,
        key: Optional[PRNGKey] = None,
    ) -> ndarray[SeqLenT, EmbedSizeT, float]:
        tokens = self.token_embedder(token_ids)
        segments = self.segment_embedder(segment_ids)
        positions = self.position_embedder(position_ids)
        embedded_inputs: ndarray[SeqLenT, EmbedSizeT, float] = tokens + segments + positions
        embedded_inputs = jax.vmap(self.layernorm)(embedded_inputs)
        embedded_inputs = self.dropout(embedded_inputs, inference=not enable_dropout, key=key)
        return embedded_inputs


class FeedForwardBlock(eqx.Module, Generic[EmbedSizeT]):
    """A single transformer feed forward block."""

    IntermediateSize = NewType("IntermediateSize", int)

    mlp: eqx.nn.Linear[EmbedSizeT, FeedForwardBlock.IntermediateSize]
    output: eqx.nn.Linear[FeedForwardBlock.IntermediateSize, EmbedSizeT]
    layernorm: eqx.nn.LayerNorm[EmbedSizeT]
    dropout: eqx.nn.Dropout

    def __init__(
        self,
        hidden_size: EmbedSizeT,
        intermediate_size: int,
        dropout_rate: float,
        key: jax.random.PRNGKey,
    ):
        mlp_key, output_key = jax.random.split(key)
        self.mlp = eqx.nn.Linear(
            in_features=hidden_size,
            out_features=self.IntermediateSize(intermediate_size),
            key=mlp_key,
        )
        self.output = eqx.nn.Linear(
            in_features=self.IntermediateSize(intermediate_size),
            out_features=hidden_size,
            key=output_key,
        )

        self.layernorm = eqx.nn.LayerNorm(shape=hidden_size)
        self.dropout = eqx.nn.Dropout(dropout_rate)

    def __call__(
        self,
        inputs: ndarray[EmbedSizeT, float],
        enable_dropout: bool = True,
        key: Optional[PRNGKey] = None,
    ) -> ndarray[EmbedSizeT, float]:
        hidden = self.mlp(inputs)
        hidden = jax.nn.gelu(hidden)

        output = self.output(hidden)
        output = self.dropout(output, inference=not enable_dropout, key=key)

        # Residual and layer norm.
        output += inputs
        output = self.layernorm(output)

        return output


class AttentionBlock(eqx.Module, Generic[SeqLenT, EmbedSizeT]):
    """A single transformer attention block."""

    NumHeads = NewType("NumHeads", int)

    attention: eqx.nn.MultiheadAttention[
        AttentionBlock.NumHeads,
        SeqLenT,
        SeqLenT,
        EmbedSizeT,
        EmbedSizeT,
        EmbedSizeT,
        EmbedSizeT,
    ]
    layernorm: eqx.nn.LayerNorm[EmbedSizeT]
    dropout: eqx.nn.Dropout
    num_heads: AttentionBlock.NumHeads = eqx.field(static=True)

    def __init__(
        self,
        hidden_size: EmbedSizeT,
        num_heads: int,
        dropout_rate: float,
        attention_dropout_rate: float,
        key: jax.random.PRNGKey,
    ):
        self.num_heads = self.NumHeads(num_heads)
        self.attention = eqx.nn.MultiheadAttention(
            num_heads=self.num_heads,
            query_size=hidden_size,
            key_size=hidden_size,
            value_size=hidden_size,
            output_size=hidden_size,
            use_query_bias=True,
            use_key_bias=True,
            use_value_bias=True,
            use_output_bias=True,
            dropout_p=attention_dropout_rate,
            key=key,
        )
        self.layernorm = eqx.nn.LayerNorm(shape=hidden_size)
        self.dropout = eqx.nn.Dropout(dropout_rate)

    def __call__(
        self,
        inputs: ndarray[SeqLenT, EmbedSizeT, float],
        mask: Optional[ndarray[SeqLenT, int]],
        enable_dropout: bool = False,
        key: Optional[PRNGKey] = None,
    ) -> ndarray[SeqLenT, EmbedSizeT, float]:
        full_mask = None if mask is None else self._make_self_attention_mask(mask)
        attention_key, dropout_key = (None, None) if key is None else jax.random.split(key)

        attention_output = self.attention(
            query=inputs,
            key_=inputs,
            value=inputs,
            mask=full_mask,
            inference=not enable_dropout,
            key=attention_key,
        )

        result = attention_output
        result = self.dropout(result, inference=not enable_dropout, key=dropout_key)
        result = result + inputs
        result = jax.vmap(self.layernorm)(result)
        return result

    def _make_self_attention_mask(
        self, mask: ndarray[SeqLenT, int]
    ) -> ndarray[AttentionBlock.NumHeads, SeqLenT, SeqLenT, float]:
        square_mask = jnp.multiply(jnp.expand_dims(mask, axis=-1), jnp.expand_dims(mask, axis=-2))
        xmask = jnp.expand_dims(square_mask, axis=-3)
        return jnp.repeat(xmask, repeats=self.num_heads, axis=-3).astype(jnp.float32)


class TransformerLayer(eqx.Module, Generic[SeqLenT, EmbedSizeT]):
    """A single transformer layer."""

    attention_block: AttentionBlock[SeqLenT, EmbedSizeT]
    ff_block: FeedForwardBlock[EmbedSizeT]

    def __init__(
        self,
        hidden_size: EmbedSizeT,
        intermediate_size: int,
        num_heads: int,
        dropout_rate: float,
        attention_dropout_rate: float,
        key: PRNGKey,
    ):
        attention_key, ff_key = jax.random.split(key)

        self.attention_block = AttentionBlock(
            hidden_size=hidden_size,
            num_heads=num_heads,
            dropout_rate=dropout_rate,
            attention_dropout_rate=attention_dropout_rate,
            key=attention_key,
        )
        self.ff_block = FeedForwardBlock(
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
            dropout_rate=dropout_rate,
            key=ff_key,
        )

    def __call__(
        self,
        inputs: ndarray[SeqLenT, EmbedSizeT, float],
        mask: Optional[ndarray[SeqLenT, int]] = None,
        *,
        enable_dropout: bool = False,
        key: Optional[PRNGKey] = None,
    ) -> ndarray[SeqLenT, EmbedSizeT, float]:
        attn_key, ff_key = (None, None) if key is None else jax.random.split(key)
        attention_output = self.attention_block(inputs, mask, enable_dropout=enable_dropout, key=attn_key)
        seq_len = inputs.shape[0]
        ff_keys = None if ff_key is None else jax.random.split(ff_key, num=seq_len)
        output = jax.vmap(self.ff_block, in_axes=(0, None, 0))(attention_output, enable_dropout, ff_keys)
        return output


class EncoderOut(TypedDict, Generic[SeqLenT, EmbedSizeT]):
    embeddings: ndarray[SeqLenT, EmbedSizeT, float]
    layers: list[ndarray[SeqLenT, EmbedSizeT, float]]
    pooled: ndarray[EmbedSizeT, float]


class Encoder(eqx.Module, Generic[SeqLenT, EmbedSizeT]):
    """Full BERT encoder."""

    embedder_block: EmbedderBlock[EmbedSizeT]
    layers: list[TransformerLayer[SeqLenT, EmbedSizeT]]
    pooler: eqx.nn.Linear[EmbedSizeT, EmbedSizeT]

    def __init__(
        self,
        vocab_size: int,
        max_length: int,
        type_vocab_size: int,
        embedding_size: EmbedSizeT,
        hidden_size: EmbedSizeT,
        intermediate_size: int,
        num_layers: int,
        num_heads: int,
        dropout_rate: float,
        attention_dropout_rate: float,
        key: jax.random.PRNGKey,
    ):
        embedder_key, layer_key, pooler_key = jax.random.split(key, num=3)
        self.embedder_block = EmbedderBlock(
            vocab_size=vocab_size,
            max_length=max_length,
            type_vocab_size=type_vocab_size,
            embedding_size=embedding_size,
            hidden_size=hidden_size,
            dropout_rate=dropout_rate,
            key=embedder_key,
        )

        layer_keys = jax.random.split(layer_key, num=num_layers)
        self.layers = []
        for layer_key in layer_keys:
            self.layers.append(
                TransformerLayer(
                    hidden_size=hidden_size,
                    intermediate_size=intermediate_size,
                    num_heads=num_heads,
                    dropout_rate=dropout_rate,
                    attention_dropout_rate=attention_dropout_rate,
                    key=layer_key,
                )
            )

        self.pooler = eqx.nn.Linear(in_features=hidden_size, out_features=hidden_size, key=pooler_key)

    def __call__(
        self,
        token_ids: ndarray[SeqLenT, int],
        position_ids: ndarray[SeqLenT, int],
        segment_ids: ndarray[SeqLenT, int],
        *,
        enable_dropout: bool = False,
        key: Optional[PRNGKey] = None,
    ) -> EncoderOut[SeqLenT, EmbedSizeT]:
        emb_key, l_key = (None, None) if key is None else jax.random.split(key)

        embeddings = self.embedder_block(
            token_ids=token_ids,
            position_ids=position_ids,
            segment_ids=segment_ids,
            enable_dropout=enable_dropout,
            key=emb_key,
        )

        # We assume that all 0-values should be masked out.
        mask = (token_ids != 0).astype(jnp.int32)

        x = embeddings
        layer_outputs: list[ndarray[SeqLenT, EmbedSizeT, float]] = []
        for layer in self.layers:
            cl_key, l_key = (None, None) if l_key is None else jax.random.split(l_key)
            x = layer(x, mask, enable_dropout=enable_dropout, key=cl_key)
            layer_outputs.append(x)

        # BERT pooling.
        # The first token in the last layer is the embedding of the "[CLS]" token.
        first_token_last_layer = x[..., 0, :]
        pooled = self.pooler(first_token_last_layer)
        pooled = jnp.tanh(pooled)

        return {"embeddings": embeddings, "layers": layer_outputs, "pooled": pooled}


NumClassesT = TypeVar("NumClassesT", bound=int)


class BertClassifier(eqx.Module, Generic[SeqLenT, NumClassesT]):
    """BERT classifier."""

    EmbedSize = NewType("EmbedSize", int)

    encoder: Encoder[SeqLenT, BertClassifier.EmbedSize]
    classifier_head: eqx.nn.Linear[BertClassifier.EmbedSize, NumClassesT]
    dropout: eqx.nn.Dropout

    def __init__(
        self,
        config: Config,
        num_classes: NumClassesT,
        key: jax.random.PRNGKey,
    ):
        encoder_key, head_key = jax.random.split(key)

        hidden = self.EmbedSize(config["hidden_size"])

        self.encoder = Encoder(
            vocab_size=config["vocab_size"],
            max_length=config["max_position_embeddings"],
            type_vocab_size=config["type_vocab_size"],
            embedding_size=hidden,
            hidden_size=hidden,
            intermediate_size=config["intermediate_size"],
            num_layers=config["num_hidden_layers"],
            num_heads=config["num_attention_heads"],
            dropout_rate=config["hidden_dropout_prob"],
            attention_dropout_rate=config["attention_probs_dropout_prob"],
            key=encoder_key,
        )
        self.classifier_head = eqx.nn.Linear(in_features=hidden, out_features=num_classes, key=head_key)
        self.dropout = eqx.nn.Dropout(config["hidden_dropout_prob"])

    def __call__(
        self,
        inputs: Inputs[(), SeqLenT],
        enable_dropout: bool = True,
        key: Optional[PRNGKey] = None,
    ) -> ndarray[NumClassesT, float]:
        seq_len = inputs["input_ids"].shape[-1]
        position_ids = jnp.arange(seq_len)

        e_key, d_key = (None, None) if key is None else jax.random.split(key)

        pooled_output = self.encoder(
            token_ids=inputs["input_ids"],
            segment_ids=inputs["token_type_ids"],
            position_ids=position_ids,
            enable_dropout=enable_dropout,
            key=e_key,
        )["pooled"]
        pooled_output = self.dropout(pooled_output, inference=not enable_dropout, key=d_key)

        return self.classifier_head(pooled_output)


class Config(TypedDict):
    vocab_size: int
    hidden_size: int
    num_hidden_layers: int
    num_attention_heads: int
    hidden_act: str
    intermediate_size: int
    hidden_dropout_prob: float
    attention_probs_dropout_prob: float
    max_position_embeddings: int
    type_vocab_size: int
    initializer_range: float


SeqLen = NewType("SeqLen", int)
NumClasses = NewType("NumClasses", int)

# Tiny-BERT config.
bert_config: Config = {
    "vocab_size": 30522,
    "hidden_size": 128,
    "num_hidden_layers": 2,
    "num_attention_heads": 2,
    "hidden_act": "gelu",
    "intermediate_size": 512,
    "hidden_dropout_prob": 0.1,
    "attention_probs_dropout_prob": 0.1,
    "max_position_embeddings": 512,
    "type_vocab_size": 2,
    "initializer_range": 0.02,
}

key = jax.random.PRNGKey(5678)
model_key, train_key = jax.random.split(key)
classifier = BertClassifier[SeqLen, NumClasses](config=bert_config, num_classes=NumClasses(2), key=model_key)

# Download the checkpoint from
# https://github.com/patrick-kidger/equinox/blob/main/examples/bert_checkpoint.eqx
classifier_chkpt = eqx.tree_deserialise_leaves("bert_checkpoint.eqx", classifier)

tokenizer = AutoTokenizer.from_pretrained("google/bert_uncased_L-2_H-128_A-2", model_max_length=128)


def tokenize(example: Any):
    return tokenizer(example["sentence"], padding="max_length", truncation=True)


DSSize = NewType("DSSize", int)
BatchSize = NewType("BatchSize", int)
NumDevices = NewType("NumDevices", int)
BatchPerDevice = NewType("BatchPerDevice", int)

Shape = TypeVarTuple("Shape")


class Inputs(TypedDict, Generic[*Shape, SeqLenT]):
    input_ids: ndarray[*Shape, SeqLenT, int]
    token_type_ids: ndarray[*Shape, SeqLenT, int]


class InputsWithLabel(TypedDict, Generic[*Shape]):
    input_ids: ndarray[*Shape, SeqLen, int]
    token_type_ids: ndarray[*Shape, SeqLen, int]
    label: ndarray[*Shape, int]


ds: DatasetDict[InputsWithLabel[DSSize]] = load_dataset("sst2")
ds = ds.map(tokenize, batched=True)
ds.set_format(type="jax", columns=["input_ids", "token_type_ids", "label"])


@eqx.filter_value_and_grad
def compute_loss(
    classifier: BertClassifier[SeqLenT, NumClassesT],
    inputs: InputsWithLabel[BatchPerDevice],
    key: PRNGKey,
) -> ndarray[float]:
    batch_size = inputs["input_ids"].shape[0]
    batched_keys = jax.random.split(key, num=batch_size)
    logits: ndarray[BatchPerDevice, NumClassesT, float] = jax.vmap(classifier, in_axes=(0, None, 0))(
        inputs, True, batched_keys
    )
    return jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=inputs["label"]))


def make_step(
    model: BertClassifier[SeqLenT, NumClassesT],
    inputs: InputsWithLabel[BatchPerDevice],
    opt_state: optax.OptState,
    key: PRNGKey,
    tx: optax.GradientTransformation,
):
    key, new_key = jax.random.split(key)
    loss, grads = compute_loss(model, inputs, key)
    grads = jax.lax.pmean(grads, axis_name="devices")

    updates, opt_state = tx.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state, new_key


def make_eval_step(
    model: BertClassifier[SeqLenT, NumClassesT],
    inputs: Inputs[BatchPerDevice, SeqLenT],
) -> ndarray[BatchPerDevice, NumClassesT, float]:
    return jax.vmap(functools.partial(model, enable_dropout=False))(cast(Any, inputs))


def p_make_eval_step(
    model: BertClassifier[SeqLenT, NumClassesT],
    inputs: Inputs[NumDevices, BatchPerDevice, SeqLenT],
) -> ndarray[NumDevices, BatchPerDevice, NumClasses, float]:
    return eqx.filter_pmap(make_eval_step)(model, inputs)


epochs = 9
batch_size = 32
learning_rate = 1e-5

num_devices = jax.device_count()
assert batch_size % num_devices == 0, "The batch size must be a multiple of the number of devices."

tx = optax.adam(learning_rate=learning_rate)
tx = optax.chain(optax.clip_by_global_norm(1.0), tx)
opt_state = tx.init(classifier_chkpt)


def p_make_step(
    model: BertClassifier[SeqLenT, NumClassesT],
    inputs: InputsWithLabel[NumDevices, BatchPerDevice],
    opt_state: optax.OptState,
    key: PRNGKey,
) -> tuple[ndarray[NumDevices, float], BertClassifier[SeqLenT, NumClassesT], optax.OptState, PRNGKey]:
    return eqx.filter_pmap(functools.partial(make_step, tx=tx), axis_name="devices")(model, inputs, opt_state, key)


# Replicate across devices.
opt_state = jax.device_put_replicated(opt_state, jax.local_devices())
model = jax.device_put_replicated(classifier_chkpt, jax.local_devices())
train_key = jax.device_put_replicated(train_key, jax.local_devices())

A = TypeVar("A")


def declare(_: type[A], x: A) -> A:
    return x


for epoch in range(epochs):
    with tqdm.tqdm(
        declare(
            Iterator[InputsWithLabel[BatchSize]],
            ds["train"].iter(batch_size=batch_size, drop_last_batch=True),
        ),
        total=ds["train"].num_rows // batch_size,
        unit="steps",
        desc=f"Epoch {epoch+1}/{epochs}",
    ) as tqdm_epoch:
        for batch in tqdm_epoch:
            token_ids, token_type_ids, label = (
                batch["input_ids"],
                batch["token_type_ids"],
                batch["label"],
            )

            token_ids_by_dev: ndarray[NumDevices, BatchPerDevice, SeqLen, int] = einops.rearrange(
                token_ids, "(b1 b2) s -> b1 b2 s", b1=num_devices
            )
            token_type_ids_by_dev: ndarray[NumDevices, BatchPerDevice, SeqLen, int] = einops.rearrange(
                token_type_ids, "(b1 b2) s -> b1 b2 s", b1=num_devices
            )
            label_by_dev: ndarray[NumDevices, BatchPerDevice, int] = einops.rearrange(
                label, "(b1 b2) -> b1 b2", b1=num_devices
            )

            inputs_by_dev: InputsWithLabel[NumDevices, BatchPerDevice] = {
                "input_ids": token_ids_by_dev,
                "token_type_ids": token_type_ids_by_dev,
                "label": label_by_dev,
            }
            loss, model, opt_state, train_key = p_make_step(model, inputs_by_dev, opt_state, train_key)

            tqdm_epoch.set_postfix(loss=np.sum(loss).item())


outputs: list[float] = []
for batch in tqdm.tqdm(
    declare(
        Iterator[InputsWithLabel[BatchSize]],
        ds["validation"].iter(batch_size=batch_size),
    ),
    unit="steps",
    total=ceil(ds["validation"].num_rows / batch_size),
    desc="Validation",
):
    token_ids, token_type_ids = batch["input_ids"], batch["token_type_ids"]
    label = batch["label"]

    token_ids_by_dev: ndarray[NumDevices, BatchPerDevice, SeqLen, int] = einops.rearrange(
        token_ids, "(b1 b2) s -> b1 b2 s", b1=num_devices
    )
    token_type_ids_by_dev: ndarray[NumDevices, BatchPerDevice, SeqLen, int] = einops.rearrange(
        token_type_ids, "(b1 b2) s -> b1 b2 s", b1=num_devices
    )

    inputs: Inputs[NumDevices, BatchPerDevice, SeqLen] = {
        "input_ids": token_ids_by_dev,
        "token_type_ids": token_type_ids_by_dev,
    }

    logits: ndarray[NumDevices, BatchPerDevice, NumClasses, float] = p_make_eval_step(model, inputs)
    flat_logits: ndarray[BatchSize, NumClasses, float] = logits.reshape((-1, NumClasses(2)))
    correct: ndarray[BatchSize, bool] = np.argmax(flat_logits, axis=-1) == label
    outputs.extend(list(correct.astype(float)))

print(f"Accuracy: {100 * sum(outputs) / len(outputs):.2f}%")

colehaus avatar Aug 21 '23 01:08 colehaus

Thanks for the heads-up about the missed type annotation in the BERT example. I've just fixed that. (The examples don't actually have the runtime type checker enabled, which is why that slipped through.)

Personally I'm also a fan of using generics+typevars internally, but I generally don't tend to use them in public APIs because (a) I realise that this adds unhelpful noise to the many users who aren't familiar, and (b) it actually means you get spurious static type errors if the end user doesn't specify the parameterising type (https://github.com/microsoft/pyright/discussions/5599), which is often more than people really want to do.

patrick-kidger avatar Aug 21 '23 11:08 patrick-kidger