0

I am trying to solve a complicated problem.

For example, I have a batch of 2D predicted images (softmax output, value between 0 and 1) with size: Batch x H x W and ground truth Batch x H x W

enter image description here

The light gray color pixels are the background with value 0, and the dark gray color pixels are the foreground with value 1. I try to compute the mass center coordinates using scipy.ndimage.center_of_mass on each ground truth image. Then I get the center location point C (red color) for each ground truth. The C points set is Batch x 1.

Now, for each pixel A (yellow color) in the predicted images, I want to get three pixels B1, B2, B3 (blue color) which are the closest to A on the line AC (here C is corresponding location of mass center in ground truth).

I used following code to get the three closest points B1, B2, B3.

def connect(ends, m=3):
    d0, d1 = np.abs(np.diff(ends, axis=0))[0]
    if d0 > d1:
        return np.c_[np.linspace(ends[0, 0], ends[1, 0], m + 1, dtype=np.int32),
                 np.round(np.linspace(ends[0, 1], ends[1, 1], m + 1))
                     .astype(np.int32)]
    else:
        return np.c_[np.round(np.linspace(ends[0, 0], ends[1, 0], m + 1))
                     .astype(np.int32),
                 np.linspace(ends[0, 1], ends[1, 1], m + 1, dtype=np.int32)]

So the B points set is Batch x 3 x H x W.

Then, I want to compute like this: |Value(A)-Value(B1)|+|Value(A)-Value(B2)|+|Value(A)-Value(B3)|. The size of the result should be Batch x H x W.

Is there any numpy vectorization tricks that can be used to update the value of each pixel in predicted images? Or can this be solved using pytorch functions? I need to find a method to update the whole image. The predicted image is the softmax output. I cannot use for loop to compute each single value since it will become non-differentiable. Thanks a lot.

N.Z
  • 19
  • 7
  • Check this: https://stackoverflow.com/a/47704298/3540982 – Matin H Aug 16 '18 at 09:42
  • Might want to try to create a [mcve], you have a bit too many independent parts in your question to give a good answer. – Daniel F Aug 16 '18 at 10:03
  • @DanielF Hi, I have updated the question. Now it only needs to updated the value of pixel A. Can you help to solve this? Thank you. – N.Z Aug 16 '18 at 11:12
  • @Matin Thank you Matin. I can compute the B points, but don't know how to compute |Value(A)-Value(B1)|+|Value(A)-Value(B2)|+|Value(A)-Value(B3)| while not using for loop. – N.Z Aug 16 '18 at 11:39

1 Answers1

0

As suggested by @Matin, you could consider Bresenham's algorithm to get your points on the AC line.

A simplistic PyTorch implementation could be as follows (directly adapted from the pseudo-code here ; could be optimized):

import torch

def get_points_from_low(x0, y0, x1, y1, num_points=3):
    dx = x1 - x0
    dy = y1 - y0
    xi = torch.sign(dx)
    yi = torch.sign(dy)
    dy = dy * yi
    D = 2 * dy - dx

    y = y0
    x = x0

    points = []
    for n in range(num_points):
        x = x + xi
        is_D_gt_0 = (D > 0).long()
        y = y + is_D_gt_0 * yi
        D = D + 2 * dy - is_D_gt_0 * 2 * dx

        points.append(torch.stack((x, y), dim=-1))

    return torch.stack(points, dim=len(x0.shape))

def get_points_from_high(x0, y0, x1, y1, num_points=3):
    dx = x1 - x0
    dy = y1 - y0
    xi = torch.sign(dx)
    yi = torch.sign(dy)
    dx = dx * xi
    D = 2 * dx - dy

    y = y0
    x = x0

    points = []
    for n in range(num_points):
        y = y + yi
        is_D_gt_0 = (D > 0).long()
        x = x + is_D_gt_0 * xi
        D = D + 2 * dx - is_D_gt_0 * 2 * dy

        points.append(torch.stack((x, y), dim=-1))

    return torch.stack(points, dim=len(x0.shape))

def get_points_from(x0, y0, x1, y1, num_points=3):
    is_dy_lt_dx = (torch.abs(y1 - y0) < torch.abs(x1 - x0)).long()
    is_x0_gt_x1 = (x0 > x1).long()
    is_y0_gt_y1 = (y0 > y1).long()

    sign = 1 - 2 * is_x0_gt_x1
    x0_comp, x1_comp, y0_comp, y1_comp = x0 * sign, x1 * sign, y0 * sign, y1 * sign
    points_low = get_points_from_low(x0_comp, y0_comp, x1_comp, y1_comp, num_points=num_points)
    points_low *= sign.view(-1, 1, 1).expand_as(points_low)

    sign = 1 - 2 * is_y0_gt_y1
    x0_comp, x1_comp, y0_comp, y1_comp = x0 * sign, x1 * sign, y0 * sign, y1 * sign
    points_high = get_points_from_high(x0_comp, y0_comp, x1_comp, y1_comp, num_points=num_points) * sign
    points_high *= sign.view(-1, 1, 1).expand_as(points_high)

    is_dy_lt_dx = is_dy_lt_dx.view(-1, 1, 1).expand(-1, num_points, 2)
    points = points_low * is_dy_lt_dx + points_high * (1 - is_dy_lt_dx)

    return points

