Add Attention Microsoft Contrib Operator
Spec here https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention
Useful resources: https://towardsdatascience.com/transformers-explained-visually-part-2-how-it-works-step-by-step-b49fa4a64f34/ https://towardsdatascience.com/transformers-explained-visually-part-3-multi-head-attention-deep-dive-1c1ff1024853/
Adding change that gets rid of the concat for now based on @pfultz2 's recommendation. Seeing a speedup with this when using MLIR fusion on an attention block. Mirroed a model workload with just attention in a gen_onnx.py test and we're getting a significant speedup.
Before with no MLIR flags
Summary:
gpu::code_object::mul_add_reduce_max_sub_exp_reduce_sum_div_kernel: 3.86515ms / 1 = 3.86515ms, 51%
gpu::gemm: 2.92227ms / 3 = 0.97409ms, 39%
gpu::code_object::mlir_reshape_dot: 0.344482ms / 1 = 0.344482ms, 5%
gpu::code_object::mlir_reshape_reshape_transpose_dot: 0.327408ms / 1 = 0.327408ms, 5%
gpu::code_object::contiguous_kernel: 0.115116ms / 1 = 0.115116ms, 2%
gpu::code_object::not_convert_mul_kernel: 0.0205068ms / 1 = 0.0205068ms, 1%
load: 0.00884556ms / 7 = 0.00126365ms, 1%
slice: 0.0055792ms / 3 = 0.00185973ms, 1%
multibroadcast: 0.00461838ms / 3 = 0.00153946ms, 1%
@param: 0.00413152ms / 5 = 0.000826304ms, 1%
reshape_lazy: 0.00160806ms / 1 = 0.00160806ms, 1%
unsqueeze: 0.00158582ms / 1 = 0.00158582ms, 1%
check_context::migraphx::gpu::context: 0.00141454ms / 1 = 0.00141454ms, 1%
broadcast: 0.00133182ms / 1 = 0.00133182ms, 1%
hip::hip_allocate_memory: 0.00126186ms / 1 = 0.00126186ms, 1%
Batch size: 1
Rate: 133.76 inferences/sec
Total time: 7.47609ms (Min: 7.46932ms, Max: 7.48673ms, Mean: 7.47635ms, Median: 7.4761ms)
Percentiles (90%, 95%, 99%): (7.48157ms, 7.48357ms, 7.48636ms)
Total instructions time: 7.62531ms
Overhead time: 0.00697326ms, -0.149226ms
Overhead: 0%, -2%
[ MIGraphX Version: 2.13.0. ] Complete: bin/driver perf ../test/onnx/attention_multihead_bias_mask_test.onnx
After with MIGRAPHX_MLIR_USE_SPECIFIC_OPS=attention,~dot
Summary:
gpu::gemm: 2.96798ms / 3 = 0.989326ms, 83%
gpu::code_object::mlir_reshape_reshape_transpose_dot_mul_add_softmax_reshape_dot: 0.482526ms / 1 = 0.482526ms, 14%
gpu::code_object::contiguous_kernel: 0.114383ms / 1 = 0.114383ms, 4%
gpu::code_object::not_convert_mul_kernel: 0.0203474ms / 1 = 0.0203474ms, 1%
load: 0.00602034ms / 5 = 0.00120407ms, 1%
slice: 0.00533918ms / 3 = 0.00177973ms, 1%
multibroadcast: 0.00512556ms / 3 = 0.00170852ms, 1%
@param: 0.00405146ms / 5 = 0.000810292ms, 1%
unsqueeze: 0.00152746ms / 1 = 0.00152746ms, 1%
broadcast: 0.0014327ms / 1 = 0.0014327ms, 1%
check_context::migraphx::gpu::context: 0.00138322ms / 1 = 0.00138322ms, 1%
reshape_lazy: 0.00134576ms / 1 = 0.00134576ms, 1%
hip::hip_allocate_memory: 0.00116304ms / 1 = 0.00116304ms, 1%
Batch size: 1
Rate: 290.792 inferences/sec
Total time: 3.43888ms (Min: 3.42825ms, Max: 3.58141ms, Mean: 3.46037ms, Median: 3.43728ms)
Percentiles (90%, 95%, 99%): (3.55396ms, 3.56477ms, 3.57875ms)
Total instructions time: 3.61262ms
Overhead time: 0.0063449ms, -0.173741ms
Overhead: 0%, -5%
MIGRAPHX_MLIR_USE_SPECIFIC_OPS=attention,~dot \
[ MIGraphX Version: 2.13.0. ] Complete: bin/driver perf ../test/onnx/attention_multihead_bias_mask_test.onnx
Vs prior to removing the concat
Flags
Summary:
gpu::gemm: 2.92583ms / 3 = 0.975278ms, 51%
gpu::code_object::mlir_slice_reshape_transpose_slice_squeeze_dot_mul_add_softmax_slice_dot: 2.65099ms / 16 = 0.165687ms, 46%
gpu::code_object::concat_kernel: 0.162542ms / 1 = 0.162542ms, 3%
slice: 0.0289778ms / 19 = 0.00152515ms, 1%
multibroadcast: 0.0286749ms / 19 = 0.0015092ms, 1%
load: 0.0240347ms / 20 = 0.00120174ms, 1%
gpu::code_object::not_convert_mul_kernel: 0.020387ms / 1 = 0.020387ms, 1%
@param: 0.00453062ms / 5 = 0.000906124ms, 1%
unsqueeze: 0.00160364ms / 1 = 0.00160364ms, 1%
check_context::migraphx::gpu::context: 0.0013431ms / 1 = 0.0013431ms, 1%
hip::hip_allocate_memory: 0.00128176ms / 1 = 0.00128176ms, 1%
Batch size: 1
Rate: 189.493 inferences/sec
Total time: 5.27725ms (Min: 5.24866ms, Max: 5.3047ms, Mean: 5.27715ms, Median: 5.27773ms)
Percentiles (90%, 95%, 99%): (5.29325ms, 5.29882ms, 5.30175ms)
Total instructions time: 5.8502ms
Overhead time: 0.0229228ms, -0.572951ms
Overhead: 0%, -11%
MIGRAPHX_MLIR_USE_SPECIFIC_OPS=attention,~dot \
[ MIGraphX Version: 2.13.0. ] Complete: bin/driver perf ../test/onnx/attention_multihead_bias_mask_test.onnx
No flags
Summary:
gpu::gemm: 2.92451ms / 3 = 0.974837ms, 61%
gpu::code_object::mul_add_reduce_max_sub_exp_reduce_sum_div_kernel: 0.704009ms / 16 = 0.0440006ms, 15%
gpu::code_object::mlir_slice_reshape_transpose_slice_squeeze_dot: 0.502644ms / 16 = 0.0314153ms, 11%
gpu::code_object::mlir_slice_dot: 0.418183ms / 16 = 0.0261364ms, 9%
gpu::code_object::concat_kernel: 0.162208ms / 1 = 0.162208ms, 4%
load: 0.0630131ms / 52 = 0.00121179ms, 2%
slice: 0.028744ms / 19 = 0.00151284ms, 1%
multibroadcast: 0.0275485ms / 19 = 0.00144992ms, 1%
gpu::code_object::not_convert_mul_kernel: 0.0197071ms / 1 = 0.0197071ms, 1%
@param: 0.0045218ms / 5 = 0.00090436ms, 1%
unsqueeze: 0.00178614ms / 1 = 0.00178614ms, 1%
check_context::migraphx::gpu::context: 0.00130386ms / 1 = 0.00130386ms, 1%
hip::hip_allocate_memory: 0.0012166ms / 1 = 0.0012166ms, 1%
Batch size: 1
Rate: 257.44 inferences/sec
Total time: 3.88439ms (Min: 3.87337ms, Max: 3.90875ms, Mean: 3.8846ms, Median: 3.88467ms)
Percentiles (90%, 95%, 99%): (3.89034ms, 3.8924ms, 3.89919ms)
Total instructions time: 4.8594ms
Overhead time: 0.0366287ms, -0.975003ms
Overhead: 1%, -25%
[ MIGraphX Version: 2.13.0. ] Complete: bin/driver perf ../test/onnx/attention_multihead_bias_mask_test.onnx
Found a way to get the slice fused into the attention head as well if I handle the weight dot product before the slice block. Simplied the logic too. Getting a slight speedup per attention head by doing this, leaving only the GEMM input being the issue now. I'm able to now let us use the attention,dot flags for Mlir without penalty. I get about a 0.01ms speed up per attetion head doing this and as a result the input broadcast also becomes fused.
Summary:
gpu::code_object::mlir_broadcast_dot: 2.9135ms / 1 = 2.9135ms, 83%
gpu::code_object::mlir_slice_reshape_slice_reshape_transpose_dot_mul_add_softmax_slice_reshape_dot: 0.481105ms / 1 = 0.481105ms, 14%
gpu::code_object::contiguous_kernel: 0.115444ms / 1 = 0.115444ms, 4%
gpu::code_object::not_convert_mul_kernel: 0.0186374ms / 1 = 0.0186374ms, 1%
@param: 0.00381764ms / 5 = 0.000763528ms, 1%
load: 0.00313322ms / 3 = 0.00104441ms, 1%
reshape_lazy: 0.00130248ms / 1 = 0.00130248ms, 1%
check_context::migraphx::gpu::context: 0.00127416ms / 1 = 0.00127416ms, 1%
broadcast: 0.00114032ms / 1 = 0.00114032ms, 1%
hip::hip_allocate_memory: 0.00094864ms / 1 = 0.00094864ms, 1%
Batch size: 1
Rate: 291.754 inferences/sec
Total time: 3.42755ms (Min: 3.41509ms, Max: 3.51073ms, Mean: 3.4358ms, Median: 3.42659ms)
Percentiles (90%, 95%, 99%): (3.47342ms, 3.48969ms, 3.50983ms)
Total instructions time: 3.54031ms
Overhead time: 0.00367538ms, -0.112759ms
Overhead: 0%, -3%
MIGRAPHX_MLIR_USE_SPECIFIC_OPS=attention,dot \
[ MIGraphX Version: 2.13.0. ] Complete: bin/driver perf ../test/onnx/attention_multihead_bias_mask_test.onnx
Need to get changes sorted as attention isn't fusing anymore when using this PR in testing - https://github.com/ROCm/AMDMIGraphX/pull/3993
We want this to be supported as it reduces the amount of larger dot/GEMMs and reduces our GEMM times from 98+ms -> ~75ms
Meaning this gives us about a 20-25% boost should attention fuse correctly.
Got fusion sorted for this and benchmarked things using Customer script. Saw a 21% speedup using Paul's change.
Added past/present input outputs and omnidirectional leaving rotary input encoding out for now to just get this in. I can open a issue for this.
All I have left to do is parser/verify test. I've reduced the size for the attention_double_head for the verify test simplicity (batch 2, seq 4, hidden 4, etc) The math should workout to be the same once this is scaled up.
Opened up a question with Onnxruntime, it appears bias argument may be now non optional, sorting that out but proceeding with additional verify/parser tests.
I've also modified the accuracy checker to spit out input/gold data as well to generate things for larger test cases. Will include in another PR @kahmed10 should be useful for when we generate things.
@kahmed10 @kahmed10 Will handle comments after merging develop. Got parser/verify test for this running now.
Codecov Report
:x: Patch coverage is 93.99293% with 17 lines in your changes missing coverage. Please review.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| src/onnx/parse_attention.cpp | 93.99% | 17 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## develop #3816 +/- ##
===========================================
+ Coverage 92.04% 92.19% +0.15%
===========================================
Files 531 544 +13
Lines 24527 25070 +543
===========================================
+ Hits 22574 23112 +538
- Misses 1953 1958 +5
| Files with missing lines | Coverage Δ | |
|---|---|---|
| src/onnx/parse_attention.cpp | 93.99% <93.99%> (ø) |
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
@pfultz2 @kahmed10 let me know if there's anything else to add.
May cut unidirectional/4d masks right now since I'm seeing weird behavior from the OnnxRT side to verify and that'll reduce scope (customer doesn't use either attr so its okay for now for support)
Same goes for Attn_bias - Added it to complete/check inputs but will be adding test cases for that too
All comments handled. This should be ready for another pass.
@pfultz2 looks like rewrite is causing issues with other test modules. Error only seems to occur with specific_ops=attention is turned on that requires the int8 conversion
Backed out change for parser to remove the convert for now. Seems to be something with MLIR specific op fusion that I would need to modify.
Backing out merge from main. Seeing errors with that.
Overall, I think this looks good, but need another review from Alan or Shiv, and to fix the convert issue with MLIR.
Added --python output and descripted issue seeing with where here- https://github.com/ROCm/AMDMIGraphX/issues/4117
| Test | Batch | Rate new 6f4af0 |
Rate old 42747c |
Diff | Compare |
|---|---|---|---|---|---|
| torchvision-resnet50 | 64 | 3,253.79 | 3,236.74 | 0.53% | :white_check_mark: |
| torchvision-resnet50_fp16 | 64 | 6,908.88 | 6,883.85 | 0.36% | :white_check_mark: |
| torchvision-densenet121 | 32 | 2,451.45 | 2,442.77 | 0.36% | :white_check_mark: |
| torchvision-densenet121_fp16 | 32 | 4,205.13 | 4,191.66 | 0.32% | :white_check_mark: |
| torchvision-inceptionv3 | 32 | 1,627.52 | 1,626.95 | 0.04% | :white_check_mark: |
| torchvision-inceptionv3_fp16 | 32 | 2,724.12 | 2,723.47 | 0.02% | :white_check_mark: |
| cadene-inceptionv4 | 16 | 755.16 | 765.87 | -1.40% | :white_check_mark: |
| cadene-resnext64x4 | 16 | 818.95 | 814.08 | 0.60% | :white_check_mark: |
| slim-mobilenet | 64 | 7,474.80 | 7,440.99 | 0.45% | :white_check_mark: |
| slim-nasnetalarge | 64 | 209.66 | 210.06 | -0.19% | :white_check_mark: |
| slim-resnet50v2 | 64 | 3,347.21 | 3,331.82 | 0.46% | :white_check_mark: |
| bert-mrpc-onnx | 8 | 1,149.47 | 1,143.19 | 0.55% | :white_check_mark: |
| bert-mrpc-tf | 1 | 460.12 | 460.12 | -0.00% | :white_check_mark: |
| pytorch-examples-wlang-gru | 1 | 347.11 | 344.33 | 0.81% | :white_check_mark: |
| pytorch-examples-wlang-lstm | 1 | 493.12 | 484.02 | 1.88% | :white_check_mark: |
| torchvision-resnet50_1 | 1 | 791.80 | 797.02 | -0.65% | :white_check_mark: |
| cadene-dpn92_1 | 1 | 414.09 | 415.93 | -0.44% | :white_check_mark: |
| cadene-resnext101_1 | 1 | 389.05 | 388.29 | 0.19% | :white_check_mark: |
| onnx-taau-downsample | 1 | 396.17 | 395.50 | 0.17% | :white_check_mark: |
| dlrm-criteoterabyte | 1 | 32.32 | 33.64 | -3.94% | :red_circle: |
| dlrm-criteoterabyte_fp16 | 1 | 51.31 | 51.23 | 0.15% | :white_check_mark: |
| agentmodel | 1 | 10,295.91 | 10,267.67 | 0.28% | :white_check_mark: |
| unet_fp16 | 2 | 59.63 | 59.64 | -0.02% | :white_check_mark: |
| resnet50v1_fp16 | 1 | 1,032.43 | 1,038.76 | -0.61% | :white_check_mark: |
| resnet50v1_int8 | 1 | 1,053.32 | 1,062.02 | -0.82% | :white_check_mark: |
| bert_base_cased_fp16 | 64 | 1,175.62 | 1,171.22 | 0.38% | :white_check_mark: |
| bert_large_uncased_fp16 | 32 | 358.20 | 359.76 | -0.43% | :white_check_mark: |
| bert_large_fp16 | 1 | 200.01 | 203.58 | -1.75% | :white_check_mark: |
| distilgpt2_fp16 | 16 | 2,243.51 | 2,230.24 | 0.59% | :white_check_mark: |
| yolov5s | 1 | 521.06 | 540.25 | -3.55% | :red_circle: |
| tinyllama | 1 | 43.83 | 43.84 | -0.04% | :white_check_mark: |
| vicuna-fastchat | 1 | 45.00 | 45.09 | -0.21% | :white_check_mark: |
| whisper-tiny-encoder | 1 | 417.84 | 418.38 | -0.13% | :white_check_mark: |
| whisper-tiny-decoder | 1 | 404.42 | 410.69 | -1.53% | :white_check_mark: |
| llama2_7b | 1 | 19.11 | 19.16 | -0.28% | :white_check_mark: |
| qwen1.5-7b | 1 | 23.53 | 23.56 | -0.10% | :white_check_mark: |
| phi3-3.8b | 1 | 26.56 | 26.75 | -0.69% | :white_check_mark: |
| mask-rcnn | 1 | 12.81 | 12.83 | -0.18% | :white_check_mark: |
| llama3-8b | 1 | 21.70 | 21.75 | -0.21% | :white_check_mark: |
| whisper-large-encoder | 1 | 10.22 | 10.18 | 0.38% | :white_check_mark: |
| whisper-large-decoder | 1 | 100.97 | 103.74 | -2.68% | :white_check_mark: |
| mistral-7b | 1 | 23.80 | 23.76 | 0.16% | :white_check_mark: |
| FLUX.1-schnell | 1 | 766.63 | 737.77 | 3.91% | :high_brightness: |
| nan | nan | nan | nan | nan% | :x: |
This build is not recommended to merge :red_circle:
:x:bert-mrpc-tf: ERROR - check error output
2025-07-04 12:41:01.442954: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1751650867.133976 183203 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 62973 MB memory: -> device: 0, name: AMD Instinct MI250X/MI250, pci bus id: 0000:b3:00.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1751650868.005664 183203 mlir_graph_optimization_pass.cc:401] MLIR V1 optimization pass is not enabled
2025-07-04 12:41:16.334065: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc
2025-07-04 12:41:16.334119: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc
2025-07-04 12:41:16.334166: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc
2025-07-04 12:41:16.334212: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc
2025-07-04 12:41:16.334436: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc
2025-07-04 12:41:16.334488: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc
2025-07-04 12:41:16.334539: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc
2025-07-04 12:41:16.334594: E external/local_xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:250] bitcode module is required by this HLO module but was not found at ./opencl.bc
error: Failure when generating HSACO
error: Failure when generating HSACO
error: Failure when generating HSACO
error: Failure when generating HSACO
error: Failure when generating HSACO
error: Failure when generating HSACO
error: Failure when generating HSACO
error: Failure when generating HSACO
2025-07-04 12:41:16.335860: E tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc:228] INTERNAL: Generating device code failed.
2025-07-04 12:41:16.336970: W tensorflow/core/framework/op_kernel.cc:1829] UNKNOWN: JIT compilation failed.
2025-07-04 12:41:16.336991: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: UNKNOWN: JIT compilation failed.
[[{{node import/bert/embeddings/LayerNorm/moments/SquaredDifference}}]]
2025-07-04 12:41:16.337000: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: UNKNOWN: JIT compilation failed.
[[{{node import/bert/embeddings/LayerNorm/moments/SquaredDifference}}]]
[[import/loss/output/_21]]
2025-07-04 12:41:16.337015: I tensorflow/core/framework/local_rendezvous.cc:424] Local rendezvous recv item cancelled. Key hash: 11217777527359497193
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1407, in _do_call
return fn(*args)
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1390, in _run_fn
return self._call_tf_sessionrun(options, feed_dict, fetch_list,
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1483, in _call_tf_sessionrun
return tf_session.TF_SessionRun_wrapper(self._session, options, feed_dict,
tensorflow.python.framework.errors_impl.UnknownError: 2 root error(s) found.
(0) UNKNOWN: JIT compilation failed.
[[{{node import/bert/embeddings/LayerNorm/moments/SquaredDifference}}]]
[[import/loss/output/_21]]
(1) UNKNOWN: JIT compilation failed.
[[{{node import/bert/embeddings/LayerNorm/moments/SquaredDifference}}]]
0 successful operations.
0 derived errors ignored.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 359, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 335, in main
y_out = sess.run(y, feed_dict=tf_dict)
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 977, in run
result = self._run(None, fetches, feed_dict, options_ptr,
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1220, in _run
results = self._do_run(handle, final_targets, final_fetches,
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1400, in _do_run
return self._do_call(_run_fn, feeds, fetches, targets, options,
File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1426, in _do_call
raise type(e)(node_def, op, message) # pylint: disable=no-value-for-parameter
tensorflow.python.framework.errors_impl.UnknownError: Graph execution error:
Detected at node 'import/bert/embeddings/LayerNorm/moments/SquaredDifference' defined at (most recent call last):
Node: 'import/bert/embeddings/LayerNorm/moments/SquaredDifference'
Detected at node 'import/bert/embeddings/LayerNorm/moments/SquaredDifference' defined at (most recent call last):
Node: 'import/bert/embeddings/LayerNorm/moments/SquaredDifference'
2 root error(s) found.
(0) UNKNOWN: JIT compilation failed.
[[{{node import/bert/embeddings/LayerNorm/moments/SquaredDifference}}]]
[[import/loss/output/_21]]
(1) UNKNOWN: JIT compilation failed.
[[{{node import/bert/embeddings/LayerNorm/moments/SquaredDifference}}]]
0 successful operations.
0 derived errors ignored.
Original stack trace for 'import/bert/embeddings/LayerNorm/moments/SquaredDifference'::red_circle:unet: FAILED: MIGraphX is not within tolerance - check verbose output
:red_circle:bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output
:red_circle:mask-rcnn: FAILED: MIGraphX is not within tolerance - check verbose output