z3 icon indicating copy to clipboard operation
z3 copied to clipboard

Z3 python bitvector performance cliff

Open intrigus-lgtm opened this issue 2 years ago • 5 comments

Hi, I use Z3 via the python bindings and have noticed an interesting performance cliff.

Depending on "THE_VALUE" -- selects how many instructions to translate -- performance looks like this:

THE_VALUE = 400000 # real     0m7,980s stats: {'mov': 288711, 'call': 111283, 'xor': 6}
THE_VALUE = 420000 # real    0m15,706s stats: {'mov': 303147, 'call': 116845, 'xor': 8}
THE_VALUE = 430000 # real    1m21,680s stats: {'mov': 310362, 'call': 119628, 'xor': 10}
THE_VALUE = 440000 # real    4m24,924s stats: {'mov': 317576, 'call': 122414, 'xor': 10}

With all instructions (3 million), z3 ooms around ~60gb and 30 minutes. The problem occurs when adding the constraint to z3, i.e. s.add(registers["eax"] == 0)

[i] Time to add constraint: 222.20120286941528

Using cvc5, it works fine in basically no time.

Reproduction steps

pip install z3-solver
pip install cvc5
cd $(mktemp -d)
wget https://github.com/Z3Prover/z3/files/13553176/z3_python_bitvector_performance_cliff.tar.gz
tar -xzf z3_python_bitvector_performance_cliff.tar.gz
time python solv.py # be patient!
time python solv_cvc5.py

foobar.cvc5.smt2 contains the smt2 statement of cvc5 and foobar.smt2 the statements of z3. Note how much bigger the cvc5 one is.

z3_python_bitvector_performance_cliff.tar.gz

solv.py (originally from https://github.com/pwning/public-writeup/tree/master/codegate2023/GateCodeGate):

from z3 import *

with open("./gatecodegate.asm", "r") as f:
    code = f.read()

code = (
    code.split(" ___isoc99_scanf", 1)[1]
    .split("lea     rdi, aCorrectCodegat", 1)[0]
    .strip()
)

code = code.split("\n")
code = [x.strip() for x in code]

registers = {
    x: BitVec(x, 32)
    for x in [
        "eax",
        "ebx",
        "ecx",
        "edx",
        "esi",
        "edi",
        "[rsp+1C8h+var_4C]",
        "[rsp+1C8h+var_50]",
        "[rsp+1C8h+var_54]",
        "[rsp+1C8h+var_58]",
        "[rsp+1C8h+var_5C]",
        "[rsp+1C8h+var_60]",
        "[rsp+1C8h+var_64]",
        "[rsp+1C8h+var_68]",
    ]
}

gates = {
    "gate_and": lambda x, y: x & y,
    "gate_or": lambda x, y: x | y,
    "gate_xor": lambda x, y: x ^ y,
    "gate_not": lambda x, y: ~x,
    "gate_shl": lambda x, y: x << y,
    "gate_shr": lambda x, y: LShR(x, y),
}


def getval(x, line):
    if x in registers:
        return registers[x]
    elif x.endswith("h"):
        return BitVecVal(int(x[:-1], 16), 32)
    elif x.isdigit():
        return BitVecVal(int(x), 32)
    else:
        raise Exception(f"Unknown operand: {x}; {line}")


print(f"[i] {len(code)} lines of code")

THE_VALUE = 400000 # real     0m7,980s stats: {'mov': 288711, 'call': 111283, 'xor': 6}
THE_VALUE = 420000 # real    0m15,706s stats: {'mov': 303147, 'call': 116845, 'xor': 8}
THE_VALUE = 430000 # real    1m21,680s stats: {'mov': 310362, 'call': 119628, 'xor': 10}
THE_VALUE = 440000 # real    4m24,924s stats: {'mov': 317576, 'call': 122414, 'xor': 10}
stats = {}

for line_num, line in enumerate(code[0:THE_VALUE]):
    line = line.replace(",", "").split()
    opcode = line[0]
    if opcode == "mov":
        stats["mov"] = stats.get("mov", 0) + 1
        dst = line[1]
        src = line[2]
        registers[dst] = getval(src, (line_num, line))

    elif opcode == "xor":
        stats["xor"] = stats.get("xor", 0) + 1
        dst = line[1]
        src = line[2]

        registers[dst] = registers[dst] ^ getval(src, (line_num, line))

    elif opcode == "call":
        stats["call"] = stats.get("call", 0) + 1
        gate = line[1]

        result = gates[gate](registers["edi"], registers["esi"])
        registers["eax"] = result

    else:
        raise Exception(f"Unknown opcode: {opcode}; {line}")


print("[i] Finished parsing code")
s = Solver()
s.set("smtlib2_log", "log.smt2")
import time
before = time.time()
s.add(registers["eax"] == 0)
print(f"[i] Time to add constraint: {time.time() - before}")
# print(registers["eax"])
print(f"THE_VALUE: {THE_VALUE}")
print(f"stats: {stats}")
before = time.time()
open("foobar.smt2", "w").write(s.sexpr())
print(f"[i] Time to write smt2: {time.time() - before}")
exit(0)
print(s.check())
print(s.model())

intrigus-lgtm avatar Dec 05 '23 00:12 intrigus-lgtm

todo: rewrite your script to C++ to make it easier to use perf-profiler. @nunoplopes may potentially be interested as it is more likely to be about hash-table performance so a way to test replacing the hash-table and/or hash functions.

NikolajBjorner avatar Dec 05 '23 03:12 NikolajBjorner

@intrigus-lgtm Maybe try my latest attempt at changing the hash function: https://github.com/nunoplopes/z3 It improved things for me, at least.

nunoplopes avatar Dec 05 '23 09:12 nunoplopes

@NikolajBjorner I've rewritten the script to C++. It should be functionally equivalent, but I don't know C++ too well.

There is still a performance cliff:

    // THE_VALUE = 400000; // 144ms    new-hash: 143ms
    // THE_VALUE = 420000; // 7427ms   new-hash: 7555ms
    // THE_VALUE = 430000; // 69206ms  new-hash: 71535ms
    // THE_VALUE = 440000; // 228012ms new-hash: 237135ms

(Unpack the tar, then compile solv.cpp and place the resulting executable next to the cleaned.asm file) z3_cpp_bitvector_performance_cliff.tar.gz

@nunoplopes I also tried your attempt at changing the hash function, but it did not help, as can be seen in the above comments.

intrigus-lgtm avatar Dec 08 '23 01:12 intrigus-lgtm

Is there anything else missing from my side or that I could/should do?

intrigus-lgtm avatar Dec 20 '23 00:12 intrigus-lgtm

Is there anything else missing from my side or that I could/should do?

Short of financing Z3's development or contributing with PRs or ideas on how to fix this, no.

nunoplopes avatar Dec 20 '23 15:12 nunoplopes