Interesting Problems of Accuracy & Inference Speed with run_eval_needle.sh
Hi, I prepared the pg19 dataset and ran run_eval_needle.sh with different settings of parameters.
I have some questions about the experiment results and hope someone could help.
Device: NVIDIA V100/A100 GPUs Script:
#! /bin/bash
export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
export llama_tokenizer_path="weights/LWM-Text-Chat-1M-Jax/tokenizer.model"
export lwm_text_checkpoint="weights/LWM-Text-Chat-1M-Jax/params"
export haystack_file="data/pg19.jsonl"
export output_file="eval_needle_1m_jax.log"
export CUDA_VISIBLE_DEVICES=0,1,2,3
chuck_size=1024
ctx_len=6144
use_bolck=True
python3 -u scripts/eval_needle.py \
--mesh_dim='!1,1,2,2' \
--dtype='fp32' \
--load_llama_config='7b' \
--update_llama_config="dict(theta=50000000,max_sequence_length=1048576,scan_attention=${use_bolck},scan_query_chunk_size=${chuck_size},scan_key_chunk_size=${chuck_size},scan_mlp=${use_bolck},scan_mlp_chunk_size=${chuck_size},scan_layers=True)" \
--load_checkpoint="params::$lwm_text_checkpoint" \
--tokenizer.vocab_file="$llama_tokenizer_path" \
--output_file="$output_file" \
--haystack_file="$haystack_file" \
--max_tokens_per_batch=5000 \
--context_lengths_min=${ctx_len} \
--context_lengths_max=${ctx_len} \
--n_context_length_intervals=1 \
--n_document_depth_intervals=2 \
--n_rounds=3
read
Some questions about the experiment results:
- The model seemed to have correct predictions with
mesh_dim='1,2,2,1'anddtype=float32 or bf16. However, if sequence parallelism was adopted withmesh_dim='1,1,2,2', the prediction was wrong.
| mesh_dim | fp32 | fp64 | fp16 | bf16 |
|---|---|---|---|---|
| 1,2,2,1 | correct | wrong | wrong | correct |
| 1,1,2,2 | wrong | wrong | wrong | wrong |
With further investigation, I found that the attention outputs began to have nan values from some middle layer and the final outputs of the networks are all nan.
... attn_output shape: (array(1, dtype=int32), array(8192, dtype=int32), array(4096, dtype=int32)), attn_output: [[[-0.00195352 -0.00402817 0.00310729 ... 0.00402906 -0.00059524 -0.00334746] [-0.00195352 -0.00402817 0.00310729 ... 0.00402906 -0.00059524 -0.00334746] [-0.00195352 -0.00402817 0.00310729 ... 0.00402906 -0.00059524 -0.00334746] ... [ nan nan nan ... nan nan nan] [ nan nan nan ... nan nan nan] [ nan nan nan ... nan nan nan]]] ... attn_output shape: (array(1, dtype=int32), array(1, dtype=int32), array(4096, dtype=int32)), attn_output: [[[nan nan nan ... nan nan nan]]] ...
- The inference speed was very slow with setting
mesh_dim='1,1,2,2'. By counting the number of the printed attention outputs, I found that withmesh_dim='1,1,2,2', theFlaxLLaMABlockwas called much more times thanmesh_dim='1,2,2,1'
| mesh_dim | times of calling FlaxLLaMABlock |
|---|---|
| 1,2,2,1 | 544 |
| 1,1,2,2 | 65568 |
Did anyone have similar findings and could share insight on these questions?