KBLaM icon indicating copy to clipboard operation
KBLaM copied to clipboard

Bad performance for three generation modes.

Open shiwanghua opened this issue 8 months ago • 14 comments

○ kb:{'rouge1': 0.10795632401934566, 'rouge2': 0.005502676702905536, 'rougeL': 0.10485264583478511, 'rougeLsum': 0.10448227546441474, 'bert_score_precision': 0.5239911143481731, 'bert_score_recall': 0.5586187492311001, 'bert_score_f1': 0.537801007553935, 'mem_cost': 22861053952
○ ICL-50个:{'rouge1': 0.10135071880712321, 'rouge2': 0.06694877154125624, 'rougeL': 0.09862427623499491, 'rougeLsum': 0.09862427623499491, 'bert_score_precision': 0.32884539023041726, 'bert_score_recall': 0.4558024096488953, 'bert_score_f1': 0.3758223417401314, 'mem_cost': 25.775390625}
○ zero-shot:{'rouge1': 0.030773897832609545, 'rouge2': 0.0015176552801414067, 'rougeL': 0.028298936576585366, 'rougeLsum': 0.027872349274998064, 'bert_score_precision': 0.3853674686700106, 'bert_score_recall': 0.4561145968735218, 'bert_score_f1': 0.4070150001347065, 'mem_cost': 30863785984}

use enron as train set and synthetic as eval set

args.command is "generation".

niter=900, loss=1.565

base mode is Meta-Llama-3-8B

the performance of ICL is also very bad ??? kb_size is 50. OOM when kb_size=80.

shiwanghua avatar May 28 '25 12:05 shiwanghua

the performance of ICL is also very bad ???

This is a little bit confusing, have you checked the output quality? ICL should be very good

xidulu avatar May 28 '25 13:05 xidulu

the performance of ICL is also very bad ???

This is a little bit confusing, have you checked the output quality? ICL should be very good

I have checked it. And I used another trained model, got bad performance either in kb mode:

niter = 20000
train set: synthetic
eval set: synthetic
loss: 1.232
base mode is Meta-Llama-3-8B

The parameter is basically in default.

The --sep_query_head is added in training, but not pass the query_head_path in evaluation because I can't find it.

So now I'm training synthetic set with Meta-Llama-3-8B-instruct model and not set the --sep_query_head parameter.

shiwanghua avatar May 29 '25 02:05 shiwanghua

ICL should not involve any trained parameter, since you don't use any learned KV vectors, but just append the selected context into the prompt.

And notice that for KBLaM, if you don't inject any encoded KV vectors, then the model should be exactly the pre-trained model.

xidulu avatar May 29 '25 02:05 xidulu

maybe because I use all-MiniLM-L6-v2 embedding model

shiwanghua avatar May 30 '25 02:05 shiwanghua

train synthetic, eval synthetic, Meta-Llama-3-8B, 2w iterations:

vary kb_size, kb generation: ○ 10:{'rouge1': 0.0659, 'rouge2': 0.0, 'rougeL': 0.0597, 'rougeLsum': 0.0597, 'bert_score_precision': 0.3326, 'bert_score_recall': 0.3724, 'bert_score_f1': 0.3505, 'mem_cost': 18.84} ○ 20:{'rouge1': 0.1062, 'rouge2': 0.0, 'rougeL': 0.0861, 'rougeLsum': 0.0861, 'bert_score_precision': 0.4700, 'bert_score_recall': 0.614, 'bert_score_f1': 0.5252, 'mem_cost': 20.27} ○ 50:{'rouge1': 0.1463, 'rouge2': 0.0197, 'rougeL': 0.1359, 'rougeLsum': 0.1359, 'bert_score_precision': 0.4294, 'bert_score_recall': 0.6425, 'bert_score_f1': 0.497, 'mem_cost': 21.89} ○ 100:{'rouge1': 0.2725, 'rouge2': 0.0776, 'rougeL': 0.2367, 'rougeLsum': 0.2367, 'bert_score_precision': 0.5319, 'bert_score_recall': 0.7293, 'bert_score_f1': 0.6133, 'mem_cost': 22.34} ○ 200:{'rouge1': 0.3907, 'rouge2': 0.1729, 'rougeL': 0.3485, 'rougeLsum': 0.3485, 'bert_score_precision': 0.5582, 'bert_score_recall': 0.7837, 'bert_score_f1': 0.6506, 'mem_cost': 22.35} ○ 500:{'rouge1': 0.4762, 'rouge2': 0.2715, 'rougeL': 0.4317, 'rougeLsum': 0.4317, 'bert_score_precision': 0.5736, 'bert_score_recall': 0.8204, 'bert_score_f1': 0.6736, 'mem_cost': 22.46} ○ 1000:{'rouge1': 0.4919, 'rouge2': 0.2894, 'rougeL': 0.4483, 'rougeLsum': 0.4483, 'bert_score_precision': 0.5733, 'bert_score_recall': 0.8202, 'bert_score_f1': 0.6733, 'mem_cost': 22.63} ○ 2000:{'rouge1': 0.307, 'rouge2': 0.1335, 'rougeL': 0.2757, 'rougeLsum': 0.2757, 'bert_score_precision': 0.5274, 'bert_score_recall': 0.7239, 'bert_score_f1': 0.6085, 'mem_cost': 22.96} ○ 5000:{'rouge1': 0.0415, 'rouge2': 0.0008, 'rougeL': 0.0383, 'rougeLsum': 0.0383, 'bert_score_precision': 0.4228, 'bert_score_recall': 0.4485, 'bert_score_f1': 0.4293, 'mem_cost': 24.85} ○ 10000:{'rouge1': 0.022, 'rouge2': 0.0004, 'rougeL': 0.0215, 'rougeLsum': 0.0215, 'bert_score_precision': 0.2719, 'bert_score_recall': 0.4171, 'bert_score_f1': 0.318, 'mem_cost': 28.23}

ICL: ○ 10:{'rouge1': 0.0551, 'rouge2': 0.0499, 'rougeL': 0.0551, 'rougeLsum': 0.0551, 'bert_score_precision': 0.3110, 'bert_score_recall': 0.5240, 'bert_score_f1': 0.3769, 'mem_cost': 22.10} ○ 20:{'rouge1': 0.0042, 'rouge2': 0.0, 'rougeL': 0.0028, 'rougeLsum': 0.00278, 'bert_score_precision': 0.3282, 'bert_score_recall': 0.3888, 'bert_score_f1': 0.3409, 'mem_cost': 22.29} ○ 50:{'rouge1': 0.0589, 'rouge2': 0.0409, 'rougeL': 0.0587, 'rougeLsum': 0.0587, 'bert_score_precision': 0.3130, 'bert_score_recall': 0.4453, 'bert_score_f1': 0.3577, 'mem_cost': 23.92} ○ 100:OOM

zeroshot: ○ 10:{'rouge1': 0.03, 'rouge2': 0.0, 'rougeL': 0.0252, 'rougeLsum': 0.0252, 'bert_score_precision': 0.4035, 'bert_score_recall': 0.4695, 'bert_score_f1': 0.4267, 'mem_cost': 22.05} ○ 20:{'rouge1': 0.0379, 'rouge2': 0.0, 'rougeL': 0.0299, 'rougeLsum': 0.0299, 'bert_score_precision': 0.4002, 'bert_score_recall': 0.4654, 'bert_score_f1': 0.4195, 'mem_cost': 25.73} ○ 50:{'rouge1': 0.0432, 'rouge2': 0.002, 'rougeL': 0.0419, 'rougeLsum': 0.0419, 'bert_score_precision': 0.3936, 'bert_score_recall': 0.449, 'bert_score_f1': 0.409, 'mem_cost': 25.74} ○ 100:{'rouge1': 0.0253, 'rouge2': 0.0, 'rougeL': 0.0237, 'rougeLsum': 0.0237, 'bert_score_precision': 0.3928, 'bert_score_recall': 0.44, 'bert_score_f1': 0.4049, 'mem_cost': 23.64} ○ 200:{'rouge1': 0.0287, 'rouge2': 0.0007, 'rougeL': 0.0259, 'rougeLsum': 0.0258, 'bert_score_precision': 0.3736, 'bert_score_recall': 0.4463, 'bert_score_f1': 0.3957, 'mem_cost': 23.67} ○ 500:{'rouge1': 0.033, 'rouge2': 0.002, 'rougeL': 0.0301, 'rougeLsum': 0.03, 'bert_score_precision': 0.3621, 'bert_score_recall': 0.4539, 'bert_score_f1': 0.391, 'mem_cost': 27.35} ○ 1000:{'rouge1': 0.0263, 'rouge2': 0.0014, 'rougeL': 0.0246, 'rougeLsum': 0.0245, 'bert_score_precision': 0.3655, 'bert_score_recall': 0.4457, 'bert_score_f1': 0.3901, 'mem_cost': 27.52}

shiwanghua avatar May 30 '25 07:05 shiwanghua

train synthetic, eval synthetic, Meta-Llama-3-8B, 2w iterations:

vary kb_size, kb generation: ○ 10:{'rouge1': 0.0659, 'rouge2': 0.0, 'rougeL': 0.0597, 'rougeLsum': 0.0597, 'bert_score_precision': 0.3326, 'bert_score_recall': 0.3724, 'bert_score_f1': 0.3505, 'mem_cost': 18.84} ○ 20:{'rouge1': 0.1062, 'rouge2': 0.0, 'rougeL': 0.0861, 'rougeLsum': 0.0861, 'bert_score_precision': 0.4700, 'bert_score_recall': 0.614, 'bert_score_f1': 0.5252, 'mem_cost': 20.27} ○ 50:{'rouge1': 0.1463, 'rouge2': 0.0197, 'rougeL': 0.1359, 'rougeLsum': 0.1359, 'bert_score_precision': 0.4294, 'bert_score_recall': 0.6425, 'bert_score_f1': 0.497, 'mem_cost': 21.89} ○ 100:{'rouge1': 0.2725, 'rouge2': 0.0776, 'rougeL': 0.2367, 'rougeLsum': 0.2367, 'bert_score_precision': 0.5319, 'bert_score_recall': 0.7293, 'bert_score_f1': 0.6133, 'mem_cost': 22.34} ○ 200:{'rouge1': 0.3907, 'rouge2': 0.1729, 'rougeL': 0.3485, 'rougeLsum': 0.3485, 'bert_score_precision': 0.5582, 'bert_score_recall': 0.7837, 'bert_score_f1': 0.6506, 'mem_cost': 22.35} ○ 500:{'rouge1': 0.4762, 'rouge2': 0.2715, 'rougeL': 0.4317, 'rougeLsum': 0.4317, 'bert_score_precision': 0.5736, 'bert_score_recall': 0.8204, 'bert_score_f1': 0.6736, 'mem_cost': 22.46} ○ 1000:{'rouge1': 0.4919, 'rouge2': 0.2894, 'rougeL': 0.4483, 'rougeLsum': 0.4483, 'bert_score_precision': 0.5733, 'bert_score_recall': 0.8202, 'bert_score_f1': 0.6733, 'mem_cost': 22.63} ○ 2000:{'rouge1': 0.307, 'rouge2': 0.1335, 'rougeL': 0.2757, 'rougeLsum': 0.2757, 'bert_score_precision': 0.5274, 'bert_score_recall': 0.7239, 'bert_score_f1': 0.6085, 'mem_cost': 22.96} ○ 5000:{'rouge1': 0.0415, 'rouge2': 0.0008, 'rougeL': 0.0383, 'rougeLsum': 0.0383, 'bert_score_precision': 0.4228, 'bert_score_recall': 0.4485, 'bert_score_f1': 0.4293, 'mem_cost': 24.85} ○ 10000:{'rouge1': 0.022, 'rouge2': 0.0004, 'rougeL': 0.0215, 'rougeLsum': 0.0215, 'bert_score_precision': 0.2719, 'bert_score_recall': 0.4171, 'bert_score_f1': 0.318, 'mem_cost': 28.23}

ICL: ○ 10:{'rouge1': 0.0551, 'rouge2': 0.0499, 'rougeL': 0.0551, 'rougeLsum': 0.0551, 'bert_score_precision': 0.3110, 'bert_score_recall': 0.5240, 'bert_score_f1': 0.3769, 'mem_cost': 22.10} ○ 20:{'rouge1': 0.0042, 'rouge2': 0.0, 'rougeL': 0.0028, 'rougeLsum': 0.00278, 'bert_score_precision': 0.3282, 'bert_score_recall': 0.3888, 'bert_score_f1': 0.3409, 'mem_cost': 22.29} ○ 50:{'rouge1': 0.0589, 'rouge2': 0.0409, 'rougeL': 0.0587, 'rougeLsum': 0.0587, 'bert_score_precision': 0.3130, 'bert_score_recall': 0.4453, 'bert_score_f1': 0.3577, 'mem_cost': 23.92} ○ 100:OOM

zeroshot: ○ 10:{'rouge1': 0.03, 'rouge2': 0.0, 'rougeL': 0.0252, 'rougeLsum': 0.0252, 'bert_score_precision': 0.4035, 'bert_score_recall': 0.4695, 'bert_score_f1': 0.4267, 'mem_cost': 22.05} ○ 20:{'rouge1': 0.0379, 'rouge2': 0.0, 'rougeL': 0.0299, 'rougeLsum': 0.0299, 'bert_score_precision': 0.4002, 'bert_score_recall': 0.4654, 'bert_score_f1': 0.4195, 'mem_cost': 25.73} ○ 50:{'rouge1': 0.0432, 'rouge2': 0.002, 'rougeL': 0.0419, 'rougeLsum': 0.0419, 'bert_score_precision': 0.3936, 'bert_score_recall': 0.449, 'bert_score_f1': 0.409, 'mem_cost': 25.74} ○ 100:{'rouge1': 0.0253, 'rouge2': 0.0, 'rougeL': 0.0237, 'rougeLsum': 0.0237, 'bert_score_precision': 0.3928, 'bert_score_recall': 0.44, 'bert_score_f1': 0.4049, 'mem_cost': 23.64} ○ 200:{'rouge1': 0.0287, 'rouge2': 0.0007, 'rougeL': 0.0259, 'rougeLsum': 0.0258, 'bert_score_precision': 0.3736, 'bert_score_recall': 0.4463, 'bert_score_f1': 0.3957, 'mem_cost': 23.67} ○ 500:{'rouge1': 0.033, 'rouge2': 0.002, 'rougeL': 0.0301, 'rougeLsum': 0.03, 'bert_score_precision': 0.3621, 'bert_score_recall': 0.4539, 'bert_score_f1': 0.391, 'mem_cost': 27.35} ○ 1000:{'rouge1': 0.0263, 'rouge2': 0.0014, 'rougeL': 0.0246, 'rougeLsum': 0.0245, 'bert_score_precision': 0.3655, 'bert_score_recall': 0.4457, 'bert_score_f1': 0.3901, 'mem_cost': 27.52}

Hello, have you run eval.py generation kb mode? The output effect is very poor. Have you fixed the bug in eval.py?

Chloe-mxxxxc avatar May 30 '25 08:05 Chloe-mxxxxc

Hello, have you run eval.py generation kb mode? The output effect is very poor. Have you fixed the bug in eval.py?

I have run the five evaluation modes successfully. Don't know where is wrong. Strange result especially for ICL.

Besides, I implement a new evaluation mode that can compare KB\ICL\zeroshot\originalModel four generation modes' string output and rouge/bert scores for self constructed dataset meantime.

shiwanghua avatar May 30 '25 12:05 shiwanghua

KB mode, evaluate synthetic data: ○ 200 & kb_scale_factor=None:{'rouge1': 0.3907, 'rouge2': 0.1729, 'rougeL': 0.3485, 'rougeLsum': 0.3485, 'bert_score_precision': 0.5582, 'bert_score_recall': 0.7837, 'bert_score_f1': 0.6506, 'mem_cost': 22.35} ○ 200 & kb_scale_factor=100:{'rouge1': 0.2744, 'rouge2': 0.081, 'rougeL': 0.2447, 'rougeLsum': 0.2447, 'bert_score_precision': 0.532, 'bert_score_recall': 0.7295, 'bert_score_f1': 0.6128, 'mem_cost': 19.41} --Worse ○ 200 & kb_scale_factor=200: ○ 200 & kb_scale_factor=200:{'rouge1': 0.4044, 'rouge2': 0.1786, 'rougeL': 0.3582, 'rougeLsum': 0.3582, 'bert_score_precision': 0.5615, 'bert_score_recall': 0.7947, 'bert_score_f1': 0.6568, 'mem_cost': 18.94} --Better ○ 200 & kb_scale_factor=200 & max_new_tokens=300:{'rouge1': 0.3886, 'rouge2': 0.171, 'rougeL': 0.34, 'rougeLsum': 0.34, 'bert_score_precision': 0.5596, 'bert_score_recall': 0.7827, 'bert_score_f1': 0.6512, 'mem_cost': 18.94} --Worse

KB mode, evaluate enron data (Generalized dataset): ○ 200:{'rouge1': 0.1109, 'rouge2': 0.0107, 'rougeL': 0.0964, 'rougeLsum': 0.0964, 'bert_score_precision': 0.4666, 'bert_score_recall': 0.5715, 'bert_score_f1': 0.5015, 'mem_cost': 19.42} ○ 200 & kb_scale_factor=100:{'rouge1': 0.117, 'rouge2': 0.0145, 'rougeL': 0.1027, 'rougeLsum': 0.1027, 'bert_score_precision': 0.4678, 'bert_score_recall': 0.5757, 'bert_score_f1': 0.5033, 'mem_cost': 19.24} --Better.

shiwanghua avatar May 30 '25 13:05 shiwanghua

Hello, have you run eval.py generation kb mode? The output effect is very poor. Have you fixed the bug in eval.py?

I have found the reason.

Use the Meta-Llama-3-8B-Instruct to train synthetic dataset, the performance on training dataset is good:

○ 训练21000轮 kbsize200,scale_factor800:{'rouge1': 0.7618, 'rouge2': 0.5479, 'rougeL': 0.6964, 'rougeLsum': 0.6964, 'bert_score_precision': 0.8801, 'bert_score_recall': 0.8798, 'bert_score_f1': 0.8795, 'time_cost': 307.9}
○ 训练21000轮 kbsize500,kb_scale_factor900:{'rouge1': 0.7632, 'rouge2': 0.568, 'rougeL': 0.7099, 'rougeLsum': 0.7099, 'bert_score_precision': 0.8897, 'bert_score_recall': 0.8865, 'bert_score_f1': 0.8876, 'time_cost': 616.58}
○ 21k轮 kbsize5k,kb_scale_factor1k:{'rouge1': 0.5245, 'rouge2': 0.3807, 'rougeL': 0.4874, 'rougeLsum': 0.4874, 'bert_score_precision': 0.7669, 'bert_score_recall': 0.7712, 'bert_score_f1': 0.7684, 'mem_cost': 21.5, 'time_cost': 551}
○ 21k轮 kbsize10k,kb_scale_factor1.1k:{'rouge1': 0.2793, 'rouge2': 0.19, 'rougeL': 0.2628, 'rougeLsum': 0.2628, 'bert_score_precision': 0.6405, 'bert_score_recall': 0.6544, 'bert_score_f1': 0.6467, 'mem_cost': 24.25, 'time_cost': 489}

However, on enron dataset, the generalization effect is still very poor, even worse than zeroshot:

○ kbsize100,kb_scale_factor900:{'rouge1': 0.1632, 'rouge2': 0.0213, 'rougeL': 0.1422, 'rougeLsum': 0.1422, 'bert_score_precision': 0.6452, 'bert_score_recall': 0.6076, 'bert_score_f1': 0.624, 'mem_cost': 18.91, 'time_cost': 66.64}
○ kbsize200,kb_scale_factor800:{'rouge1': 0.1352, 'rouge2': 0.0113, 'rougeL': 0.1181, 'rougeLsum': 0.1181, 'bert_score_precision': 0.6327, 'bert_score_recall': 0.5857, 'bert_score_f1': 0.6062, 'mem_cost': 18.94, 'time_cost': 134.46}
○ **zeroshot**-kbsize200:{'rouge1': 0.1268, 'rouge2': 0.0148, 'rougeL': 0.1116, 'rougeLsum': 0.1116, 'bert_score_precision': 0.5868, 'bert_score_recall': 0.5783, 'bert_score_f1': 0.5799, 'mem_cost': 18.94, 'time_cost': 352.54}
○ **ICl**-kbsize100:{'rouge1': 0.9445, 'rouge2': 0.9131, 'rougeL': 0.9445, 'rougeLsum': 0.9433, 'bert_score_precision': 0.9261, 'bert_score_recall': 0.9439, 'bert_score_f1': 0.9338, 'mem_cost': 37.69, 'time_cost': 184.79}
○ **ICl**-kbsize120:{'rouge1': 0.9552, 'rouge2': 0.9225, 'rougeL': 0.954, 'rougeLsum': 0.9527, 'bert_score_precision': 0.9354, 'bert_score_recall': 0.9525, 'bert_score_f1': 0.9434, 'mem_cost': 38.87, 'time_cost': 226.91}

shiwanghua avatar Jun 04 '25 06:06 shiwanghua

This is a little bit confusing, have you checked the output quality? ICL should be very good

That is because the base model is Meta-Llama-3-8B instead of Meta-Llama-3-8B-Instruct, leading to poor result of ICL.

shiwanghua avatar Jun 04 '25 06:06 shiwanghua

Hello, have you run eval.py generation kb mode? The output effect is very poor. Have you fixed the bug in eval.py?

I have found the reason.

Use the Meta-Llama-3-8B-Instruct to train synthetic dataset, the performance on training dataset is good:

○ 训练21000轮 kbsize200,scale_factor800:{'rouge1': 0.7618, 'rouge2': 0.5479, 'rougeL': 0.6964, 'rougeLsum': 0.6964, 'bert_score_precision': 0.8801, 'bert_score_recall': 0.8798, 'bert_score_f1': 0.8795, 'time_cost': 307.9}
○ 训练21000轮 kbsize500,kb_scale_factor900:{'rouge1': 0.7632, 'rouge2': 0.568, 'rougeL': 0.7099, 'rougeLsum': 0.7099, 'bert_score_precision': 0.8897, 'bert_score_recall': 0.8865, 'bert_score_f1': 0.8876, 'time_cost': 616.58}
○ 21k轮 kbsize5k,kb_scale_factor1k:{'rouge1': 0.5245, 'rouge2': 0.3807, 'rougeL': 0.4874, 'rougeLsum': 0.4874, 'bert_score_precision': 0.7669, 'bert_score_recall': 0.7712, 'bert_score_f1': 0.7684, 'mem_cost': 21.5, 'time_cost': 551}
○ 21k轮 kbsize10k,kb_scale_factor1.1k:{'rouge1': 0.2793, 'rouge2': 0.19, 'rougeL': 0.2628, 'rougeLsum': 0.2628, 'bert_score_precision': 0.6405, 'bert_score_recall': 0.6544, 'bert_score_f1': 0.6467, 'mem_cost': 24.25, 'time_cost': 489}

However, on enron dataset, the generalization effect is still very poor, even worse than zeroshot: ○ kbsize100,kb_scale_factor900:{'rouge1': 0.1632, 'rouge2': 0.0213, 'rougeL': 0.1422, 'rougeLsum': 0.1422, 'bert_score_precision': 0.6452, 'bert_score_recall': 0.6076, 'bert_score_f1': 0.624, 'mem_cost': 18.91, 'time_cost': 66.64} ○ kbsize200,kb_scale_factor800:{'rouge1': 0.1352, 'rouge2': 0.0113, 'rougeL': 0.1181, 'rougeLsum': 0.1181, 'bert_score_precision': 0.6327, 'bert_score_recall': 0.5857, 'bert_score_f1': 0.6062, 'mem_cost': 18.94, 'time_cost': 134.46} ○ zeroshot-kbsize200:{'rouge1': 0.1268, 'rouge2': 0.0148, 'rougeL': 0.1116, 'rougeLsum': 0.1116, 'bert_score_precision': 0.5868, 'bert_score_recall': 0.5783, 'bert_score_f1': 0.5799, 'mem_cost': 18.94, 'time_cost': 352.54} ○ ICl-kbsize100:{'rouge1': 0.9445, 'rouge2': 0.9131, 'rougeL': 0.9445, 'rougeLsum': 0.9433, 'bert_score_precision': 0.9261, 'bert_score_recall': 0.9439, 'bert_score_f1': 0.9338, 'mem_cost': 37.69, 'time_cost': 184.79} ○ ICl-kbsize120:{'rouge1': 0.9552, 'rouge2': 0.9225, 'rougeL': 0.954, 'rougeLsum': 0.9527, 'bert_score_precision': 0.9354, 'bert_score_recall': 0.9525, 'bert_score_f1': 0.9434, 'mem_cost': 38.87, 'time_cost': 226.91}

I have three questions: 1. How did you determine the value of kb_scale_factor? 2. Can you post some GT and PRED results? 3. I want to confirm that when running eval.py with the Meta-Llama-3-8B-Instruct model, args.command == "generation", --eval_mode=="kb", can it output meaningful information? Thank you very much.

Chloe-mxxxxc avatar Jun 04 '25 10:06 Chloe-mxxxxc

I have three questions: 1. How did you determine the value of kb_scale_factor? 2. Can you post some GT and PRED results? 3. I want to confirm that when running eval.py with the Meta-Llama-3-8B-Instruct model, args.command == "generation", --eval_mode=="kb", can it output meaningful information? Thank you very much.

  1. vary kb_scale_factor, for example:
○ 训练21000轮 kbsize500:                            {'rouge1': 0.7393, 'rouge2': 0.5261, 'rougeL': 0.6796, 'rougeLsum': 0.6796, 'bert_score_precision': 0.8774, 'bert_score_recall': 0.8823, 'bert_score_f1': 0.8793, 'mem_cost': 19.04}
○ 训练21000轮 kbsize500,scale_factor400:{'rouge1': 0.7147, 'rouge2': 0.4932, 'rougeL': 0.6542, 'rougeLsum': 0.6542, 'bert_score_precision': 0.8657, 'bert_score_recall': 0.8743, 'bert_score_f1': 0.8695, 'time_cost': 641.25}
○ 训练21000轮 kbsize500,scale_factor500:{'rouge1': 0.7372, 'rouge2': 0.5237, 'rougeL': 0.6787, 'rougeLsum': 0.6787, 'bert_score_precision': 0.8738, 'bert_score_recall': 0.8796, 'bert_score_f1': 0.8762, 'time_cost': 633.41}
○ 训练21000轮 kbsize500,scale_factor600:{'rouge1': 0.7492, 'rouge2': 0.5489, 'rougeL': 0.6915, 'rougeLsum': 0.6915, 'bert_score_precision': 0.8793, 'bert_score_recall': 0.8838, 'bert_score_f1': 0.8811, 'time_cost': 626.67}                                                                     
○ 训练21000轮 kbsize500,scale_factor700:{'rouge1': 0.7588, 'rouge2': 0.5629, 'rougeL': 0.703, 'rougeLsum': 0.703, 'bert_score_precision': 0.8834, 'bert_score_recall': 0.8858, 'bert_score_f1': 0.8842, 'time_cost': 625.35}
○ 训练21000轮 kbsize500,scale_factor800:{'rouge1': 0.7662, 'rouge2': 0.5707, 'rougeL': 0.711, 'rougeLsum': 0.711, 'bert_score_precision': 0.8868, 'bert_score_recall': 0.8868, 'bert_score_f1': 0.8864, 'time_cost': 620.02}
○ 训练21000轮 kbsize500,scale_factor900:{'rouge1': 0.7632, 'rouge2': 0.568, 'rougeL': 0.7099, 'rougeLsum': 0.7099, 'bert_score_precision': 0.8897, 'bert_score_recall': 0.8865, 'bert_score_f1': 0.8876, 'time_cost': 616.58}
○ 训练21000轮 kbsize500,scale_factor1000:{'rouge1': 0.7679, 'rouge2': 0.5588, 'rougeL': 0.7107, 'rougeLsum': 0.7107, 'bert_score_precision': 0.8879, 'bert_score_recall': 0.8833, 'bert_score_f1': 0.8851, 'time_cost': 613.64}
○ 训练21000轮 kbsize500,scale_factor1100:{'rouge1': 0.7637, 'rouge2': 0.5551, 'rougeL': 0.7062, 'rougeLsum': 0.7062, 'bert_score_precision': 0.8854, 'bert_score_recall': 0.8795, 'bert_score_f1': 0.8819, 'time_cost': 610.87}
  1. result when kb_size=500, kb_scale_factor=1000: (These outputs have been truncated by the regular expression)
PREDICTION-381: to foster innovation, support entrepreneurs, and drive economic growth.
GT-381: to foster technological innovation, support entrepreneurs, and drive economic growth
PREDICTION-382: to facilitate business meetings, seminars, and corporate events.
GT-382: to facilitate business meetings, seminars, and corporate events
PREDICTION-383: a podcast series that dives into historical events and their lasting impact.
GT-383: a podcast that delves into historical events and their lasting impact
PREDICTION-384: enrich cultural diversity through artistic expressions.
GT-384: enrich the cultural landscape through diverse artistic expressions
PREDICTION-385: ensure fast internet, provide reliable connectivity, and offer easy setup.
GT-385: ensure reliable connectivity, provide fast speeds, and offer easy setup
PREDICTION-386: to promote a healthy and active lifestyle.
GT-386: to promote a healthy and active lifestyle
PREDICTION-387: to provide aid to those in need and support the healthcare system.
GT-387: to provide aid and support to those in need
PREDICTION-388: to provide fitness enthusiasts with stylish and functional workout gear.
GT-388: to provide fitness enthusiasts with stylish and functional workout gear
PREDICTION-389: create relaxing and sustainable living environments, support artisanal craftsmanship, and promote eco-friendly production.
GT-389: create a relaxing ambiance, promote sustainable materials, and support artisanal craftsmanship
PREDICTION-390: to offer a guilt-free coffee experience.
GT-390: to offer a guilt-free coffee experience
PREDICTION-391: to preserve and maintain architectural heritage.
GT-391: to maintain the architectural heritage of communities
PREDICTION-392: to create a more inclusive and equitable society.
GT-392: to create a more just and inclusive society
PREDICTION-393: to preserve biodiversity and educate the public about wildlife conservation.
GT-393: to conserve biodiversity and educate the public on wildlife preservation
PREDICTION-394: clean water, reduce contaminants, and improve health.
GT-394: clean up water quality, reduce contaminants, and promote health
PREDICTION-395: a light and soft feathered bird with gentle movements.
GT-395: a sparrow with soft, velvety feathers
PREDICTION-396: provide high-quality fitness classes, support physical fitness, and foster a supportive fitness community.
GT-396: provide high-quality fitness instruction, support physical health, and foster a community of fitness enthusiasts
PREDICTION-397: reduce screen time, promote relaxation, and provide a compact device.
GT-397: reduce screen time, promote relaxation, and provide a space for unplugging
PREDICTION-398: to highlight the significance of hard work and determination.
GT-398: to highlight the significance of hard work and determination
PREDICTION-399: a underwater display system for augmented reality.
GT-399: a holographic display system for underwater visuals
PREDICTION-400: a travel agency specializing in eco-friendly and sustainable tourism.
GT-400: a travel agency specializing in eco-friendly and sustainable tourism
  1. You should train the model to a KBLaM type model with its own format. The generalization result is bad.

shiwanghua avatar Jun 04 '25 11:06 shiwanghua

The result of multi_entites test is also very bad:

○ multi_entites=1,scale_factor=900:{'rouge1': 0.7632, 'rouge2': 0.568, 'rougeL': 0.7099, 'rougeLsum': 0.7099, 'bert_score_precision': 0.8897, 'bert_score_recall': 0.8865, 'bert_score_f1': 0.8876, 'time_cost': 616.58}
○ multi_entites=2,scale_factor=900:{'rouge1': 0.3554, 'rouge2': 0.1774, 'rougeL': 0.3034, 'rougeLsum': 0.3034, 'bert_score_precision': 0.671, 'bert_score_recall': 0.5647, 'bert_score_f1': 0.6124, 'mem_cost': 19.04, 'time_cost': 581.56}
○ multi_entites=3,scale_factor=900:{'rouge1': 0.1891, 'rouge2': 0.0903, 'rougeL': 0.1596, 'rougeLsum': 0.1596, 'bert_score_precision': 0.4821, 'bert_score_recall': 0.3776, 'bert_score_f1': 0.4227, 'mem_cost': 19.04, 'time_cost': 537.01}
○ multi_entites=5,scale_factor=900:{'rouge1': 0.0476, 'rouge2': 0.0162, 'rougeL': 0.0396, 'rougeLsum': 0.0396, 'bert_score_precision': 0.2022, 'bert_score_recall': 0.1545, 'bert_score_f1': 0.1747, 'mem_cost': 19.12, 'time_cost': 468.62}
○ multi_entites=8,scale_factor=900:{'rouge1': 0.048, 'rouge2': 0.0075, 'rougeL': 0.0399, 'rougeLsum': 0.0399, 'bert_score_precision': 0.383, 'bert_score_recall': 0.2958, 'bert_score_f1': 0.3334, 'mem_cost': 19.32, 'time_cost': 528.36}
○ multi_entites=10,scale_factor=900:{'rouge1': 0.0567, 'rouge2': 0.006, 'rougeL': 0.0477, 'rougeLsum': 0.0477, 'bert_score_precision': 0.5758, 'bert_score_recall': 0.4374, 'bert_score_f1': 0.4966, 'mem_cost': 19.51, 'time_cost': 605.59}

shiwanghua avatar Jun 04 '25 11:06 shiwanghua


"""Script for evaluating KB models"""

import argparse
import json
import os
import re
import shutil
from pathlib import Path
from typing import Any, Dict, List, Optional

import numpy as np
import time
import datetime
import torch
import transformers
from tqdm import tqdm
from transformers import AutoTokenizer, logging, AutoModelForCausalLM, pipeline

from kblam.kb_encoder import KBEncoder
from kblam.models.kblam_config import KBLaMConfig
from kblam.models.llama3_model import KblamLlamaForCausalLM
from kblam.models.phi3_model import KBLaMPhi3ForCausalLM
from kblam.utils.data_utils import generate_multi_entity_qa # aug_row,
from kblam.utils.eval_utils import (
    instruction_prompts,
    instruction_prompts_multi_entities,
    zero_shot_prompt,
    zero_shot_prompt_multi_entities,
    _format_Q_llama,
    _format_Q_phi3,
    model_prune_format_mapping,
    answer_question,
    softmax,
)
from kblam.utils.train_utils import get_kb_embd

import nltk
# nltk.download("wordnet", download_dir='/home/ubisec/nltk_data')
logging.set_verbosity_warning()

from rouge_score import rouge_scorer
rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL', 'rougeLsum'], use_stemmer=True)
import evaluate
# rouge = evaluate.load("rouge")

# bert_score = evaluate.load("bertscore")
# from bert_score import score as bert_score
from bert_score import BERTScorer
bert_score = BERTScorer(
    model_type="/models/deberta-xlarge-mnli",  # 本地模型目录路径
    num_layers=40, # 'microsoft/deberta-large': 16, 'microsoft/deberta-large-mnli': 18, 'microsoft/deberta-xlarge': 18, 'microsoft/deberta-xlarge-mnli': 40, 'microsoft/deberta-base': 9, 'microsoft/deberta-base-mnli': 9, 'microsoft/deberta-large': 16
    device='cuda:1',
    lang="en",
)


def get_evaluate_rouge(scorer, predictions, references):
    rouge1, rouge2, rougel, rougelsum = [], [], [], []
    for pred, ref in zip(predictions, references):
        score = scorer.score(ref, pred)
        rouge1.append(score['rouge1'].fmeasure)
        rouge2.append(score['rouge2'].fmeasure)
        rougel.append(score['rougeL'].fmeasure)
        rougelsum.append(score['rougeLsum'].fmeasure)
    return {'rouge1': np.mean(rouge1), 'rouge2':  np.mean(rouge2), 'rougeL':  np.mean(rougel), 'rougeLsum':  np.mean(rougelsum)}

class KBRetriever:
    def __init__(
        self,
        encoder: KBEncoder,
        dataset: List[Dict],
        precomputed_embed_keys_path: Optional[str] = None,
        precomputed_embed_values_path: Optional[np.ndarray] = None,
    ):
        self.encoder = encoder
        self.dataset = dataset
        if precomputed_embed_keys_path is not None:
            self.key_embds = np.load(precomputed_embed_keys_path).astype("float32")
        else:
            self.key_embds = None
        if precomputed_embed_values_path is not None:
            self.value_embds = np.load(precomputed_embed_values_path).astype("float32")
        else:
            self.value_embds = None

        if precomputed_embed_keys_path is not None:
            assert len(dataset) == len(self.key_embds)

    def _use_cached_embd(self):
        if self.key_embds is not None and self.value_embds is not None:
            return True
        else:
            return False

    def get_key_embeddings(self, batch_indices):
        if self._use_cached_embd():
            return get_kb_embd(
                self.encoder,
                batch_indices,
                precomputed_embd=(self.key_embds, self.value_embds),
            )
        else:
            return get_kb_embd(self.encoder, batch_indices, kb_dict=self.dataset)


def perform_eval(
    model: KBLaMPhi3ForCausalLM | KblamLlamaForCausalLM,
    tokenizer: transformers.PreTrainedTokenizer,
    kb_retriever: KBRetriever,
    encoder_model_spec: str,
    kb_config: KBLaMConfig,
    eval_mode: str = "kb",
    kb_size: int = 250,
    seed: int = 1,
    topk_size: int = -1,
    multi_entites: int = -1,
    remove_sorry: bool = False,
):
    np.random.seed(seed)
    kb_idx = np.random.randint(0, len(kb_retriever.dataset), kb_size)
    test_kb = [kb_retriever.dataset[idx] for idx in kb_idx]
    kb_embedding = ()
    key_str = [row["key_string"] for row in test_kb]
    value_str = [row["description"] for row in test_kb]
    prompt_strs = ""
    for k, v in zip(key_str, value_str):
        prompt_strs += f"{k} is {v}; "

    kb_embedding = kb_retriever.get_key_embeddings(kb_idx)

    model_outputs = []
    answers = []
    full_outputs = []
    subset_size = min(
        400, len(test_kb)
    )  # Regardless of KB size, always test 250 questions, otherwise it will be too slow;   subset_size = 50
    for row in tqdm(test_kb[:subset_size]):
        if multi_entites == -1:
            Q = row["Q"]
            answer = row["A"]
        else:
            kb_subset_idx = np.random.randint(0, len(test_kb), multi_entites)
            Q, answer = generate_multi_entity_qa(
                [test_kb[i]["name"] for i in kb_subset_idx],
                [test_kb[i]["description_type"] for i in kb_subset_idx],
                [test_kb[i]["description"] for i in kb_subset_idx],
            )

        if eval_mode == "kb":
            model_output = answer_question(
                tokenizer,
                model,
                Q,
                kb=kb_embedding,
                # topk_size=topk_size,
                kb_config=kb_config,
            ).split(Q)[1]
        elif eval_mode == "icl":
            if multi_entites != -1:
                ins_prompt = instruction_prompts_multi_entities
            else:
                ins_prompt = instruction_prompts
            model_output = answer_question(
                tokenizer,
                model,
                ins_prompt + prompt_strs + Q,
                kb=None,
                kb_config=kb_config,
            ).split(Q)[1]
        elif eval_mode == "zeroshot":
            if multi_entites != -1:
                ins_prompt = zero_shot_prompt_multi_entities
            else:
                ins_prompt = zero_shot_prompt
            model_output = answer_question(
                tokenizer, model, ins_prompt + Q, kb=None, kb_config=kb_config
            ).split(Q)[1]
        # print(model_output)
        if remove_sorry:
            if "sorry" in model_output:
                continue
        full_outputs.append((model_output, answer))
        if multi_entites == -1:
            pattern = r'The\s+\w+\s+of\s+[^"]+\s+is\s+(.+)'
            match = re.search(pattern, model_output)
            answers.append(row["description"])
            if match:
                model_output = match.group(1)
        else:
            pattern = r"(?:is|are) (.*?)(?:\.|;)"
            matches = re.findall(pattern, model_output)
            model_output = "; ".join(matches)
            answers.append(";".join(re.findall(r"(?:is|are) (.*?);", answer)))
        model_outputs.append(model_output)

    print(f"KB size: {kb_size}, mode: {eval_mode}")
    for i,(pred, gt) in enumerate(zip(model_outputs, answers)):
        print(f"PREDICTION-{i+1}: {pred}")
        print(f"GT-{i+1}: {gt}")
    # rouge_scores = rouge.compute(predictions=model_outputs, references=answers)
    rouge_scores = get_evaluate_rouge(rouge, model_outputs, answers)
    print(f'rouge_scores={rouge_scores}')

    results_dict = {k: round(float(v),4) for k, v in rouge_scores.items()}

    # bertscore = bert_score.compute(
    #     predictions=model_outputs,
    #     references=answers,
    #     lang="en",
    #     model_type="microsoft/deberta-xlarge-mnli",
    # )
    # P, R, F1 = bert_score(
    #     cands=model_outputs,
    #     refs=answers,
    #     model_type="/models/deberta-xlarge-mnli",
    #     lang="en",
    #     verbose=True,  # 显示进度
    #     device="cuda:1" if torch.cuda.is_available() else "cpu",  # 自动选择设备
    # )

    P, R, F1 = bert_score.score(model_outputs, answers, batch_size=10, verbose=True)
    bertscore = {
        "precision": P.tolist(),
        "recall": R.tolist(),
        "f1": F1.tolist(),
    }
    
    for i, (k, v) in enumerate(bertscore.items()):
        if isinstance(v, list):
            results_dict[f"bert_score_{k}"] = round(float(np.mean(v)),4)
            print(f'bertscore-{i+1}', k, np.mean(v), v)
        else:
            print(f'bertscore-{i+1}', k, v)
    results = ""
    for i, (a, A) in enumerate(full_outputs):
        results += f"Model output {i+1}: {a}\nTrue answer: {A}\n-------\n"
    if eval_mode == "kb":
        eval_mode = encoder_model_spec + eval_mode

    return results, results_dict


def perform_eval_refusal(
    model: KBLaMPhi3ForCausalLM | KblamLlamaForCausalLM,
    tokenizer: transformers.PreTrainedTokenizer,
    kb_retriever: KBRetriever,
    kb_config: Optional[KBLaMConfig] = None,
    eval_mode: str = "kb",
    kb_size: int = 200,
    seed: int = 1,
    outlier_ratio: float = 0.2, # 需要拒绝回答的问题的占比
    topk_size: int = -1,
    question_size: int = 100,
):
    instruction_prompts = (
        'Please answer questions based on the given text with format: "The {property} of {name} is {description}",'
        ' if relevant information cannot be found in the text, please respond "I am sorry I cannot find relevant information in the KB".'
    )
    zero_shot_prompt = """
    Please answer the question in a very compact manner with format: The {property} of {name} is {description}
    """

    np.random.seed(seed)
    kb_idx = np.random.randint(0, len(kb_retriever.dataset), kb_size)
    test_kb = [kb_retriever.dataset[idx] for idx in kb_idx]
    key_str = [row["key_string"] for row in test_kb]
    value_str = [row["description"] for row in test_kb]
    prompt_strs = ""
    for k, v in zip(key_str, value_str):
        prompt_strs += f"{k} is {v}; "

    kb_embedding = kb_retriever.get_key_embeddings(kb_idx)

    model_outputs = []
    answers = []
    # answer_question
    outlier_idx = np.arange(len(kb_retriever.dataset))
    outlier_idx = outlier_idx[~np.isin(outlier_idx, kb_idx)] # 保留不在 kb_idx 里的下标对应的元素
    np.random.shuffle(outlier_idx)
    question_size = min(kb_size, question_size)
    outlier_idx = outlier_idx[: int(question_size * outlier_ratio)]
    test_kb = test_kb[: int(question_size * (1 - outlier_ratio))] + [
        kb_retriever.dataset[idx] for idx in outlier_idx
    ]
    change_point = int(question_size * (1 - outlier_ratio)) # 前这么多个问题是不应该拒绝回答的
    for i, row in tqdm(enumerate(test_kb)):
        Q = row["Q"]
        if eval_mode == "kb":
            model_output = answer_question(
                tokenizer,
                model,
                Q,
                kb=kb_embedding,
                # topk_size=topk_size,
                kb_config=kb_config,
            ).split(Q)[1]

        elif eval_mode == "icl":
            model_output = answer_question(
                tokenizer,
                model,
                instruction_prompts + prompt_strs + Q,
                kb=None,
                kb_config=kb_config,
            ).split(Q)[1]
        elif eval_mode == "zeroshot":
            model_output = answer_question(
                tokenizer,
                model,
                zero_shot_prompt + Q,
                kb=None,
                kb_config=kb_config,
            ).split(Q)[1]
        model_outputs.append(model_output)
        if i < change_point:
            answers.append(row["description"])
        else:
            answers.append("cannot find relevant information in the KB")
    true_label = [0] * change_point + [1] * int(question_size * outlier_ratio)
    pattern = r"(cannot|can't)\s+find\s+.*\s+information"
    prediction = [int("sorry" in model_output or bool(re.search(pattern, model_output))) for model_output in model_outputs]
    print(f"KB size: {kb_size}, mode: {eval_mode}, question_size: {question_size}, outlier ratio: {outlier_ratio}")
    results = ""
    for a, A in zip(model_outputs, answers):
        results += f"Model output: {a}\nTrue answer: {A}\n-------\n"
    return results, np.array([prediction, true_label])


parser = argparse.ArgumentParser(description="Evaluation script")

# Add arguments that will be shared across all subcommands
parent_parser = argparse.ArgumentParser(add_help=False)

parent_parser.add_argument(
    "--dataset_dir", type=str, help="Directory containing the dataset", default='../datasets'
)
parent_parser.add_argument(
    "--encoder_dir", type=str, help="Directory containing the encoder model", default='/codes/KBLaM/experiments/output-250530/stage1_lr_0.0001_KBTokenLayerFreq3_UseOutlier1_SepQueryHead_UseDataAug__KeyFromkey_all-MiniLM-L6-v2_synthetic_llama3_step_21000_encoder/encoder.pt'
)
parent_parser.add_argument(
    "--encoder_spec",
    type=str,
    default="/models/all-MiniLM-L6-v2",
    help="Specification for the encoder model",
)
parent_parser.add_argument(
    "--fancy_instruction",
    action=argparse.BooleanOptionalAction,
    default=False,
    help="Whether to use fancy instructions",
)
parent_parser.add_argument(
    "--kb_layer_frequency",
    type=int,
    default=3,
    help="Frequency of knowledge base layers",
)
parent_parser.add_argument(
    "--kb_scale_factor",
    type=int,
    default=None,
    help="Scaling factor for knowledge base",
)
parent_parser.add_argument(
    "--kb_size", type=int, default=200, help="Size of the knowledge base"
)
parent_parser.add_argument( # 用于加载 tokenizer
    "--llm_base_dir",
    type=str,
    default='/models/Meta-Llama-3-8B-Instruct',
    help="llm to load, can be HF location or local directory",
)
parent_parser.add_argument(
    "--llm_type",
    type=str,
    default="llama3",
    choices=["llama3", "phi3"],
    help="Type of language model to use",
)
parent_parser.add_argument(
    "--model_dir", type=str, default='/codes/KBLaM/experiments/output-250530/stage1_lr_0.0001_KBTokenLayerFreq3_UseOutlier1_SepQueryHead_UseDataAug__KeyFromkey_all-MiniLM-L6-v2_synthetic_llama3_step_21000', help="Directory containing the model"
)
parent_parser.add_argument("--save_dir", type=str, default='eval-output-250526', help="Directory to save outputs")
parent_parser.add_argument("--seed", type=int, default=2025, help="Random seed for reproducibility")
parent_parser.add_argument(
    # "--test_dataset", type=str, default='synthetic.json', help="Source of test KB (assumes KV pair format)"
    "--test_dataset", type=str, default='enron.json', help="Source of test KB (assumes KV pair format)"
)
parent_parser.add_argument(
    "--precomputed_embed_keys_path", type=str, 
    # default='/codes/KBLaM/datasets/synthetic_all-MiniLM-L6-v2_embd_key.npy',
    default='/codes/KBLaM/datasets/enron_all-MiniLM-L6-v2_embd_key.npy',
    help="Path to precomputed key embeddings"
)
parent_parser.add_argument(
    "--precomputed_embed_values_path",
    type=str,
    # default='/codes/KBLaM/datasets/synthetic_all-MiniLM-L6-v2_embd_value.npy',
    default='/codes/KBLaM/datasets/enron_all-MiniLM-L6-v2_embd_value.npy',
    help="Path to precomputed value embeddings",
)
parent_parser.add_argument(
    "--query_head_path", type=str, default=None, help="Path to load KB head from"
)

# Create subparsers
subparsers = parser.add_subparsers(dest="command", required=True)

# Create the parser for the generation command
gen_parser = subparsers.add_parser(
    "generation", parents=[parent_parser], help="Evaluate generation"
)
gen_parser.add_argument(
    "--eval_mode",
    type=str,
    choices=["kb", "icl", "zeroshot"],
    default="kb",
    help="Evaluation mode: knowledge base, in-context learning, or zero-shot",
)
gen_parser.add_argument(
    "--exp_config_name",
    type=str,
    default="generation_results-250526",
    help="Name of the experiment configuration",
)
gen_parser.add_argument(
    "--kb_token_layer_frequency",
    type=int,
    default=None,
    help="Frequency of knowledge base token layers",
)
gen_parser.add_argument(
    "--multi_entites",
    type=int,
    default=-1,
    help="Number of entities to process (-1 for unlimited)",
)
gen_parser.add_argument(
    "--no_outlier",
    action=argparse.BooleanOptionalAction,
    default=False,
    help="Use checkpoints trained without outliers",
)
gen_parser.add_argument(
    "--remove_sorry",
    action=argparse.BooleanOptionalAction,
    default=False,
    help='Filter out "sorry" answers from the output',
)
gen_parser.add_argument(
    "--topk_size", type=int, default=-1, help="Size of top-k selection (-1 for all)"
)


# Create the parser for the accuracy command
acc_parser = subparsers.add_parser(
    "accuracy", parents=[parent_parser], help="Evaluate accuracy"
)

acc_parser.add_argument(
    "--attn_save_dir", type=str, default="accuracy_results_attn_save_dir_250526", help="Directory to save attention masks"
)
acc_parser.add_argument(
    "--exp_config_name",
    type=str,
    default="accuracy_results-250526",
    help="Name of the experiment configuration",
)
acc_parser.add_argument(
    "--fancy_question",
    action=argparse.BooleanOptionalAction,
    default=False,
    help="Enable fancy question format",
)
acc_parser.add_argument(
    "--log_save_dir", type=str, default="accuracy_results_log_250526", help="Directory to save accuracy results"
)
acc_parser.add_argument(
    "--test_batch_size", type=int, default=50, help="Batch size for testing"
)
acc_parser.add_argument(
    "--use_shift_match",
    action=argparse.BooleanOptionalAction,
    default=False,
    help="Enable shift matching",
)

# Create the parser for the accuracy eval
acc_results_parser = subparsers.add_parser(
    "acc_results", parents=[acc_parser], help="run accuracy eval", add_help=False
)


# Create the parser for the refusal command
ref_parser = subparsers.add_parser(
    "refusal", parents=[parent_parser], help="Evaluate refusal"
)
ref_parser.add_argument(
    "--eval_mode",
    type=str,
    choices=["kb", "icl", "zeroshot"],
    default="kb",
    help="Evaluation mode: knowledge base, in-context learning, or zero-shot",
)
ref_parser.add_argument(
    "--exp_config_name",
    type=str,
    default="refusal_results-250526",
    help="Name of the experiment configuration",
)
ref_parser.add_argument(
    "--kb_token_layer_frequency",
    type=int,
    default=None,
    help="Frequency of knowledge base token layers",
)
ref_parser.add_argument(
    "--multi_entites",
    type=int,
    default=-1,
    help="Number of entities to process (-1 for unlimited)",
)
ref_parser.add_argument(
    "--no_outlier",
    action=argparse.BooleanOptionalAction,
    default=False,
    help="Use checkpoints trained without outliers",
)
ref_parser.add_argument(
    "--remove_sorry",
    action=argparse.BooleanOptionalAction,
    default=False,
    help='Filter out "sorry" answers from the output',
)
ref_parser.add_argument(
    "--topk_size", type=int, default=-1, help="Size of top-k selection (-1 for all)"
)

# Create the parser for the standard command
basic_parser = subparsers.add_parser(
    "standard", parents=[parent_parser], help="Evaluate basic performance"
)
basic_parser.add_argument(
    "--attn_summary_save_dir",
    type=str,
    default="standard_summary_save_dir_250526",
    help="Directory to save attention masks",
)
basic_parser.add_argument(
    "--eval_mode",
    type=str,
    choices=["kb", "icl", "zeroshot"],
    default="kb",
    help="Evaluation mode: knowledge base, in-context learning, or zero-shot",
)
basic_parser.add_argument(
    "--exp_config_name",
    type=str,
    default="standard_results-250526",
    help="Name of the experiment configuration",
)
basic_parser.add_argument(
    "--exp_config_str", type=str, default='attention_file', help="Experiment configuration string"
)
basic_parser.add_argument(
    "--kb_token_layer_frequency",
    type=int,
    default=None,
    help="Frequency of knowledge base token layers",
)
basic_parser.add_argument(
    "--no_outlier",
    action=argparse.BooleanOptionalAction,
    default=False,
    help="Use checkpoints trained without outliers",
)
basic_parser.add_argument(
    "--sample_size", default=5, type=int, help="Number of samples to process"
)
basic_parser.add_argument(
    "--subset_size", default=100, type=int, help="Size of the data subset to use"
)
basic_parser.add_argument(
    "--topk_size", type=int, default=-1, help="Size of top-k selection (-1 for all)"
)


compare_parser = subparsers.add_parser(
    "compare", parents=[parent_parser], help="compare texts"
)

def eval_generate():
    """Evaluate generation using KB"""
    args = parser.parse_args()

    dataset_dir = args.dataset_dir
    encoder_model_spec = args.encoder_spec
    encoder_path = args.encoder_dir
    eval_mode = args.eval_mode
    exp_config = args.exp_config_name
    kb_layer_frequency = args.kb_layer_frequency
    kb_scale_factor = args.kb_scale_factor
    kb_size = args.kb_size
    llm_base_dir = args.llm_base_dir
    llm_type = args.llm_type
    model_path = args.model_dir
    seed = args.seed
    test_dataset = args.test_dataset
    query_head_path = args.query_head_path
    precomputed_embed_keys_path = args.precomputed_embed_keys_path
    precomputed_embed_values_path = args.precomputed_embed_values_path

    dataset = json.load(open(os.path.join(dataset_dir, test_dataset)))

    tokenizer, encoder, model, kb_config = _prepare_models(
        encoder_model_spec,
        encoder_path,
        llm_type,
        llm_base_dir,
        model_path,
        query_head_path,
        kb_layer_frequency,
        kb_scale_factor,
    )

    kb_retriever = KBRetriever(
        encoder,
        dataset,
        precomputed_embed_keys_path=precomputed_embed_keys_path,
        precomputed_embed_values_path=precomputed_embed_values_path,
    )
    
    for em in ["zeroshot"]: # "kb", "icl", "zeroshot"
        eval_mode = em
        for ks in [100, 200, 500, 1000, 2000, 5000, 10000]: # 50, 100, 500, 1000, .
            kb_size = ks
            for sf in [None, 1000]: # 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200, 1500, 2000
                if ks!=500 and sf==1000:
                    continue
                kb_config.kb_scale_factor = sf
                bt1 = time.time()
                gen_results, score_results = perform_eval(
                    model,
                    tokenizer,
                    kb_retriever,
                    encoder_model_spec,
                    kb_config,
                    eval_mode,
                    seed=seed,
                    kb_size=kb_size,
                    topk_size=args.topk_size,
                    multi_entites=args.multi_entites,
                )
                mem_cost = round(torch.cuda.max_memory_reserved("cuda:1")/1024**3,2)
                score_results["mem_cost"] = mem_cost

                (Path(args.save_dir) / exp_config).mkdir(exist_ok=True, parents=True) # 似乎没必要创建
                write_to_json(score_results, Path(args.save_dir) / f"{exp_config}.json")
                text_file = open(os.path.join(args.save_dir, exp_config + ".txt"), "w")
                text_file.write(gen_results)
                et1 = time.time()
                score_results["time_cost"] = round(et1-bt1, 2)
                print(f'eval_mode={eval_mode}, kb_size={kb_size}, kb_scale_factor={kb_config.kb_scale_factor}, score_results={score_results}')
    for em in ["icl"]: # "kb", "icl", "zeroshot"
        eval_mode = em
        for ks in [20, 50, 100, 120, 150, 200, 250]: # 50, 100, 500, 1000, .
            kb_size = ks
            for sf in [None]: # 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200, 1500, 2000
                kb_config.kb_scale_factor = sf
                bt1 = time.time()
                gen_results, score_results = perform_eval(
                    model,
                    tokenizer,
                    kb_retriever,
                    encoder_model_spec,
                    kb_config,
                    eval_mode,
                    seed=seed,
                    kb_size=kb_size,
                    topk_size=args.topk_size,
                    multi_entites=args.multi_entites,
                )
                mem_cost = round(torch.cuda.max_memory_reserved("cuda:1")/1024**3,2)
                score_results["mem_cost"] = mem_cost

                (Path(args.save_dir) / exp_config).mkdir(exist_ok=True, parents=True) # 似乎没必要创建
                write_to_json(score_results, Path(args.save_dir) / f"{exp_config}.json")
                text_file = open(os.path.join(args.save_dir, exp_config + ".txt"), "w")
                text_file.write(gen_results)
                et1 = time.time()
                score_results["time_cost"] = round(et1-bt1, 2)
                print(f'eval_mode={eval_mode}, kb_size={kb_size}, kb_scale_factor={kb_config.kb_scale_factor}, score_results={score_results}')


def _prepare_models(
    encoder_spec,
    encoder_path,
    llm_type,
    llm_base_dir,
    model_path,
    query_head_path,
    kb_layer_frequency,
    kb_scale_factor,
):
    tokenizer = AutoTokenizer.from_pretrained(
        llm_base_dir, trust_remote_code=True, padding_side="left"
    )
    tokenizer.pad_token = "^"

    if llm_type == "llama3":
        if query_head_path:
            model = KblamLlamaForCausalLM.from_pretrained(
                model_path,
                device_map="cuda:1",
                torch_dtype="auto",
                trust_remote_code=True,
            )
            model.load_query_head(query_head_path)
        else:
            model = KblamLlamaForCausalLM.from_pretrained(
                model_path,
                device_map="cuda:1",
                torch_dtype="auto",
                trust_remote_code=True,
            )
    else:
        model = KBLaMPhi3ForCausalLM.from_pretrained(
            model_path,
            device_map="cuda:1",
            torch_dtype="auto",
            trust_remote_code=True,
        )
    model.generation_config.pad_token_id = tokenizer.pad_token_id
    model.generation_config.eos_token_id = tokenizer.eos_token_id
    model.eval()

    # config = model.config.to_dict()
    kb_config = KBLaMConfig(
        sep_query_head=True,
        kb_layer_frequency=kb_layer_frequency,
        kb_scale_factor=kb_scale_factor,
    )
    # config.update(kb_config.to_dict())
    # new_config = KBLaMConfig(**config)
    # model.config = new_config

    encoder = KBEncoder(
        encoder_name=encoder_spec.upper(),
        projector_type="linear",
        endpoint_url="",
        out_dim=model.config.hidden_size
        * (model.config.num_hidden_layers // kb_layer_frequency + 1),
        frozen_base_model=True,
        projector_kwargs={"mlp_depth": 1, "mlp_hidden_dim": 512},
        device=torch.device("cuda:1"),
    )

    encoder.load_state_dict(torch.load(encoder_path))
    return tokenizer, encoder, model, kb_config


def eval_accuracy(
    tokenizer,
    kb_retriever,
    model,
    dataset,
    exp_config,
    fancy_question,
    kb_config,
    kb_size,
    llm_type,
    test_batch_size,
    save_dir,
    attn_save_dir,
):
    """Evaluate accuracy using KB"""

    if kb_size == len(dataset):
        dataset_subset_idx = range(len(dataset))
    elif kb_size > len(dataset):
        raise IndexError(
            f"The KB size {kb_size} is greater than the dataset size {len(dataset)}"
        )
    else:
        dataset_subset_idx = np.random.choice(len(dataset), kb_size, replace=False)

    dataset_subset = [dataset[i] for i in dataset_subset_idx]

    kb_embedding_real = kb_retriever.get_key_embeddings(dataset_subset_idx)

    format_func_map = {"llama3": _format_Q_llama, "phi3": _format_Q_phi3}

    if not fancy_question:
        input_strs_gen = (dataset_subset[i]["Q"] for i in range(test_batch_size))
    else:
        input_strs_gen = (aug_row(dataset_subset[i]) for i in range(test_batch_size))
    input_strs = [format_func_map[llm_type](ex) for ex in input_strs_gen]

    tokenizer_output = tokenizer(input_strs, return_tensors="pt", padding=True).to(
        "cuda:1"
    )
    input_ids, attention_masks = (
        tokenizer_output["input_ids"],
        tokenizer_output["attention_mask"],
    )
    
    os.makedirs(attn_save_dir, exist_ok=True)

    with torch.autograd.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_masks,
            kb_kvs=kb_embedding_real,
            max_new_tokens=60,
            tokenizer=tokenizer,
            output_attentions=True,
            save_attention_weights=True,
            kb_config=kb_config,
            attention_save_loc=attn_save_dir,
            attention_file_base_name=exp_config,
        )
        outputs = tokenizer.batch_decode(outputs.squeeze(), skip_special_tokens=False)

    save_path = Path(save_dir)
    save_path.mkdir(exist_ok=True, parents=True)

    with open(save_path / f"{exp_config}_acc.txt", "w+") as text_file:
        for output in outputs:
            output_string = output.strip("^")
            text_file.write(f"{str(output_string)}\n")

    accs = []
    with torch.autograd.no_grad():
        for idx in range(0, 32, kb_config.kb_layer_frequency):
            weight = np.load(os.path.join(attn_save_dir, f"{exp_config}_{idx}.npy"))
            weight = weight[..., :kb_size]
            label = np.arange(test_batch_size)
            weight = weight.reshape(test_batch_size, -1, kb_size)
            acc = (weight.sum(1).argmax(1) == label).mean()
            top_5_predictions = torch.topk(torch.from_numpy(weight.sum(1)), 5, dim=1)[1]
            top_5_acc = (top_5_predictions.numpy() == label[:, None]).any(1).mean()
            print(f"layer{idx} ACC & TOP 5 ACC: {idx} {(acc, top_5_acc)}")
            # print(f"min: {np.min(weight)}  max: {np.max(weight)}\n")
            accs.append(
                {
                    "idx": idx,
                    "acc": float(acc),
                    "top5acc": float(top_5_acc),
                }
            )

    np.save(
        save_path / f"{exp_config}_acc.npy",
        np.array([(a["acc"], a["top5acc"]) for a in accs]),
    )

    return accs


def eval_accuracy_cli():
    """Evaluate accuracy using KB"""
    args = parser.parse_args()

    dataset_dir = args.dataset_dir
    encoder_path = args.encoder_dir
    encoder_spec = args.encoder_spec
    exp_config = args.exp_config_name
    fancy_question = args.fancy_question
    kb_layer_frequency = args.kb_layer_frequency
    kb_scale_factor = args.kb_scale_factor
    kb_size = args.kb_size
    llm_base_dir = args.llm_base_dir
    llm_type = llm_type = args.llm_type
    model_path = args.model_dir
    test_batch_size = args.test_batch_size
    test_dataset = args.test_dataset
    precomputed_embed_keys_path = args.precomputed_embed_keys_path
    precomputed_embed_values_path = args.precomputed_embed_values_path

    query_head_path = args.query_head_path
    tokenizer, encoder, model, kb_config = _prepare_models(
        encoder_spec,
        encoder_path,
        llm_type,
        llm_base_dir,
        model_path,
        query_head_path,
        kb_layer_frequency,
        kb_scale_factor,
    )
    dataset = json.load(open(os.path.join(dataset_dir, test_dataset)))

    kb_retriever = KBRetriever(
        encoder,
        dataset,
        precomputed_embed_keys_path=precomputed_embed_keys_path,
        precomputed_embed_values_path=precomputed_embed_values_path,
    )

    eval_accuracy(
        tokenizer,
        kb_retriever,
        model,
        dataset,
        exp_config,
        fancy_question,
        kb_config,
        kb_size,
        llm_type,
        test_batch_size,
        args.log_save_dir,
        args.attn_save_dir,
    )


def write_to_json(
    data: Any, filepath: str, indent: int = 4, encoding: str = "utf-8"
) -> bool:
    """
    Write a dictionary to a JSON file with error handling and formatting options.

    Args:
        data: Dictionary to write to JSON file
        filepath: Path where the JSON file should be saved
        indent: Number of spaces for indentation (default: 4)
        encoding: File encoding (default: 'utf-8')

    Raises:
        TypeError: If data is not a dictionary
    """

    try:
        # Convert string path to Path object
        file_path = Path(filepath)

        # Write the JSON file
        with open(file_path, "w", encoding=encoding) as f:
            json.dump(
                data,
                f,
                indent=indent,
                sort_keys=True,  # For consistent output
                default=str,  # Handle non-serializable objects by converting to string
            )

    except Exception as e:
        print(f"Error writing JSON file: {str(e)}")


def run_accuracy_evalution():
    args = parser.parse_args()

    dataset_dir = args.dataset_dir
    encoder_path = args.encoder_dir
    encoder_spec = args.encoder_spec
    exp_config = args.exp_config_name
    fancy_question = args.fancy_question
    kb_layer_frequency = args.kb_layer_frequency
    kb_scale_factor = args.kb_scale_factor
    llm_base_dir = args.llm_base_dir
    llm_type = llm_type = args.llm_type
    model_path = args.model_dir
    test_dataset = args.test_dataset

    query_head_path = args.query_head_path
    precomputed_embed_keys_path = args.precomputed_embed_keys_path
    precomputed_embed_values_path = args.precomputed_embed_values_path

    tokenizer, encoder, model, kb_config = _prepare_models(
        encoder_spec,
        encoder_path,
        llm_type,
        llm_base_dir,
        model_path,
        query_head_path,
        kb_layer_frequency,
        kb_scale_factor,
    )

    dataset = json.load(open(os.path.join(dataset_dir, test_dataset)))
    kb_retriever = KBRetriever(
        encoder,
        dataset,
        precomputed_embed_keys_path=precomputed_embed_keys_path,
        precomputed_embed_values_path=precomputed_embed_values_path,
    )

    xs = [50, 100, 200, 400, 800] # , 1600, 3200, 6400
    accuracy_results = []
    for x in xs:
        print(f"kb_size {x}")

        accs = eval_accuracy(
            tokenizer,
            kb_retriever,
            model,
            dataset,
            exp_config,
            fancy_question,
            kb_config,
            x,
            llm_type,
            min(x, 200),
            args.log_save_dir,
            args.attn_save_dir,
        )
        shutil.rmtree(args.attn_save_dir)
        os.mkdir(args.attn_save_dir)
        accuracy_results.append({"kb_size": x, "accuracy_results": accs})
    write_to_json(
        accuracy_results, os.path.join(args.log_save_dir, "accuracy_results.json")
    )


def eval_refusal():
    """Evaluate refusal to answer questions for which the answer does not exist in the KB"""
    args = parser.parse_args()
    dataset_dir = args.dataset_dir
    encoder_model_spec = args.encoder_spec
    encoder_path = args.encoder_dir
    eval_mode = args.eval_mode
    exp_config = args.exp_config_name
    kb_layer_frequency = args.kb_layer_frequency
    kb_scale_factor = args.kb_scale_factor
    kb_size = args.kb_size
    llm_base_dir = args.llm_base_dir
    llm_type = args.llm_type
    model_path = args.model_dir
    seed = args.seed
    test_dataset = args.test_dataset
    precomputed_embed_keys_path = args.precomputed_embed_keys_path
    precomputed_embed_values_path = args.precomputed_embed_values_path
    query_head_path = args.query_head_path

    dataset = json.load(open(os.path.join(dataset_dir, test_dataset)))

    tokenizer, encoder, model, kb_config = _prepare_models(
        encoder_model_spec,
        encoder_path,
        llm_type,
        llm_base_dir,
        model_path,
        query_head_path,
        kb_layer_frequency,
        kb_scale_factor,
    )

    kb_retriever = KBRetriever(
        encoder,
        dataset,
        precomputed_embed_keys_path=precomputed_embed_keys_path,
        precomputed_embed_values_path=precomputed_embed_values_path,
    )

    gen_results, refusal_results = perform_eval_refusal(
        model,
        tokenizer,
        kb_retriever,
        eval_mode=eval_mode,
        seed=seed,
        kb_size=kb_size,
        topk_size=args.topk_size,
        kb_config=kb_config,
    )

    np.save(os.path.join(args.save_dir, "OutLierTest" + exp_config), refusal_results)
    text_file = open(
        os.path.join(args.save_dir, "OutLierTest" + exp_config + ".txt"), "w"
    )
    text_file.write(gen_results)


def eval():
    """Evaluate the KB model"""
    args = parser.parse_args()
    attn_summary_save_dir = args.attn_summary_save_dir
    dataset_dir = args.dataset_dir
    encoder_model_spec = args.encoder_spec
    encoder_path = args.encoder_dir
    exp_config_str = args.exp_config_str
    kb_layer_frequency = args.kb_layer_frequency
    kb_scale_factor = args.kb_scale_factor
    kb_size = args.kb_size
    llm_base_dir = args.llm_base_dir
    llm_type = args.llm_type
    model_path = args.model_dir
    output_dir = args.save_dir
    sample_size = args.sample_size
    seed = args.seed
    subset_size = args.subset_size
    test_dataset = args.test_dataset
    precomputed_embed_keys_path = args.precomputed_embed_keys_path
    precomputed_embed_values_path = args.precomputed_embed_values_path
    query_head_path = args.query_head_path
    sep_query_head = True
    actual_kb_token_layer_frequency = 3

    if kb_size == -1:
        kb_size = None

    # validation_part_start_idx = 120000 if 'gpt' in test_dataset else 0
    dataset = json.load(open(os.path.join(dataset_dir, test_dataset)))

    if sep_query_head:
        print("Having seperate query head for KB!")

    torch.manual_seed(seed)
    np.random.seed(seed)

    os.environ["ATTN_SAVE_DIR"] = output_dir
    os.environ["EVAL_MODE"] = "1"

    tokenizer, encoder, model, kb_config = _prepare_models(
        encoder_model_spec,
        encoder_path,
        llm_type,
        llm_base_dir,
        model_path,
        query_head_path,
        kb_layer_frequency,
        kb_scale_factor,
    )

    for param in model.parameters():
        param.requires_grad = False

    # Set up the encoder
    encoder = KBEncoder(
        encoder_name=encoder_model_spec.upper(),
        projector_type="linear",
        endpoint_url="",
        out_dim=model.config.hidden_size  # type: ignore
        * (model.config.num_hidden_layers // actual_kb_token_layer_frequency + 1),  # type: ignore
        frozen_base_model=True,
        device=torch.device("cuda:1"),
    )
    encoder.load_state_dict(torch.load(encoder_path))

    kb_retriever = KBRetriever(
        encoder,
        dataset,
        precomputed_embed_keys_path=precomputed_embed_keys_path,
        precomputed_embed_values_path=precomputed_embed_values_path,
    )
    no_kb_predictions = []
    predictions = []
    answer = []

    for _ in range(sample_size):
        print("******")
        dataset_subset_idx = np.random.choice(len(dataset), subset_size, replace=False)
        dataset_subset = [dataset[i] for i in dataset_subset_idx]
        encoder.eval()
        with torch.autograd.no_grad():
            kb_embedding_real = kb_retriever.get_key_embeddings(dataset_subset_idx)
            kb_embedding_key, kb_embedding_val = kb_embedding_real
            kb_embedding_real = (kb_embedding_key, kb_embedding_val)

        format_func_map = {"llama3": _format_Q_llama, "phi3": _format_Q_phi3}

        input_strs = [
            format_func_map[llm_type](dataset_subset[i]["Q"])
            for i in range(subset_size)
        ]

        tokenizer_output = tokenizer(input_strs, return_tensors="pt", padding=True).to(
            "cuda:1"
        )
        input_ids, attention_masks = (
            tokenizer_output["input_ids"],
            tokenizer_output["attention_mask"],
        )
        kb_embedding_real = (kb_embedding_real[0], kb_embedding_real[1])

        config_str = f"{exp_config_str}__kb_{subset_size}__seed_{seed}"
        with torch.autograd.no_grad():
            outputs_no_kb = model.generate(
                input_ids=input_ids,
                attention_mask=attention_masks,
                kb_kvs=None,
                max_new_tokens=40,
                tokenizer=tokenizer,
                output_attentions=False,
                kb_config=kb_config,
            )

            outputs_true_kb = model.generate(
                input_ids=input_ids,
                attention_mask=attention_masks,
                kb_kvs=kb_embedding_real,
                max_new_tokens=40,
                tokenizer=tokenizer,
                output_attentions=True,
                save_attention_weights=True,
                attention_save_loc=output_dir,
                attention_file_base_name=config_str,
                kb_config=kb_config,
            )
        print("eval() begin to batch_decode")
        outputs_no_kb = tokenizer.batch_decode(outputs_no_kb, skip_special_tokens=False)

        outputs_true_kb = tokenizer.batch_decode(
            outputs_true_kb, skip_special_tokens=False
        )
        print("KB:")
        for i in range(subset_size):
            print(
                "{} : {}".format(
                    dataset_subset[i]["name"], dataset_subset[i]["description"]
                )
            )

        for m in model_prune_format_mapping:
            if isinstance(model, m):
                prune_str = model_prune_format_mapping[m]

        print("------------------")
        for i in range(subset_size):
            print("True KB: ", prune_str(outputs_true_kb[i]))
            print("True answer: ", dataset_subset[i]["A"])
            no_kb_predictions.append(
                prune_str(outputs_no_kb[i]).split(dataset_subset[i]["Q"])[1]
            )
            predictions.append(
                prune_str(outputs_true_kb[i]).split(dataset_subset[i]["Q"])[1]
            )
            answer.append(dataset_subset[i]["A"])
            print("--------------------")
        print("******")

    # rogue_score = rouge.compute(predictions=predictions, references=answer)
    rouge_score = get_evaluate_rouge(rouge, predictions, answer)
    Path(args.attn_summary_save_dir).mkdir(exist_ok=True, parents=True)
    np.savez(
        os.path.join(attn_summary_save_dir, f"{config_str}_rouge.npy"), **rouge_score
    )
    print(f'eval(): rouge_score_kb={rouge_score}')
    
    # rouge_score_no_kb = rouge.compute(predictions=no_kb_predictions, references=answer)
    rouge_score_no_kb = get_evaluate_rouge(rouge, no_kb_predictions, answer)
    np.savez(
        os.path.join(attn_summary_save_dir, f"{config_str}_rouge_no_kb.npy"),
        **rouge_score_no_kb,
    )
    print(f'eval(): rouge_score_no_kb={rouge_score_no_kb}')

    # Start inspecting attention masks
    ranges = [(0, 6), (6, 12), (12, 18), (18, 24), (24, 30), (30, 32)]

    save_dir = output_dir
    Path(args.save_dir).mkdir(exist_ok=True, parents=True)

    accs, confidences = [], []
    for left, right in ranges:
        weights = []
        kb_size = subset_size
        for idx in range(32)[left:right]:
            if idx % 3 == 0:
                weight = np.load(os.path.join(save_dir, f"{config_str}_{idx}.npy"))
                weights.append(weight[..., :kb_size].reshape(kb_size, -1, kb_size))
        print(f'len(weights)={len(weights)}, weights[0].shape={weights[0].shape}, kb_size=subset_size={kb_size}')
        weights = np.stack(weights)
        weights = weights.transpose(1, 0, 2, 3).reshape(kb_size, -1, kb_size)
        print(f'weights.shape={weights.shape}')
        acc = (weights.sum(1).argmax(1) == np.arange(kb_size)).mean()
        top_5_predictions = torch.topk(torch.from_numpy(weights.sum(1)), 5, dim=1)[1]
        top_5_acc = (
            (top_5_predictions == torch.arange(kb_size)[:, None]).any(1).float().mean()
        )
        accs.append((acc, top_5_acc))
        confidence = softmax(weights.mean(1), -1).max()
        print(f'left={left}, right={right}, acc={acc}, top_5_acc={top_5_acc}, confidence={confidence}')
        confidences.append(confidence)
    print(f'confidences={confidences}')
    print(f'accs={accs}')
    np.save(
        os.path.join(attn_summary_save_dir, f"{config_str}_acc.npy"), np.array(accs)
    )
    np.save(
        os.path.join(attn_summary_save_dir, f"{config_str}_conf.npy"),
        np.array(confidences),
    )


def compare_texts_case():
    """Compare generated text of using KB, not using KB, and original model"""
    args = parser.parse_args()

    encoder_model_spec = args.encoder_spec
    encoder_path = args.encoder_dir
    kb_layer_frequency = args.kb_layer_frequency
    kb_scale_factor = args.kb_scale_factor
    kb_size = args.kb_size
    llm_base_dir = args.llm_base_dir
    llm_type = args.llm_type
    model_path = args.model_dir
    seed = args.seed
    query_head_path = args.query_head_path

    dataset = [{"Q": "What's the description of apple?", "A": "The description of apple is an red circle that can eat.", "key_string": "the description of apple", "description": "an red circle that can eat"}, 
               {"Q": "What's the purpose of software testing?", "A": "The purpose of software testing is to reduce bugs of the program.", "key_string": "the purpose of software testing", "description": "to reduce bugs of the program"},
               {"Q": "What is the objectives of Agent?", "A": "The objectives of agent is to enhance the capability of LLM", "key_string": "the objectives of agent", "description": "to enhance the capability of LLM"}]
    
    # 1. 单独加载 tokenizer 以设置特殊标记
    tokenizer = AutoTokenizer.from_pretrained('/models/Meta-Llama-3-8B')
    # 2. 确保 tokenizer 有 pad_token 和 eos_token
    if tokenizer.pad_token is None:
        # 如果模型没有 pad_token,使用 eos_token 或添加新 token
        tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token is not None else tokenizer.add_special_tokens({'pad_token': '^'})
    # 3. 加载模型并更新嵌入层大小(如果添加了新 token)
    model = AutoModelForCausalLM.from_pretrained(
        '/models/Meta-Llama-3-8B',
        torch_dtype=torch.bfloat16,
        device_map="cuda:1"
    )
    # 如果添加了新 token,需要调整模型嵌入层大小
    if tokenizer.pad_token_id is None:
        print(f'len(tokenizer)={len(tokenizer)}')
        model.resize_token_embeddings(len(tokenizer))
        model.config.pad_token_id = tokenizer.pad_token_id
    else:
        print(f'tokenizer.pad_token_id={tokenizer.pad_token_id}')
    # 4. 创建 pipeline 时传入设置好的 tokenizer
    pipeline = transformers.pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,  # 关键:传入已配置的 tokenizer
        device_map="cuda:1",
        model_kwargs={"torch_dtype": torch.bfloat16}
    )
    original_model_output = []
    for item in dataset:
        output = pipeline(item['Q'], max_new_tokens=150, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
        original_model_output.append(output[0]['generated_text'])
    del pipeline, model, tokenizer  # 删除引用
    import gc
    gc.collect()  # 收集 Python 层的垃圾
    torch.cuda.empty_cache()  
    torch.cuda.ipc_collect()


    tokenizer, encoder, model, kb_config = _prepare_models(
        encoder_model_spec,
        encoder_path,
        llm_type,
        llm_base_dir,
        model_path,
        query_head_path,
        kb_layer_frequency,
        kb_scale_factor,
    )

    kb_retriever = KBRetriever(
        encoder,
        dataset,
    )

    np.random.seed(seed)
    key_str = [row["key_string"] for row in dataset]
    value_str = [row["description"] for row in dataset]
    prompt_strs = ""
    for k, v in zip(key_str, value_str):
        prompt_strs += f"{k} is {v}; "
    
    kb_embedding = kb_retriever.get_key_embeddings(np.array([i for i in range(len(dataset))]))

    model_outputs_kb = []
    model_outputs_icl = []
    model_outputs_zeroshot = []
    answers = []
    qi = 0
    for row in tqdm(dataset):
        Q = row["Q"]
        answer = row["A"]

        model_output_kb = answer_question(
            tokenizer,
            model,
            Q,
            kb=kb_embedding,
            kb_config=kb_config,
        )
        if Q in model_output_kb:
            print(f'qi={qi}, raw model_output_kb={model_output_kb}')
            model_output_kb = model_output_kb.split(Q)[1]
        else:
            print(f'qi={qi}, Q not in model_output_kb')
        
        model_output_icl = answer_question(
            tokenizer,
            model,
            instruction_prompts + prompt_strs + Q,
            kb=None,
            kb_config=kb_config,
        )
        if Q in model_output_icl:
            print(f'qi={qi}, raw model_output_icl={model_output_icl}')
            model_output_icl = model_output_icl.split(Q)[1]
        else:
            print(f'qi={qi}, Q not in model_output_icl')
        
        model_output_zeroshot = answer_question(
            tokenizer, model, zero_shot_prompt + Q, kb=None, kb_config=kb_config
        )
        if Q in model_output_zeroshot:
            print(f'qi={qi}, raw model_output_zeroshot={model_output_zeroshot}')
            model_output_zeroshot = model_output_zeroshot.split(Q)[1]
        else:
            print(f'qi={qi}, Q not in model_output_zeroshot')

        print(f'\nQuestion-{qi+1}: {Q}\nAnswer: {answer}\nKB: {model_output_kb}\nICL: {model_output_icl}\nZEROSHOT: {model_output_zeroshot}\nOriginal: {original_model_output[qi]}\n\n')
        
        pattern = r'The\s+\w+\s+of\s+[^"]+\s+is\s+(.+)'
        match = re.search(pattern, model_output_kb)
        if match:
            model_output_kb = match.group(1)
        else:
            print(f'[error] model_output_kb {model_output_kb} not match pattern')
        model_outputs_kb.append(model_output_kb)
        
        match = re.search(pattern, model_output_icl)
        if match:
            model_output_icl = match.group(1)
        else:
            print(f'[error] model_output_icl {model_output_icl} not match pattern')
        model_outputs_icl.append(model_output_icl)
        
        match = re.search(pattern, model_output_zeroshot)
        if match:
            model_output_zeroshot = match.group(1)
        else:
            print(f'[error] model_output_zeroshot {model_output_zeroshot} not match pattern')
        model_outputs_zeroshot.append(model_output_zeroshot)
        
        answers.append(row["description"])
        qi+=1

    rouge_scores_kb = get_evaluate_rouge(rouge, model_outputs_kb, answers)
    rouge_scores_icl = get_evaluate_rouge(rouge, model_outputs_icl, answers)
    rouge_scores_zeroshot = get_evaluate_rouge(rouge, model_outputs_zeroshot, answers)
    rouge_original = get_evaluate_rouge(rouge, original_model_output, answers)
    print(f'rouge_scores_kb={rouge_scores_kb}')
    print(f'rouge_scores_icl={rouge_scores_icl}')
    print(f'rouge_scores_zeroshot={rouge_scores_zeroshot}')
    print(f'rouge_original={rouge_original}')
    
    P, R, F1 = bert_score.score(model_outputs_kb, answers, batch_size=10,verbose=True)
    P, R, F1 = P.tolist(), R.tolist(), F1.tolist()
    print(f'KB-P: {np.mean(P)}, {P}')
    print(f'KB-R: {np.mean(R)}, {R}')
    print(f'KB-F1: {np.mean(F1)}, {F1}')
    
    P, R, F1 = bert_score.score(model_outputs_icl, answers, batch_size=10,verbose=True)
    P, R, F1 = P.tolist(), R.tolist(), F1.tolist()
    print(f'ICL-P: {np.mean(P)}, {P}')
    print(f'ICL-R: {np.mean(R)}, {R}')
    print(f'ICL-F1: {np.mean(F1)}, {F1}')
    
    P, R, F1 = bert_score.score(model_outputs_zeroshot, answers, batch_size=10,verbose=True)
    P, R, F1 = P.tolist(), R.tolist(), F1.tolist()
    print(f'ZEROSHOT-P: {np.mean(P)}, {P}')
    print(f'ZEROSHOT-R: {np.mean(R)}, {R}')
    print(f'ZEROSHOT-F1: {np.mean(F1)}, {F1}')
    
    P, R, F1 = bert_score.score(original_model_output, answers, batch_size=10,verbose=True)
    P, R, F1 = P.tolist(), R.tolist(), F1.tolist()
    print(f'ORIGINAL-P: {np.mean(P)}, {P}')
    print(f'ORIGINAL-R: {np.mean(R)}, {R}')
    print(f'ORIGINAL-F1: {np.mean(F1)}, {F1}')
    
    mem_cost = torch.cuda.max_memory_reserved("cuda:1")/1024**3
    print(f'mem_cost={mem_cost}GB')
    

def main():
    st = time.time()
    print(f'begin time: {datetime.datetime.now()}')
    args = parser.parse_args()
    print(f'args={args}')
    if args.command == "generation":
        eval_generate()
    elif args.command == "accuracy":
        eval_accuracy_cli()
    elif args.command == "acc_results":
        run_accuracy_evalution()
    elif args.command == "refusal":
        eval_refusal()
    elif args.command == "standard":
        eval()
    elif args.command == "compare":
        compare_texts_case()
    else:
        raise ValueError(f"command {args.command} not recognised")

    total_time = round(time.time() - st,2)
    print(f'end time: {datetime.datetime.now()}, total cost {total_time}s')

if __name__ == "__main__":
    main()

shiwanghua avatar Jun 06 '25 08:06 shiwanghua

I can't get good results using Qwen3-Embedding-8B neither.

lllyyyqqq avatar Dec 22 '25 02:12 lllyyyqqq

@lllyyyqqq can you elaborate more on the performance issues you are seeing? I suspect that Qwen3 embedding 8B's output dimension may be a little bit too large, and the adapter has trouble learning. You could consider switch to a model with smaller embedding dimension

xidulu avatar Dec 22 '25 02:12 xidulu

As I can observe, training loss decrease slowly to 1.3 after 30000 steps. Prediction and groud truth are about different objectives. Prediction sentence not complete. The reason I choose Qwen3 embedding 8B is that I notice test results increase as embedding dimension increase in the paper, so I choose a even larger one. But it does training harder. I might use 4B model to do it again, which dimension size falls in paper experiments domain.

Specific results: Evaluation on generation/kb mode: {'rouge1': 0.18947318770740823, 'rouge2': 0.016747509800141377, 'rougeL': 0.17989379150859153, 'rougeLsum': 0.17986779382904017, 'bert_score_precision': 0.6285847148299217, 'bert_score_recall': 0.6257531017065048, 'bert_score_f1': 0.6260100820660591, 'mem_cost': 21659385856}

Evaluation on generation/ICL mode: {'rouge1': 0.9621221181584352, 'rouge2': 0.9556224008959437, 'rougeL': 0.9619207044305509, 'rougeLsum': 0.9617364319350254, 'bert_score_precision': 0.9368505561351776, 'bert_score_recall': 0.9653788071870804, 'bert_score_f1': 0.9501989161968232, 'mem_cost': 34915483648}

Training logs: Image

Image

lllyyyqqq avatar Dec 22 '25 02:12 lllyyyqqq

I have changed to 4B model, it seems already much better after 200 steps training, looks like embedding dimension is the issue.

lllyyyqqq avatar Dec 22 '25 08:12 lllyyyqqq

Amazing!! Good to hear that !!

xidulu avatar Dec 22 '25 08:12 xidulu