[QST] MatX is around x15 slower than CuPy for the same task
Gidday.
I'm a bit of a novice with MatX and CPP, and was looking to get some help with optimising my MatX code.
So basically I'm trying to refactor my code that was written in CuPy first into lightning fast MatX code. Except I find that my MatX implementation, despite (what looks to me) an identical equivalent to my CuPy code, it is a lot slower. I was wondering if anybody would be able to give me some tips as to where my code might be slowing down.
FYI a general assumption is that MatX's operators are super lightweight - so the reshapes, repmats are all super quick.
My MatX code looks like:
matx::tensor_t<matx::matxFp16, 2> GsDBSCAN::findDistancesMatX(matx::tensor_t<matx::matxFp16, 2> &X_t, matx::tensor_t<int, 2> &A_t, matx::tensor_t<int, 2> &B_t, float alpha, int batchSize) {
const int k = A_t.Shape()[1] / 2;
const int m = B_t.Shape()[1];
const int n = X_t.Shape()[0];
const int d = X_t.Shape()[1];
int D = B_t.Shape()[0] / 2;
batchSize = (batchSize != -1) ? batchSize: GsDBSCAN::findDistanceBatchSize(alpha, n, d, k, m);
auto AFlat_t = matx::flatten(A_t);
auto distances_t = matx::make_tensor<matx::matxFp16>({n, 2*k*m});
for (int i = 0; i < n; i += batchSize) {
int maxBatchIdx = i + batchSize - 1; // Index within X along the ROWS
auto XSubset_t_op = matx::slice(X_t, {i, 0}, {maxBatchIdx + 1, matx::matxEnd});
auto ABatchFlat_t_op = matx::slice(AFlat_t, {i * 2 * k}, {(maxBatchIdx + 1) * 2 * k});
auto BBatch_t_op = matx::remap<0>(B_t, ABatchFlat_t_op);
auto XBatch_t_op = matx::remap<0>(X_t, matx::flatten(BBatch_t_op));
auto XBatchReshaped_t_op = matx::reshape(XBatch_t_op, {batchSize, 2*k*m, d});
auto XSubsetReshaped_t_op = matx::reshape(XSubset_t_op, {batchSize, 1, d});
auto YBatch_t_op = (XBatchReshaped_t_op - matx::repmat(XSubsetReshaped_t_op, {1, 2*k*m, 1})); // Repmat is a workaround for minusing naively incompatibhle tensor shapes
auto YBatch_t_norm_op = matx::vector_norm(YBatch_t_op, {2}, matx::NormOrder::L2);
(matx::slice(distances_t, {i, 0}, {maxBatchIdx + 1, matx::matxEnd}) = YBatch_t_norm_op).run();
}
return distances_t;
}
And the same CuPy code looks like:
def find_distances(X, A, B, alpha=1.2, batch_size = -1):
k = A.shape[1] // 2
m = B.shape[1]
n = X.shape[0]
d = X.shape[1]
D = B.shape[0] // 2
batch_size = batch_size if batch_size != -1 else get_batch_size(n, d, k, m, alpha=alpha)
distances = cp.empty(shape=(n, 2 * k * m),
dtype=cp.float16) # float32 causes a memory overload. float16 is fine (for eps 2DP)
for i in range(0, n, batch_size):
max_batch_idx = min(i + batch_size, X.shape[0])
Z_batch = X[B[A[i:max_batch_idx]]]
# (Edit): Changed the reshape call to be a little clearer. Z_batch_adj is equivalent to XBatchReshaped_t_op above.
Z_batch_adj = Z_batch.reshape(batch_size, 2 * k * m, d)
Y_batch = Z_batch_adj - X[i:max_batch_idx, cp.newaxis, :]
distances[i:max_batch_idx] = cp.linalg.norm(Y_batch, axis=2)
return distances
The parameters used for both are:
n = 70_000
k = 5
m = 50
d = 784
D = 1024
batchSize ~= 250 (FYI it will should always be a divisor of n, I found that CuPy implementation was a lot slower otherwise on the final iteration).
Regarding results, the MatX code takes around 14.5 seconds to complete, but CuPy takes 0.9 seconds (including Cuda Synchronisations).
As a baseline, a multithreaded (64 threads) CPU implementation of the above code (using loops with no tensors involved) takes less than 0.7 seconds. A single threaded CPU implementation takes around 7 seconds - (this is using the same machine of course).
Sorry if the variable names are a little cryptic.
I've tested for around n = 1000 and found that the two implementations produce the same results (albeit with a small amount of floating point errors).
Thanks in advance.
Hi @HugoPhibbs , this is very interesting an unexpected. We'll take a look at the profile and get back to you.
Can the batches all run in parallel and get the same answer? We generally suggest removing batch loops like this and just send in the entire tensor. It's simpler and faster.
On Fri, Aug 2, 2024, 10:38 PM Cliff Burdick @.***> wrote:
Hi @HugoPhibbs https://github.com/HugoPhibbs , this is very interesting an unexpected. We'll take a look at the profile and get back to you.
— Reply to this email directly, view it on GitHub https://github.com/NVIDIA/MatX/issues/688#issuecomment-2266369527, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABSFS4T5YZ7UXGLE4H3PGFLZPRNFLAVCNFSM6AAAAABL5R35MGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENRWGM3DSNJSG4 . You are receiving this because you are subscribed to this thread.Message ID: @.***>
Hi @luitjens
I'm using batches because otherwise, my GPU quickly runs out of memory (I tried no batching with CuPy and this was the result). Batching is used to control the memory usage of intermediary tensors.
I intend in the future to tune the batch size to produce optimal memory usage of the GPU, but right now, I'm focused on getting an MVP.
Are you timing the allocation and page faults of managed memory as part of the execution time? If you switch to cuda memory instead of managed does the perf issue go away?
On Fri, Aug 2, 2024, 11:08 PM Hugo Phibbs @.***> wrote:
Hi @luitjens https://github.com/luitjens
I'm using batches because otherwise, my GPU quickly runs out of memory (I tried no batching with CuPy and this was the result). Batching is used to control the memory usage of intermediary tensors.
I intend in the future to tune the batch size to produce optimal memory usage of the GPU, but right now, I'm focused on getting an MVP.
— Reply to this email directly, view it on GitHub https://github.com/NVIDIA/MatX/issues/688#issuecomment-2266376565, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABSFS4XSCA6HXWE75C2WXRLZPRQWVAVCNFSM6AAAAABL5R35MGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENRWGM3TMNJWGU . You are receiving this because you were mentioned.Message ID: @.***>
Hi @luitjens, thx for getting back to me.
I'm timing the complete function runtime - as in how long it takes to run the function start to finish. The timing looks a bit like this:
TEST_F(TestFindingDistances, TestLargeInputMatX) {
int k = 5;
int n = 70000;
int m = 50;
int D = 1024;
int d = 784;
auto A = tu::createMockAMatrixMatX(n, k, D);
auto B = tu::createMockBMatrixMatX(n, m, D);
auto X = tu::createMockMnistDatasetMatX(n, d);
cudaDeviceSynchronize(); // Possibly not necessary?
tu::Time start = tu::timeNow();
auto distances = GsDBSCAN::findDistancesMatX(X, A, B, 1.2, 250);
cudaDeviceSynchronize();
tu::printDurationSinceStart(start);
printf("%lld %lld", distances.Shape()[0], distances.Shape()[1]);
ASSERT_TRUE(distances.Shape()[0] == n);
ASSERT_TRUE(distances.Shape()[1] == 2*k*m);
}
As for memory options, I changed the memory space of all the tensors to matx::MATX_DEVICE_MEMORY and I'm still getting the same 14.5 second runtime. E.g. what I did was the below for all the make_tensor calls:
inline auto createMockAMatrixMatX(int n = 70000, int k = 2, int D = 1024) {
auto A = matx::make_tensor<float>({n, 2*k}, matx::MATX_DEVICE_MEMORY);
auto A_i = matx::make_tensor<int32_t>({n, 2*k}, matx::MATX_DEVICE_MEMORY);
int a = 2 * (D - 1);
(A = matx::random<float>({n, 2*k}, matx::UNIFORM, 0, a)).run();
(A_i = matx::as_type<int32_t>(A)).run();
return A_i;
}
Hi again,
I've done some more testing, and I've found that the cuda synchronise step takes the lion's share of the runtime. I added some hacky profiling to the function that looks like this:
matx::tensor_t<matx::matxFp16, 2> GsDBSCAN::findDistancesMatX(matx::tensor_t<matx::matxFp16, 2> &X_t, matx::tensor_t<int, 2> &A_t, matx::tensor_t<int, 2> &B_t, float alpha, int batchSize) {
const int k = A_t.Shape()[1] / 2;
const int m = B_t.Shape()[1];
const int n = X_t.Shape()[0];
const int d = X_t.Shape()[1];
int D = B_t.Shape()[0] / 2;
batchSize = (batchSize != -1) ? batchSize : GsDBSCAN::findDistanceBatchSize(alpha, n, d, k, m);
auto AFlat_t = matx::flatten(A_t);
auto distances_t = matx::make_tensor<matx::matxFp16>({n, 2*k*m}, matx::MATX_DEVICE_MEMORY);
int j = 0;
std::vector<double> times;
auto start_all = std::chrono::high_resolution_clock::now();
for (int i = 0; i < n; i += batchSize) {
auto start = std::chrono::high_resolution_clock::now();
int maxBatchIdx = i + batchSize - 1; // Index within X along the ROWS
auto XSubset_t_op = matx::slice(X_t, {i, 0}, {maxBatchIdx + 1, matx::matxEnd});
auto ABatchFlat_t_op = matx::slice(AFlat_t, {i * 2 * k}, {(maxBatchIdx + 1) * 2 * k});
auto BBatch_t_op = matx::remap<0>(B_t, ABatchFlat_t_op);
auto XBatch_t_op = matx::remap<0>(X_t, matx::flatten(BBatch_t_op));
auto XBatchReshaped_t_op = matx::reshape(XBatch_t_op, {batchSize, 2*k*m, d});
auto XSubsetReshaped_t_op = matx::reshape(XSubset_t_op, {batchSize, 1, d});
auto YBatch_t_op = (XBatchReshaped_t_op - matx::repmat(XSubsetReshaped_t_op, {1, 2*k*m, 1})); // Repmat is a workaround for minusing naively incompatibhle tensor shapes
auto YBatch_t_norm_op = matx::vector_norm(YBatch_t_op, {2}, matx::NormOrder::L2);
(matx::slice(distances_t, {i, 0}, {maxBatchIdx + 1, matx::matxEnd}) = YBatch_t_norm_op).run();
// Record end time
auto end = std::chrono::high_resolution_clock::now();
// Calculate the duration
std::chrono::duration<double> duration = end - start;
// Cast to double and store in array
times.push_back(duration.count());
}
auto start_sync = std::chrono::high_resolution_clock::now();
cudaDeviceSynchronize();
// Record end time
auto end_sync = std::chrono::high_resolution_clock::now();
// Calculate the duration
std::chrono::duration<double> duration_sync = end_sync - start_sync;
// Output the duration
std::cout << "Time taken: " << duration_sync.count() << " seconds" << std::endl;
for (const auto& element : times) {
std::cout << element << std::endl;
}
// Record end time
auto end_all = std::chrono::high_resolution_clock::now();
// Calculate the duration
std::chrono::duration<double> duration = end_all - start_all;
// Output the duration
std::cout << "Time taken: " << duration.count() << " seconds" << std::endl;
return distances_t;
}
Which produces the output:
Time taken: 14.4069 seconds // For the synchronise
0.00528887
1.775e-05
... A bunch more times around 1.5e-5 (with approx 300 total loop runs, this creates a runtime of around 0.005 ~ 300 * 1.5e-5 seconds for the entire loop (which is very quick)
1.642e-05
Time taken: 14.4189 seconds // For the overall function call
Has this got something to do with the fact that MatX looks to have an async execution style? I.e. adding a bunch of async operations to queue on GPU may produce a large bottleneck effect? - Just an idea
Hi, can you please provide fully buildable/runnable example in both matx and python that we can use to compare?
Generally speaking you don't want to include allocation time in your timings as you want to allocate once upfront and reuse.
alternatively if you cannot easily create us a standalone reproducer can you share an nsys profile of both python and matx with us?
Ok thx, pls see the gist: https://gist.github.com/HugoPhibbs/a2ce2c75b70c6737f1094f32b15af3ea
It contains source files to run it, along with an nsys profile
I recreated your repro as an example and had to make a few modifications to get it to build. Once I did that I ran on H100 and I see this output:
Total Time taken: 0.0242754 seconds Total Time taken (again): 0.0248659 seconds 70000 500
Unfortunately I was not able to view your profile as it says it is corrupt. Could you get a fresh profile, put it in a zip and upload it to your example?
Hugo I created a repro with some build fixes here: https://github.com/NVIDIA/MatX/tree/688-repro
From your build directory: %> make repro %> ./examples/repro
Can you verify that the issue still reproduces?
On L40 i see similar performance:
Total Time taken: 0.0205341 seconds Total Time taken (again): 0.0211705 seconds
Ok thanks. Honestly I'm a little bit skeptical that it would take just a fraction of a second. But yes, the error still reproduces on my machine:
make repro
[ 50%] Building CUDA object examples/CMakeFiles/repro.dir/repro.cu.o
[100%] Linking CUDA executable repro
[100%] Built target repro
./examples/repro
Sync Time taken: 14.3653 seconds
0.00239233
2.32e-05
....
1.776e-05
Total Time taken: 14.3747 seconds
Total Time taken (again): 14.375 seconds
70000 500
Please see this zip for the profile test_profile.zip - may have been an encoding issue.
I guess now would be a good time to show you my environment:
nvidia-smi
Wed Aug 7 09:38:27 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.28.03 Driver Version: 560.28.03 CUDA Version: 12.6 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 3090 Off | 00000000:01:00.0 On | N/A |
| 53% 45C P5 63W / 390W | 1022MiB / 24576MiB | 47% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA GeForce RTX 3090 Off | 00000000:4A:00.0 On | N/A |
| 0% 48C P8 49W / 390W | 115MiB / 24576MiB | 25% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------
nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0
Thank you for the profile. When I inspect the profiles the matx generated kernels seem inline with hardware (2ms on H100 and 6.5ms on 3090). However, the reduction kernel seems way off. We use cub for this kernel so perhaps there is something going wrong in cub. We will investigate this. As a work around can you try to materialize the inputs to the reduction kernel into a memory backed tensor then compute the vector norm on the memory backed tensor: https://github.com/NVIDIA/MatX/blob/688-repro/examples/repro.cu#L96
can you also get me an ncu profile with this command on your system:
ncu --import-source --set full --metrics all --kernel-id "::regex:.*ReduceKernel.*:1" -o 3090 ./examples/repro
Then zip up 309.ncu-rep and attach that too.
updated ncu instruction above
Also can you try updating your toolkit?
You currently have: Cuda 11.8.
I'd suggest going to 12.5.
@HugoPhibbs I ran this on both an A100 and 3090. Here are the results: A100:
Total Time taken: 0.0451949 seconds
Total Time taken (again): 0.0453394 seconds
3090:
Total Time taken: 0.0272116 seconds
Total Time taken (again): 0.0280414 seconds
This is CUDA 12.5. I will try 11.8 and report back.
@HugoPhibbs on my nsys capture I see 32 registers per thread whereas @luitjens pointed out you had 128. Here is our compilation line:
cd /repro/tmp/MatX/build/examples && /usr/local/cuda/bin/nvcc -forward-unknown-to-host-compiler -DMATX_DISABLE_CUB_CACHE -DMATX_ENABLE_FILEIO -DMATX_ENABLE_PYBIND11 -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_CUDA -DTHRUST_HOST_SYSTEM=THRUST_HOST_SYSTEM_CPP -I/repro/tmp/MatX/include -I/repro/tmp/MatX/include/matx/kernels -I/repro/tmp/MatX/build/_deps/cccl-src/thrust/thrust/cmake/../.. -I/repro/tmp/MatX/build/_deps/cccl-src/libcudacxx/lib/cmake/libcudacxx/../../../include -I/repro/tmp/MatX/build/_deps/cccl-src/cub/cub/cmake/../.. -isystem=/repro/tmp/MatX/build/_deps/pybind11-src/include -isystem=/usr/include/python3.10 -isystem=/usr/local/cuda/include -g --generate-code=arch=compute_80,code=[compute_80,sm_80] --generate-code=arch=compute_86,code=[compute_86,sm_86] -Wall -Wextra -Wcast-align -Wunused -Wshadow -Wno-unknown-pragmas -Wnon-virtual-dtor -Wconversion -Wmisleading-indentation -Wduplicated-cond -Wduplicated-branches -Wlogical-op -Wnull-dereference -Werror all-warnings --threads 0 -ftemplate-backtrace-limit=0 -lineinfo --expt-relaxed-constexpr -DMATX_ROOT=\"/repro/tmp/MatX\" -fvisibility=hidden -MD -MT examples/CMakeFiles/repro.dir/repro.cu.o -MF CMakeFiles/repro.dir/repro.cu.o.d -x cu -c /repro/tmp/MatX/examples/repro.cu -o CMakeFiles/repro.dir/repro.cu.o
Can you send what yours looks like? If you're importing the matx::matx target it should look similar
@HugoPhibbs I was able to reproduce your issue on CUDA 11.8 with everything else the same:
Total Time taken: 14.6827 seconds
Total Time taken (again): 14.6836 seconds
Is it possible for you to update? This may be an issue where nvcc had trouble with register reuse in this case causing poor occupancy.
thx @cliffburdick and @luitjens
@luitjens re ncu, currently waiting for admin permissions to run sudo ncu ..., I'll send results once I can.
As on the front of upgrading CUDA, I upgraded to 12.5, it runs ok, but now my tests are broken 🙃.
Just to make sure, when I do cudaDeviceSynchronize() this makes sure that any pending operations on the GPU are done right? When I upgrade to 12.5, the returned distances_t tensor is now just empty (full of zeros) - where as with 11.8 it was full of values.
E.g. my simple tests look a bit like:
auto distances_t = GsDBSCAN::findDistancesMatX(X_t_16, A_t, B_t);
cudaDeviceSynchronize();
matx::matxFp16 *distances_ptr = distances_t.Data();
matx::matxFp16 expected_squared[] = {
11, 5, 14, 11, 0, 5,
9, 0, 11, 0, 14, 11,
5, 0, 5, 5, 8, 14,
9, 5, 0, 0, 9, 5,
9, 6, 5, 5, 0, 6
};
for (int i = 0; i < 5*6; i++) {
ASSERT_NEAR(std::sqrt(expected_squared[i]), distances_ptr[i], 1e-3); // distances is full of zeros with 12.5 but actually full in 11.8
}
Do you guys know a reason why this may be?
First thing I'd do is drop this macro call I to your code, specifically after syncs. This will verify that no cuda errors occurred. https://gist.github.com/jefflarkin/5390993
On Wed, Aug 7, 2024, 5:09 PM Hugo Phibbs @.***> wrote:
thx @cliffburdick https://github.com/cliffburdick and @luitjens https://github.com/luitjens
@luitjens https://github.com/luitjens re ncu, currently waiting for admin permissions to run sudo ncu ..., I'll send results once I can.
As on the front of upgrading CUDA, I upgraded to 12.5, it runs ok, but now my tests are broken 🙃.
Just to make sure, when I do cudaDeviceSynchronize() this makes sure that any pending operations on the GPU are done right? When I upgrade to 12.5, the returned distances_t tensor is now just empty (full of zeros) - where as with 11.8 it was full of values.
E.g. my simple tests look a bit like:
auto distances_t = GsDBSCAN::findDistancesMatX(X_t_16, A_t, B_t); cudaDeviceSynchronize();
matx::matxFp16 *distances_ptr = distances_t.Data();
matx::matxFp16 expected_squared[] = { 11, 5, 14, 11, 0, 5, 9, 0, 11, 0, 14, 11, 5, 0, 5, 5, 8, 14, 9, 5, 0, 0, 9, 5, 9, 6, 5, 5, 0, 6 }; for (int i = 0; i < 5*6; i++) { ASSERT_NEAR(std::sqrt(expected_squared[i]), distances_ptr[i], 1e-3); // distances is full of zeros with 12.5 but actually full in 11.8 }
Do you guys know a reason why this may be?
— Reply to this email directly, view it on GitHub https://github.com/NVIDIA/MatX/issues/688#issuecomment-2274495660, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABSFS4TDGUQAGEZ3PHROMVTZQKSLFAVCNFSM6AAAAABL5R35MGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENZUGQ4TKNRWGA . You are receiving this because you were mentioned.Message ID: @.***>
@luitjens yep added the macro and no errors occur
Ok, can you create a pr to modify the repro branch which assets that there is an error? Then in the morning I will dig into it.
On Wed, Aug 7, 2024, 5:34 PM Hugo Phibbs @.***> wrote:
@luitjens https://github.com/luitjens yep added the macro and no errors occur
— Reply to this email directly, view it on GitHub https://github.com/NVIDIA/MatX/issues/688#issuecomment-2274518553, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABSFS4W5PD7M2GQ4BNK6FNDZQKVHZAVCNFSM6AAAAABL5R35MGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENZUGUYTQNJVGM . You are receiving this because you were mentioned.Message ID: @.***>
sure will do
Can you also verify you are not trying to dereference a device pointer on the host?
On Wed, Aug 7, 2024, 5:35 PM Justin Luitjens @.***> wrote:
Ok, can you create a pr to modify the repro branch which assets that there is an error? Then in the morning I will dig into it.
On Wed, Aug 7, 2024, 5:34 PM Hugo Phibbs @.***> wrote:
@luitjens https://github.com/luitjens yep added the macro and no errors occur
— Reply to this email directly, view it on GitHub https://github.com/NVIDIA/MatX/issues/688#issuecomment-2274518553, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABSFS4W5PD7M2GQ4BNK6FNDZQKVHZAVCNFSM6AAAAABL5R35MGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENZUGUYTQNJVGM . You are receiving this because you were mentioned.Message ID: @.***>
Ok have checked. But don't think is the case since I'm using managed memory?
I was getting seg fault when using device memory, so I changed the mem to managed and it worked (in cuda 11.8)
Ok yes managed is fine. I will review in the morning.
On Wed, Aug 7, 2024, 6:07 PM Hugo Phibbs @.***> wrote:
Ok have checked. But don't think is the case since I'm using managed memory?
I was getting seg fault when using device memory, so I changed the mem to managed and it worked (in cuda 11.8)
— Reply to this email directly, view it on GitHub https://github.com/NVIDIA/MatX/issues/688#issuecomment-2274571880, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABSFS4SEVBTK4XJIYPTBCRTZQKZFTAVCNFSM6AAAAABL5R35MGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENZUGU3TCOBYGA . You are receiving this because you were mentioned.Message ID: @.***>
@HugoPhibbs we're still looking into it, but we can reproduce your issue.
Hi @HugoPhibbs we found a bug in reshape where it was using the wrong type when passed to a reduction. This has been fixed in this PR: https://github.com/NVIDIA/MatX/pull/703
Please either pull the latest after that's tested/merged (in about an hour), or grab the changes directly from that branch.