CogAgent icon indicating copy to clipboard operation
CogAgent copied to clipboard

请问在screenspot测试中,数据的platform字段是如何确定的,以及推理中历史信息是否利用了

Open lijianhong-code opened this issue 1 year ago • 5 comments

请问在screenspot测试中,数据的platform字段是如何确定的,如原始数据中web,tool来源的数据,如何划分到Mac,Mobile,WIN中,以及推理中历史信息是否利用了?本人复现模型在screenspot上评估结果的指标(70.%)上,与报告所述85.4%相差较大

lijianhong-code avatar Jan 05 '25 11:01 lijianhong-code

在模型训练阶段,是有一定概率不加入platform字段的,在不确定数据采集平台的情况下,可以选择不加入platform字段,或者尝试使用默认平台WIN。历史信息请参考提示词拼接文档“History 字段”部分:https://zhipu-ai.feishu.cn/wiki/D9FTwQ78fitS3CkZHUjcKEWTned。您可以提供更加详细的评测配置,如是否使用模型量化、提示词拼接的具体代码等,以便我们帮助您发现潜在问题。

jasonnoy avatar Jan 06 '25 08:01 jasonnoy

在模型训练阶段,是有一定概率不加入platform字段的,在不确定数据采集平台的情况下,可以选择不加入platform字段,或者尝试使用默认平台WIN。历史信息请参考提示词拼接文档“History 字段”部分:https://zhipu-ai.feishu.cn/wiki/D9FTwQ78fitS3CkZHUjcKEWTned。您可以提供更加详细的评测配置,如是否使用模型量化、提示词拼接的具体代码等,以便我们帮助您发现潜在问题。

我利用cogagent-9b-20241220参数在benchmark screenspot上进行推理评测,未添加history字段 未使用量化,判断标准为:预测box中心点位于真实标注框内算正确。

lijianhong-code avatar Jan 06 '25 08:01 lijianhong-code

推理代码如下:def main_ScreenSpot(): """ A continuous interactive demo using the CogAgent1.5 model with selectable format prompts. The output_image_path is interpreted as a directory. For each round of interaction, the annotated image will be saved in the directory with the filename: {original_image_name_without_extension}_{round_number}.png

Example:
python cli_demo_my.py --model_dir ../cogagent-9b-20241220 --platform "Mac" --max_length 4096 --top_k 1 \
                 --output_image_path ./results --format_key status_action_op_sensitive
"""

parser = argparse.ArgumentParser(
    description="Continuous interactive demo with CogAgent model and selectable format."
)
parser.add_argument(
    "--model_dir", default='../cogagent-9b-20241220',required=True, help="Path or identifier of the model."
)
parser.add_argument(
    "--platform",
    default="Mac",
    help="Platform information string (e.g., 'Mac', 'WIN').",
)
parser.add_argument(
    "--max_length", type=int, default=4096, help="Maximum generation length."
)
parser.add_argument(
    "--top_k", type=int, default=1, help="Top-k sampling parameter."
)
parser.add_argument(
    "--output_image_path",
    default="image_results",
    help="Directory to save the annotated images.",
)
parser.add_argument(
    "--output_pred_path",
    default="./ScreenSpot.csv",
    help="Directory to save the annotated images.",
)
parser.add_argument(
    "--format_key",
    default="status_action_op_sensitive",
    help="Key to select the prompt format.",
)
parser.add_argument(
    "--task_path",
    default="../ScreenSpot/ScreenSpot_combined.json",
    help="Key to select the prompt format.",
)
parser.add_argument(
    "--image_path",
    default="../ScreenSpot/images",
    help="Key to select the prompt format.",
)
args = parser.parse_args()

# Dictionary mapping format keys to format strings
format_dict = {
    "action_op_sensitive": "(Answer in Action-Operation-Sensitive format.)",
    "status_plan_action_op": "(Answer in Status-Plan-Action-Operation format.)",
    "status_action_op_sensitive": "(Answer in Status-Action-Operation-Sensitive format.)",
    "status_action_op": "(Answer in Status-Action-Operation format.)",
    "action_op": "(Answer in Action-Operation format.)",
}

# Ensure the provided format_key is valid
if args.format_key not in format_dict:
    raise ValueError(
        f"Invalid format_key. Available keys are: {list(format_dict.keys())}"
    )

# Ensure the output directory exists
os.makedirs(args.output_image_path, exist_ok=True)

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    args.model_dir,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    device_map="auto",
    # quantization_config=BitsAndBytesConfig(load_in_8bit=True), # For INT8 quantization
    # quantization_config=BitsAndBytesConfig(load_in_4bit=True), # For INT4 quantization
).eval()
# Initialize platform and selected format strings
platform_str = f"(Platform: {args.platform})\n"
format_str = format_dict[args.format_key]

# Initialize history lists
history_step = []
history_action = []
pre_result_list=[]

round_num = 1

with open(args.task_path, 'r') as file:
    data = json.load(file)

