[Bug] Inconcistency after meta schedule for the conv2d_hwcn operator
Actual behavior
An Inconsistency bug detected.
Mismatched elements: 4093 / 294912 (1.39%)
Max absolute difference among violations: 0.00019836
Max relative difference among violations: 0.25589287
ACTUAL: array([[[[ 0.653972, 4.73613 , 47.223606, ..., 1.366684,
-10.288207, 2.612887],
[ 8.710161, -28.032656, 5.167299, ..., -5.342846,...
DESIRED: array([[[[ 0.653983, 4.736097, 47.22358 , ..., 1.366686,
-10.288215, 2.612871],
[ 8.710163, -28.032671, 5.167283, ..., -5.342826,...
Environment
tvm-v0.21.dev0
Steps to reproduce
import tvm
from tvm import te, topi, tir
from tvm import meta_schedule as ms
import numpy as np
def compile_mod(mod, np_input_list, output_shape, output_type, opt_level=3):
with tvm.transform.PassContext(opt_level):
ref_mod = tvm.build(mod, target='llvm')
mod_output = tvm.nd.empty(output_shape, dtype=output_type, device=tvm.cpu(0))
tvm_inputs = [tvm.nd.array(x) for x in np_input_list]
ref_mod(*tvm_inputs, mod_output)
return mod_output
Input = te.placeholder([7, 7, 512, 32], dtype='float32', name='Input')
Filter = te.placeholder([3, 3, 512, 1024], dtype='float32', name='Filter')
op_config = {'Input': Input, 'Filter': Filter, 'stride': [1, 1], 'padding': [1, 1], 'dilation': [3, 3], }
op_output = topi.nn.conv2d_hwcn(**op_config)
np_inputs = [np.random.uniform(-1, 1, size=[7, 7, 512, 32]).astype('float32'),np.random.uniform(-1, 1, size=[3, 3, 512, 1024]).astype('float32')]
sch = tir.Schedule(te.create_prim_func([Input, Filter, op_output]).with_attr('target', tvm.target.Target('llvm')))
ref_output = compile_mod(sch.mod, np_inputs, op_output.shape, op_output.dtype, opt_level=0)
database = ms.tir_integration.tune_tir(mod=sch.mod, target='llvm --num-cores=16', work_dir='./tune_tmp', max_trials_global=1, num_trials_per_iter=1)
sch = ms.tir_integration.compile_tir(database, sch.mod, 'llvm --num-cores=16')
opt_mod_output = compile_mod(sch.mod, np_inputs, op_output.shape, op_output.dtype, opt_level=4)
np.testing.assert_allclose(
ref_output.numpy(), opt_mod_output.numpy(), rtol=1e-5, atol=1e-5, err_msg=f"An Inconsistency detected."
)
Triage
- needs-triage
@vacu9708 Could you please check if this inconsistency indicates a bug in TVM? Thanks a lot!
Hi, I checked and compared the values of each output and confirmed that all gaps are within 1 ULP. But in your result log, the gaps seem quite big (quite a lot of ULPs). I don't know why your result is different from mine.
Below is how I tested. (test environment: tvm-v0.21.dev0, Fri May 23, commit ada7c7c7cc8c15891f0c82e01c5a52f636302b07)
ref_np = ref_output.numpy().astype('float32')
opt_np = opt_mod_output.numpy().astype('float32')
# absolute difference
diff = np.abs(ref_np - opt_np)
# 1 ULP at each ref element
ulp = np.abs(np.spacing(ref_np))
# allow up to N ULP
N = 1
mask = diff > (N * ulp)
if mask.any():
bad = np.argwhere(mask)
print(f"{len(bad)} elements differ by more than {N} ULP:")
for idx in bad[:10]:
i = tuple(idx)
print(f" idx={i}: ref={ref_np[i]:.8e}, opt={opt_np[i]:.8e}, "
f"diff={diff[i]:.8e}, 1ULP@ref={ulp[i]:.8e}")
if len(bad) > 10:
print(f" …and {len(bad)-10} more")
else:
print(f"All values within {N} ULP of reference.")
Thanks for your investigation. It seems a flaky problem due to randomness.
In most cases, they will show a significant difference, as shown below.
272197 elements differ by more than 1 ULP:
idx=(np.int64(0), np.int64(0), np.int64(0), np.int64(1)): ref=-2.35918593e+00, opt=-2.35919356e+00, diff=7.62939453e-06, 1ULP@ref=2.38418579e-07
idx=(np.int64(0), np.int64(0), np.int64(0), np.int64(2)): ref=-1.82648659e+00, opt=-1.82645822e+00, diff=2.83718109e-05, 1ULP@ref=1.19209290e-07
idx=(np.int64(0), np.int64(0), np.int64(0), np.int64(3)): ref=6.85934210e+00, opt=6.85935402e+00, diff=1.19209290e-05, 1ULP@ref=4.76837158e-07
idx=(np.int64(0), np.int64(0), np.int64(0), np.int64(4)): ref=-1.39356594e+01, opt=-1.39356689e+01, diff=9.53674316e-06, 1ULP@ref=9.53674316e-07
idx=(np.int64(0), np.int64(0), np.int64(0), np.int64(5)): ref=1.25973148e+01, opt=1.25973377e+01, diff=2.28881836e-05, 1ULP@ref=9.53674316e-07
idx=(np.int64(0), np.int64(0), np.int64(0), np.int64(6)): ref=-1.22587121e+00, opt=-1.22587681e+00, diff=5.60283661e-06, 1ULP@ref=1.19209290e-07
idx=(np.int64(0), np.int64(0), np.int64(0), np.int64(7)): ref=1.96191578e+01, opt=1.96191769e+01, diff=1.90734863e-05, 1ULP@ref=1.90734863e-06
idx=(np.int64(0), np.int64(0), np.int64(0), np.int64(8)): ref=1.85626125e+00, opt=1.85625291e+00, diff=8.34465027e-06, 1ULP@ref=1.19209290e-07
idx=(np.int64(0), np.int64(0), np.int64(0), np.int64(9)): ref=-1.14425735e+01, opt=-1.14425793e+01, diff=5.72204590e-06, 1ULP@ref=9.53674316e-07
idx=(np.int64(0), np.int64(0), np.int64(0), np.int64(10)): ref=2.34399719e+01, opt=2.34399662e+01, diff=5.72204590e-06, 1ULP@ref=1.90734863e-06
…and 272187 more
Traceback (most recent call last):
File "/data/qshenaf/remote_pc/TirFuzz/bugs/topi.nn.conv2d_hwcn_4.py", line 56, in