2

I have a 6 dimensional all-zero pytorch tensor lrel_w that I want to fill with 1s at positions where the indices of the first three dimensions and the indices of the last three dimensions match. I'm currently solving this trivially using 3 nested for loops:

lrel_w = torch.zeros(
  input_size[0], input_size[1], input_size[2],
  input_size[0], input_size[1], input_size[2]
)
for c in range(input_size[0]):
  for x in range(input_size[1]):
    for y in range(input_size[2]):
      lrel_w[c,x,y,c,x,y] = 1

I'm sure there must be a more efficient way of doing this, however I have not been able to figure it out.

ymerkli
  • 23
  • 1
  • 5
  • Without knowing why you want to construct this, it is difficult to find a more efficient alternative, and for how many data points (e.g., what is input size typically?) – amdex Dec 04 '20 at 14:42
  • It's a weight matrix that stores a set of weights for each pixel of an MNIST image. Each weight set is respective to a set of neurons in a convolutional layer. So a typical shape of `lrel_w` is for example `[1,28,28,16,28,28]`. – ymerkli Dec 04 '20 at 15:18

1 Answers1

3

You can try this one.

import torch
c, m, n = input_size[0], input_size[1], input_size[2]

t = torch.zeros(c, m, n, c, m, n)
i, j, k = torch.meshgrid(torch.arange(c), torch.arange(m), torch.arange(n))
i = i.flatten()
j = j.flatten()
k = k.flatten()

t[i, j, k, i, j, k] = 1

Here is how meshgrid works in case you need reference.

swag2198
  • 2,546
  • 1
  • 7
  • 18
  • 1
    This is twice according to my timings. Nice! – amdex Dec 04 '20 at 15:03
  • 1
    Thanks, this works great! For the typical tensor shapes of my problem (`[1,28,28,16,28,28]`), this provides roughly a 150x speed-up on my machine, compared to the 3 nested for loops. – ymerkli Dec 04 '20 at 15:22