task_id=0
for item in tqdm(data):
    for one in item['annotations']:
        logging.info(f'开始推理{task_id}...')
        pred_result=one.copy()
        task=one['objective_reference']
        img_path=args.image_path+'/'+one['image_id']
        if one["data_source"] in ['ios','android']:
            platform='Mobile'
        elif one["data_source"] in ['macos']:
            platform='Mac'
        else:
            platform='WIN'
        platform_str=f"(Platform: {platform})\n"

        try:
            image = Image.open(img_path).convert("RGB")
        except Exception:
            logging.info("Invalid image path. Please try again.")
            continue

        # Verify history lengths match
        if len(history_step) != len(history_action):
            raise ValueError("Mismatch in lengths of history_step and history_action.")

        # Format history steps for output
        history_str = "\nHistory steps: "
        for index, (step, action) in enumerate(zip(history_step, history_action)):
            history_str += f"\n{index}. {step}\t{action}"

        # Compose the query with task, platform, and selected format instructions
        #query = f"Task: {task}{history_str}\n{platform_str}{format_str}"
        query = f"Task: {task}\n{platform_str}{format_str}"

        logging.info(f"Round {round_num} query:\n{query}")

        inputs = tokenizer.apply_chat_template(
            [{"role": "user", "image": image, "content": query}],
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt",
            return_dict=True,
        ).to(model.device)
        # Generation parameters
        gen_kwargs = {
            "max_length": args.max_length,
            "do_sample": True,
            "top_k": args.top_k,
        }

        # Generate response
        with torch.no_grad():
            outputs = model.generate(**inputs, **gen_kwargs)
            outputs = outputs[:, inputs["input_ids"].shape[1]:]
            response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        # Extract grounded operation and action
        grounded_pattern = r"Grounded Operation:\s*(.*)"
        action_pattern = r"Action:\s*(.*)"
        matches_history = re.search(grounded_pattern, response)
        matches_actions = re.search(action_pattern, response)

        if matches_history:
            grounded_operation = matches_history.group(1)
            history_step.append(grounded_operation)
        if matches_actions:
            action_operation = matches_actions.group(1)
            history_action.append(action_operation)

        # Extract bounding boxes from the response
        box_pattern = r"box=\[\[?(\d+),(\d+),(\d+),(\d+)\]?\]"
        matches = re.findall(box_pattern, response)

        if matches:
            boxes = [[int(x) / 1000 for x in match] for match in matches]

            # Extract base name of the user's input image (without extension)
            base_name = os.path.splitext(os.path.basename(img_path))[0]
            # Construct the output file name with round number
            output_file_name = f"{base_name}_{round_num}.png"
            output_path = os.path.join(args.output_image_path, output_file_name)

            draw_boxes_on_image(image, boxes, output_path)

            pred_boxes = [[int(x)for x in match] for match in matches]
            logging.info(f"Annotated image saved at: {output_path}")
            pred_result.update({"pred_box": pred_boxes,"platform":platform,"Model_response": response})
        else:
            logging.info("No bounding boxes found in the response.")
            pred_result.update({"pred_box": '', "Model_response": {response}})

        task_id+=1
        round_num += 1
        logging.info(pred_result)
        pre_result_list.append(pred_result)

if os.path.exists(args.output_pred_path):
    # 如果存在,先删除文件
    os.remove(args.output_pred_path)

keys = pre_result_list[0].keys()
with open(args.output_pred_path, 'w', encoding='utf-8', newline='') as output_file:
    dict_writer = csv.DictWriter(output_file, keys)
    dict_writer.writeheader()
    dict_writer.writerows(pre_result_list)

lijianhong-code avatar Jan 06 '25 08:01 lijianhong-code

评测代码如下: def draw_boxes_on_image(img_path, box): image = Image.open(img_path).convert("RGB") x_min = int(box[0]/1000 * image.width) y_min = int(box[1]/1000 * image.height) x_max = int(box[2]/1000 * image.width) y_max = int(box[3]/1000 * image.height) return [x_min, y_min, x_max, y_max]

def evaluate(df): data_dict_list = [] with open(df, 'r', encoding='GBK') as csvfile: reader = csv.DictReader(csvfile) for row in reader: data_dict_list.append(row)

true_num=0
corr_num=0
for id,item in enumerate(data_dict_list):
    try:
        pred_box = np.array(eval(item['pred_box'])[0])
        true_box = np.array(eval(item['bounding_box']))
        # 获取预测结果的对应坐标
        img_path=f'./ScreenSpot/images/{item['image_id']}'
        pred_box = draw_boxes_on_image(img_path, pred_box)
        x_min, y_min, x_max, y_max = pred_box
        # 获取标注结果坐标
        true_box = [true_box[0], true_box[1], true_box[0] + true_box[2], true_box[1] + true_box[3]]
        pred_center_x = (x_min + x_max) / 2
        pred_center_y = (y_min + y_max) / 2

        # 判断中心点是否在真实标识框内
        if (true_box[0] <= pred_center_x <= true_box[2]) and (true_box[1] <= pred_center_y <= true_box[3]):
            true_num+=1
        else:
            print(true_box, pred_center_x, pred_center_y)
            print('错误',item['platform'],item['data_source'])
            corr_num+=1
    except:
        continue
print(true_num/len(data_dict_list),true_num,corr_num,len(data_dict_list))

lijianhong-code avatar Jan 06 '25 08:01 lijianhong-code

I also hope to see the prompt for evaluation the grounding performance of cogagentv2 on screenspot. Thanks

zhangxgu avatar Feb 06 '25 09:02 zhangxgu