# Inputs:
# (@todo: extend A to cover all points in maps):
A = torch.LongTensor([[0, 1], [8, 6]])
C = torch.LongTensor([[6, 4], [2, 3]])
num_points = 3

# Getting points between A and C:
# (@todo: what if there's less than `num_points` between A-C?)
Bs = get_points_from(A[:, 0], A[:, 1], C[:, 0], C[:, 1], num_points=num_points)
print(Bs)
# tensor([[[1, 1],
#          [2, 2],
#          [3, 2]],
#         [[7, 6],
#          [6, 5],
#          [5, 5]]])

Once you have your points, you could retrieve their "values" (Value(A), Value(B1), etc.) using torch.index_select() (note that as of now, this method only accept 1D indices, so you need to unravel your data). All things put together, this would look like something such as the following (extending A from shape (Batch, 2) to (Batch, H, W, 2) is left for exercise...)

# Inputs:
# (@todo: extend A to cover all points in maps):
A = torch.LongTensor([[0, 1], [8, 6]])
C = torch.LongTensor([[6, 4], [2, 3]])
batch_size = A.shape[0]
num_points = 3
map_size = (9, 9)
map_num_elements = map_size[0] * map_size[1]
map_values = torch.stack((torch.arange(0, map_num_elements).view(*map_size),
                          torch.arange(0, -map_num_elements, -1).view(*map_size)))

# Getting points between A and C:
# (@todo: what if there's less than `num_points` between A-C?)
Bs = get_points_from(A[:, 0], A[:, 1], C[:, 0], C[:, 1], num_points=num_points)

# Get map values in positions A:
A_unravel = torch.arange(0, batch_size) * map_num_elements
A_unravel = A_unravel + A[:, 0] * map_size[1] + A[:, 1]
values_A = torch.index_select(map_values.view(-1), dim=0, index=A_unravel)
print(values_A)
# tensor([ 1, -4])

# Get map values in positions A:
A_unravel = torch.arange(0, batch_size) * map_num_elements
A_unravel = A_unravel + A[:, 0] * map_size[1] + A[:, 1]
values_A = torch.index_select(map_values.view(-1), dim=0, index=A_unravel)
print(values_A)
# tensor([  1, -78])

# Get map values in positions B:
Bs_flatten = Bs.view(-1, 2)
Bs_unravel = (torch.arange(0, batch_size)
              .unsqueeze(1)
              .repeat(1, num_points)
              .view(num_points * batch_size) * map_num_elements)
Bs_unravel = Bs_unravel + Bs_flatten[:, 0] * map_size[1] + Bs_flatten[:, 1]
values_B = torch.index_select(map_values.view(-1), dim=0, index=Bs_unravel)
values_B = values_B.view(batch_size, num_points)
print(values_B)
# tensor([[ 10,  20,  29],
#         [-69, -59, -50]])

# Compute result:
res = torch.abs(values_A.unsqueeze(-1).expand_as(values_B) - values_B)
print(res)
# tensor([[ 9, 19, 28],
#         [ 9, 19, 28]])
res = torch.sum(res, dim=1)
print(res)
# tensor([56, 56])
benjaminplanche
  • 14,689
  • 5
  • 57
  • 69
  • Hi Aldream, thank you for the answer. When I get the coordinates of three points for each pixel A, do you know how can I compute |Value(A)-Value(B1)|+|Value(A)-Value(B2)|+|Value(A)-Value(B3)|? The values of each pixel A and its corresponding three B pixels come from the same softmax output. And the softmax output is a Variable in Pytorch with requires_grad=True, so I need to compute this with the whole image. – N.Z Aug 16 '18 at 11:30
  • I updated my answer to provide further directions, though your problem is maybe a bit too vast to be explored in a single post... – benjaminplanche Aug 16 '18 at 12:52
  • I found a bug. If the coordinates of A is greater than that of C, the return results should be reversed. For example A is [8,6], C is [2, 3]. Can you help to fix this? Thank you. – N.Z Aug 17 '18 at 14:17
  • Indeed, Bresenham's algorithm switches A and C in such cases, not really caring for the points order / direction. I've corrected the suggested method. It works now e.g. with `A = [8,6]` and `C = [2, 3]` (see answer), though I didn't check all other cases... – benjaminplanche Aug 17 '18 at 15:12
  • The code above already works for `N` pixels (`A` of shape `(N, 2)`). You could use `arange`/`meshgrid` functions to obtain all your image `A`s (shape `(H, W, 2)`) then flatten the tensor so that `N = H * W`. – benjaminplanche Aug 18 '18 at 17:58
  • Yes, I used arange and meshgrid the single image as (H, W, 2). However, there is another dimension, the batch size. The A pixels should be (Batch, H, W, 2). – N.Z Aug 18 '18 at 19:48
  • Assuming the `A`s are all the same for each image in the batch, this would mean just tiling/expanding your tensor to reach `(B, H, W, 2)`, before flattening to `(B*H*W, 2)` for the points computation. Similarly, assuming you already have `C` of shape `(B, 2)`, you could just expand it to `(B, H, W, 2)` before flattening to the same shape as `A`. – benjaminplanche Aug 19 '18 at 10:46
  • I finished it. Thank you! – N.Z Aug 22 '18 at 08:12