def matmul(M, K, N, dtype):
A = te.placeholder((M, K), name="A", dtype=dtype)
B = te.placeholder((K, N), name="B", dtype=dtype)
k = te.reduce_axis((0, K), name="k")
matmul = te.compute(
(M, N),
lambda i, j: te.sum(A[i, k] * B[k, j], axis=k),
name="matmul",
attrs ={"layout_free_placeholders": [B]}, # enable automatic layout transform for tensor B
)
out = te.compute((M, N), lambda i, j: matmul[i, j] , name="out")
return [A, B, out]
The output type is also int8, result larger than int8 will be cut off during computation.
How to make out tensor become int32?