CTranslate2
CTranslate2 copied to clipboard
GEMM operator calculates the `c` output shape incorrectly when input `a` is transpose?
Hey everyone!
I believe I have found a bug in the GEMM operator.
To the best of my knowledge, the output shape of the c StorageView in the GEMM operator should always be: {m, n}.
However, Gemm::compute() calcultes the shape of c incorrectly if a is transpose (a has shape {k, m} instead of {m, k}).
Take a look at this piece of code: https://github.com/OpenNMT/CTranslate2/blob/4f8a4f334c59588223b6f1f24b707d7e8d5fe08c/src/ops/gemm.cc#L84-L88
-
Shape output_shape(a.shape());: setsoutput_shapeto:{k, m} -
output_shape[output_shape.size() - 1] = n;: setsoutput_shapeto:{k, n} -
c.resize(std::move(output_shape));: resizescto{k, n}.
Am I missing something here?
Bare in mind that there are no unit tests to catch this currently.