1

I would like to take the mean along an axis of a tensor, defined by tensor which contains several slices.

So this would be my sample tensor for which I want to get mean of slices from, along the first dimension

import torch

sample = torch.arange(0,40).reshape(10,-1)
sample
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23],
        [24, 25, 26, 27],
        [28, 29, 30, 31],
        [32, 33, 34, 35],
        [36, 37, 38, 39]])

And this would be the tensor which contains the start and end indices for which I would like to get the mean of

mean_slices = torch.tensor([
    [3, 5],
    [1, 8],
    [4, 8],
    [6, 9],
])

Since Pytorch doesn't have ragged tensors, I use a trick described here

https://stackoverflow.com/a/71337491/3259896

Where the cumsum is calculated through the entire axis for which I want to get the mean from, then the row of each end slice index is retrieved and subtracted from the cumsum row that is before the start slice index. Finally, the result is divided by the length of the slices.

padded = torch.nn.functional.pad(
    sample.cumsum(dim=0), (0, 0, 1, 0)
)
padded
tensor([[  0,   0,   0,   0],
        [  0,   1,   2,   3],
        [  4,   6,   8,  10],
        [ 12,  15,  18,  21],
        [ 24,  28,  32,  36],
        [ 40,  45,  50,  55],
        [ 60,  66,  72,  78],
        [ 84,  91,  98, 105],
        [112, 120, 128, 136],
        [144, 153, 162, 171],
        [180, 190, 200, 210]])
pools = torch.diff(
    padded[mean_slices], dim=1
).squeeze()/torch.diff(mean_slices, dim=1)

pools
tensor([[14., 15., 16., 17.],
        [16., 17., 18., 19.],
        [22., 23., 24., 25.],
        [28., 29., 30., 31.]])

The only issue with this solution is that originally I was only looking to get the mean of specifically the rows defined by the slices, and while my current solution does that, the calculations involve all rows before the slices indices as well. So the backwards pass may not work as intended.

Is this guess correct?

Is there a more exact and computationally efficient way to calculate the mean for the slices defined in a tensor?

SantoshGupta7
  • 5,607
  • 14
  • 58
  • 116

1 Answers1

1

Why do you think that the gradient calculation includes values of pixels outside the slices?

When you compute the sum over the slice using torch.cumsum you sum all the values outside the slice twice: once to estimate their sum, stored in the row before the slice and the second time you sum the slice and these values storing this value at the last row of the slice. The most important thing is you subtract the row-before-first from the last row: That is, you eliminate the sum of all values outside the slice from the equation. Thus these values has no effect on the calculation nor on the gradients.

Here's a simple example:
Consider the function f(x,y,z) = x + y + z - z. What is the gradient of f w.r.t z? Once z is eliminated, it has no effect on the value of f nor on its gradient.

Bottom line: your backward pass is correct, and is not affected by values outside the slices.


regarding a more efficient implementation:
If the minimal slice starting index is high (that is, there's a large portion of sample that is ignored by all slices) you might remove it completely:

mn, mx = mean_slices.min(), mean_slices.max()  # only the relevant par of sample
padded_ef = torch.nn.functional.pad(
    sample[mn:mx, :].cumsum(dim=0), (0, 0, 1, 0)
)
# sum the slices - need to shift the index
pools_ef = torch.diff(
    padded_ef[mean_slices-mn], dim=1
).squeeze()/torch.diff(mean_slices, dim=1)

Results with the same pools, but potentially involving less elements of sample if the slices are "packed".
However, unless sample is very large w.r.t the slices, I don't believe this will give you a significant boost in run time.

Shai
  • 111,146
  • 38
  • 238
  • 371