3

I need to insert elements of tensor new into a tensor old with a certain probability, let's say that it is 0.8 for simplicity. Substantially this is what masked_fill would do, but it only works with monodimensional tensor. Actually I am doing

    prob = torch.rand(trgs.shape, dtype=torch.float32).to(trgs.device)
    mask = prob < 0.8

    dim1, dim2, dim3, dim4 = new.shape
    for a in range(dim1):
        for b in range(dim2):
            for c in range(dim3):
                for d in range(dim4):
                    old[a][b][c][d] = old[a][b][c][d] if mask[a][b][c][d] else new[a][b][c][d]

which is awful. I would like something like

    prob = torch.rand(trgs.shape, dtype=torch.float32).to(trgs.device)
    mask = prob < 0.8

    old = trgs.multidimensional_masked_fill(mask, new)
Stefano Berti
  • 141
  • 1
  • 11

1 Answers1

2

I am not sure what some of your objects are, but this should get you to do what you need in short order:

old is the your existing data.

mask is the mask you generated with probability p

new is the new tensor that has elements you want to insert.

# torch.where
result = old.where(mask, new)
John Stud
  • 1,506
  • 23
  • 46