0

I have an array a whose size is torch.Size([330, 330, 36])

The original data structure is [330, 6, 330, 6]

The meaning is:

I have 330 atoms in the system, each atom has 6 orbitals

I want to know the interactions between all atoms and all orbitals.

I want to perform these operations:

(1) a.reshape(330,330,6,6).
permute(0,2,1,3).reshape(1980, 1980)

convert the matrix to (330 x 6) x (330 x 6)

(2) torch.sum(torch.diag(b@b)[1:6])

perform a matmul operation and sum the diagonal elements 1-5

I want to know if there is any method to perform matmul operation without reshaping 330x330x36 matrix.

Thanks a lot.

(1) a.reshape(330,330,6,6).permute(0,2,1,3).reshape(1980,1980)

(2) torch.sum(torch.diag(b@b)[1:6])

What if I have a list of matrices, how to do matmul operations in a single command?

  • the key with `A@B` is that last dim of A sums with the 2nd to last of B. Lead dimensions (of 3+) are 'batched'. `matmul` should be clear about this. You may have to to `transpose` or `reshape`. – hpaulj Aug 28 '23 at 21:16
  • `diag(b@b)[1:6]` throws out a lot of values – hpaulj Aug 29 '23 at 00:20
  • For an arbitrary order of multiplications/summations, I'd recommend [`torch.einsum`](https://pytorch.org/docs/stable/generated/torch.einsum.html). This function will automatically determine the most optimal contraction (not all contractions are equal), and once you get used to the notation it is much more intuitive/readable than reshaping and `torch.tensordot` – MPA Aug 29 '23 at 06:09

1 Answers1

0

You asked for a couple of things, and what you are doing is inefficient.

Matmul without reshape

As I will explain below, you should not do this contraction. But assume you want to. You can not avoid the reshape that "splits" the axis 36 -> 6 * 6, but you can avoid combining the 6 * 303 -> 1980 by using torch.tensordot. In your case that would be

b = a.reshape(330, 330, 6, 6)
c = torch.tensordot(b, b, ([1, 3], [0, 2]))  # shape [330, 6, 330, 6]

List of matrices

If it is a list of torch.Tensors, you can not get around doint a loop of some kind, so no there is no "one command" solution. If you have a single Tensor, created e.g. via torch.tensor, say of shape as.shape == (42, 330, 330, 36) for 42 different "matrices", you can batch the torch operations;

bs = as.reshape(42, 330, 330, 6, 6)
cs = torch.tensordot(bs, bs, ([2, 4], [1, 3]))  # shape [42, 330, 6, 330, 6]

More efficient way to compute what you are after

It seems that you are only interested in a few diagonal entries of the matrix product. In your case only 5 of 1980 * 1980 total entries thats. So you should only compute those entries, as computing the other roughly 4000000 entries is not needed. For example

b = a.reshape(330, 330, 6, 6)
c = torch.sum(b[0, :, 1:5, :] * b[:, 0, :, 1:5])

should give the same as you got in your snippets above. Note that due to C-style reshaping your index 1:5 becomes 0 and 1:5, e.g.

after_reshape = before_reshape.reshape(330, 6)
before_reshape[1:5] == after_reshape[0, 1:5]
Jakob Unfried
  • 136
  • 2
  • 6
  • Under the covers the `numpy` `tensordot` uses reshape and transpose to reduce the calculation to single `np.dot` call (with 2d arrays) – hpaulj Aug 29 '23 at 14:27
  • @hpaulj True. That is also true for torch. However, avoiding the explicit reshape is (a) what OP asked for and (b) IMO better style since it is more readable and easier to maintain – Jakob Unfried Aug 30 '23 at 11:25