Dimension shift
Here's a puzzle posed by Francois Fleuret:
Probably dumb @PyTorch question : how to [elegantly] write?x = x.permute(0, ..., d-1, d+1, ..., x.dim()-1, d)
To use an example, find a way to permute axes such that for an x
:
x = torch.tensor(
[[[[[1, 1], [1, 1]], [[2, 2], [2, 2]], [[3, 3], [3, 3]], [[4, 4], [4, 4]]]]]
)
shift(x,d=2).equal(x.permute(0,1,3,4,2))
shift(x,d=4).equal(x.permute(0,1,2,3,4))