请问在screenspot测试中,数据的platform字段是如何确定的,以及推理中历史信息是否利用了
请问在screenspot测试中,数据的platform字段是如何确定的,如原始数据中web,tool来源的数据,如何划分到Mac,Mobile,WIN中,以及推理中历史信息是否利用了?本人复现模型在screenspot上评估结果的指标(70.%)上,与报告所述85.4%相差较大
在模型训练阶段,是有一定概率不加入platform字段的,在不确定数据采集平台的情况下,可以选择不加入platform字段,或者尝试使用默认平台WIN。历史信息请参考提示词拼接文档“History 字段”部分:https://zhipu-ai.feishu.cn/wiki/D9FTwQ78fitS3CkZHUjcKEWTned。您可以提供更加详细的评测配置,如是否使用模型量化、提示词拼接的具体代码等,以便我们帮助您发现潜在问题。
在模型训练阶段,是有一定概率不加入platform字段的,在不确定数据采集平台的情况下,可以选择不加入platform字段,或者尝试使用默认平台WIN。历史信息请参考提示词拼接文档“History 字段”部分:https://zhipu-ai.feishu.cn/wiki/D9FTwQ78fitS3CkZHUjcKEWTned。您可以提供更加详细的评测配置,如是否使用模型量化、提示词拼接的具体代码等,以便我们帮助您发现潜在问题。
我利用cogagent-9b-20241220参数在benchmark screenspot上进行推理评测,未添加history字段 未使用量化,判断标准为:预测box中心点位于真实标注框内算正确。
推理代码如下:def main_ScreenSpot(): """ A continuous interactive demo using the CogAgent1.5 model with selectable format prompts. The output_image_path is interpreted as a directory. For each round of interaction, the annotated image will be saved in the directory with the filename: {original_image_name_without_extension}_{round_number}.png
Example:
python cli_demo_my.py --model_dir ../cogagent-9b-20241220 --platform "Mac" --max_length 4096 --top_k 1 \
--output_image_path ./results --format_key status_action_op_sensitive
"""
parser = argparse.ArgumentParser(
description="Continuous interactive demo with CogAgent model and selectable format."
)
parser.add_argument(
"--model_dir", default='../cogagent-9b-20241220',required=True, help="Path or identifier of the model."
)
parser.add_argument(
"--platform",
default="Mac",
help="Platform information string (e.g., 'Mac', 'WIN').",
)
parser.add_argument(
"--max_length", type=int, default=4096, help="Maximum generation length."
)
parser.add_argument(
"--top_k", type=int, default=1, help="Top-k sampling parameter."
)
parser.add_argument(
"--output_image_path",
default="image_results",
help="Directory to save the annotated images.",
)
parser.add_argument(
"--output_pred_path",
default="./ScreenSpot.csv",
help="Directory to save the annotated images.",
)
parser.add_argument(
"--format_key",
default="status_action_op_sensitive",
help="Key to select the prompt format.",
)
parser.add_argument(
"--task_path",
default="../ScreenSpot/ScreenSpot_combined.json",
help="Key to select the prompt format.",
)
parser.add_argument(
"--image_path",
default="../ScreenSpot/images",
help="Key to select the prompt format.",
)
args = parser.parse_args()
# Dictionary mapping format keys to format strings
format_dict = {
"action_op_sensitive": "(Answer in Action-Operation-Sensitive format.)",
"status_plan_action_op": "(Answer in Status-Plan-Action-Operation format.)",
"status_action_op_sensitive": "(Answer in Status-Action-Operation-Sensitive format.)",
"status_action_op": "(Answer in Status-Action-Operation format.)",
"action_op": "(Answer in Action-Operation format.)",
}
# Ensure the provided format_key is valid
if args.format_key not in format_dict:
raise ValueError(
f"Invalid format_key. Available keys are: {list(format_dict.keys())}"
)
# Ensure the output directory exists
os.makedirs(args.output_image_path, exist_ok=True)
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
args.model_dir,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
# quantization_config=BitsAndBytesConfig(load_in_8bit=True), # For INT8 quantization
# quantization_config=BitsAndBytesConfig(load_in_4bit=True), # For INT4 quantization
).eval()
# Initialize platform and selected format strings
platform_str = f"(Platform: {args.platform})\n"
format_str = format_dict[args.format_key]
# Initialize history lists
history_step = []
history_action = []
pre_result_list=[]
round_num = 1
with open(args.task_path, 'r') as file:
data = json.load(file)
task_id=0
for item in tqdm(data):
for one in item['annotations']:
logging.info(f'开始推理{task_id}...')
pred_result=one.copy()
task=one['objective_reference']
img_path=args.image_path+'/'+one['image_id']
if one["data_source"] in ['ios','android']:
platform='Mobile'
elif one["data_source"] in ['macos']:
platform='Mac'
else:
platform='WIN'
platform_str=f"(Platform: {platform})\n"
try:
image = Image.open(img_path).convert("RGB")
except Exception:
logging.info("Invalid image path. Please try again.")
continue
# Verify history lengths match
if len(history_step) != len(history_action):
raise ValueError("Mismatch in lengths of history_step and history_action.")
# Format history steps for output
history_str = "\nHistory steps: "
for index, (step, action) in enumerate(zip(history_step, history_action)):
history_str += f"\n{index}. {step}\t{action}"
# Compose the query with task, platform, and selected format instructions
#query = f"Task: {task}{history_str}\n{platform_str}{format_str}"
query = f"Task: {task}\n{platform_str}{format_str}"
logging.info(f"Round {round_num} query:\n{query}")
inputs = tokenizer.apply_chat_template(
[{"role": "user", "image": image, "content": query}],
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True,
).to(model.device)
# Generation parameters
gen_kwargs = {
"max_length": args.max_length,
"do_sample": True,
"top_k": args.top_k,
}
# Generate response
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs["input_ids"].shape[1]:]
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract grounded operation and action
grounded_pattern = r"Grounded Operation:\s*(.*)"
action_pattern = r"Action:\s*(.*)"
matches_history = re.search(grounded_pattern, response)
matches_actions = re.search(action_pattern, response)
if matches_history:
grounded_operation = matches_history.group(1)
history_step.append(grounded_operation)
if matches_actions:
action_operation = matches_actions.group(1)
history_action.append(action_operation)
# Extract bounding boxes from the response
box_pattern = r"box=\[\[?(\d+),(\d+),(\d+),(\d+)\]?\]"
matches = re.findall(box_pattern, response)
if matches:
boxes = [[int(x) / 1000 for x in match] for match in matches]
# Extract base name of the user's input image (without extension)
base_name = os.path.splitext(os.path.basename(img_path))[0]
# Construct the output file name with round number
output_file_name = f"{base_name}_{round_num}.png"
output_path = os.path.join(args.output_image_path, output_file_name)
draw_boxes_on_image(image, boxes, output_path)
pred_boxes = [[int(x)for x in match] for match in matches]
logging.info(f"Annotated image saved at: {output_path}")
pred_result.update({"pred_box": pred_boxes,"platform":platform,"Model_response": response})
else:
logging.info("No bounding boxes found in the response.")
pred_result.update({"pred_box": '', "Model_response": {response}})
task_id+=1
round_num += 1
logging.info(pred_result)
pre_result_list.append(pred_result)
if os.path.exists(args.output_pred_path):
# 如果存在,先删除文件
os.remove(args.output_pred_path)
keys = pre_result_list[0].keys()
with open(args.output_pred_path, 'w', encoding='utf-8', newline='') as output_file:
dict_writer = csv.DictWriter(output_file, keys)
dict_writer.writeheader()
dict_writer.writerows(pre_result_list)
评测代码如下: def draw_boxes_on_image(img_path, box): image = Image.open(img_path).convert("RGB") x_min = int(box[0]/1000 * image.width) y_min = int(box[1]/1000 * image.height) x_max = int(box[2]/1000 * image.width) y_max = int(box[3]/1000 * image.height) return [x_min, y_min, x_max, y_max]
def evaluate(df): data_dict_list = [] with open(df, 'r', encoding='GBK') as csvfile: reader = csv.DictReader(csvfile) for row in reader: data_dict_list.append(row)
true_num=0
corr_num=0
for id,item in enumerate(data_dict_list):
try:
pred_box = np.array(eval(item['pred_box'])[0])
true_box = np.array(eval(item['bounding_box']))
# 获取预测结果的对应坐标
img_path=f'./ScreenSpot/images/{item['image_id']}'
pred_box = draw_boxes_on_image(img_path, pred_box)
x_min, y_min, x_max, y_max = pred_box
# 获取标注结果坐标
true_box = [true_box[0], true_box[1], true_box[0] + true_box[2], true_box[1] + true_box[3]]
pred_center_x = (x_min + x_max) / 2
pred_center_y = (y_min + y_max) / 2
# 判断中心点是否在真实标识框内
if (true_box[0] <= pred_center_x <= true_box[2]) and (true_box[1] <= pred_center_y <= true_box[3]):
true_num+=1
else:
print(true_box, pred_center_x, pred_center_y)
print('错误',item['platform'],item['data_source'])
corr_num+=1
except:
continue
print(true_num/len(data_dict_list),true_num,corr_num,len(data_dict_list))
I also hope to see the prompt for evaluation the grounding performance of cogagentv2 on screenspot. Thanks