LWM icon indicating copy to clipboard operation
LWM copied to clipboard

Interesting Problems of Accuracy & Inference Speed with run_eval_needle.sh

Open Treemann opened this issue 1 year ago • 0 comments

Hi, I prepared the pg19 dataset and ran run_eval_needle.sh with different settings of parameters. I have some questions about the experiment results and hope someone could help.

Device: NVIDIA V100/A100 GPUs Script:

#! /bin/bash

export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"

export llama_tokenizer_path="weights/LWM-Text-Chat-1M-Jax/tokenizer.model"
export lwm_text_checkpoint="weights/LWM-Text-Chat-1M-Jax/params"
export haystack_file="data/pg19.jsonl"
export output_file="eval_needle_1m_jax.log"

export CUDA_VISIBLE_DEVICES=0,1,2,3

chuck_size=1024
ctx_len=6144
use_bolck=True
python3 -u scripts/eval_needle.py \
    --mesh_dim='!1,1,2,2' \
    --dtype='fp32' \
    --load_llama_config='7b' \
    --update_llama_config="dict(theta=50000000,max_sequence_length=1048576,scan_attention=${use_bolck},scan_query_chunk_size=${chuck_size},scan_key_chunk_size=${chuck_size},scan_mlp=${use_bolck},scan_mlp_chunk_size=${chuck_size},scan_layers=True)" \
    --load_checkpoint="params::$lwm_text_checkpoint" \
    --tokenizer.vocab_file="$llama_tokenizer_path" \
    --output_file="$output_file" \
    --haystack_file="$haystack_file" \
    --max_tokens_per_batch=5000 \
    --context_lengths_min=${ctx_len} \
    --context_lengths_max=${ctx_len} \
    --n_context_length_intervals=1 \
    --n_document_depth_intervals=2 \
    --n_rounds=3
read

Some questions about the experiment results:

  1. The model seemed to have correct predictions with mesh_dim='1,2,2,1' and dtype=float32 or bf16. However, if sequence parallelism was adopted with mesh_dim='1,1,2,2', the prediction was wrong.
mesh_dim fp32 fp64 fp16 bf16
1,2,2,1 correct wrong wrong correct
1,1,2,2 wrong wrong wrong wrong

With further investigation, I found that the attention outputs began to have nan values from some middle layer and the final outputs of the networks are all nan.

... attn_output shape: (array(1, dtype=int32), array(8192, dtype=int32), array(4096, dtype=int32)), attn_output: [[[-0.00195352 -0.00402817 0.00310729 ... 0.00402906 -0.00059524 -0.00334746] [-0.00195352 -0.00402817 0.00310729 ... 0.00402906 -0.00059524 -0.00334746] [-0.00195352 -0.00402817 0.00310729 ... 0.00402906 -0.00059524 -0.00334746] ... [ nan nan nan ... nan nan nan] [ nan nan nan ... nan nan nan] [ nan nan nan ... nan nan nan]]] ... attn_output shape: (array(1, dtype=int32), array(1, dtype=int32), array(4096, dtype=int32)), attn_output: [[[nan nan nan ... nan nan nan]]] ...

  1. The inference speed was very slow with setting mesh_dim='1,1,2,2'. By counting the number of the printed attention outputs, I found that with mesh_dim='1,1,2,2', the FlaxLLaMABlock was called much more times than mesh_dim='1,2,2,1'
mesh_dim times of calling FlaxLLaMABlock
1,2,2,1 544
1,1,2,2 65568

Did anyone have similar findings and could share insight on these questions?

Treemann avatar Apr 18 '24 09:04 Treemann