Skip to content

Commit

Permalink
[Relay][Pytorch] Add support for aten::unflatten (apache#16131)
Browse files Browse the repository at this point in the history
* add support for `aten::unflatten`

* Add check that dshape[dim] % multiplication of dimensions in unflattened_size == 0

* Update shape check

* handle `dim=-1`

* formatting
  • Loading branch information
mshr-h authored Nov 22, 2023
1 parent 3190f28 commit bce8243
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 1 deletion.
22 changes: 22 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,6 +1546,27 @@ def flatten(self, inputs, input_types):
out = _op.squeeze(out, axis=squeeze_axes)
return out

def unflatten(self, inputs, input_types):
data = inputs[0]
dim = int(inputs[1])
unflattened_size = tuple(inputs[2])
dshape = get_const_tuple(self.infer_shape_with_prelude(data))

dim = dim if dim >= 0 else len(dshape) + dim
assert len(dshape) > dim >= 0

assert unflattened_size.count(-1) <= 1

mult = np.multiply.reduce(unflattened_size)
if mult < 0:
assert dshape[dim] % mult == 0
else:
assert dshape[dim] == mult

new_shape = dshape[:dim] + unflattened_size + dshape[dim + 1 :]
out = _op.reshape(data, new_shape)
return out

def addmm(self, inputs, input_types):
input_mat = inputs[0]
mat1 = inputs[1]
Expand Down Expand Up @@ -3945,6 +3966,7 @@ def create_convert_map(self):
"aten::t": self.transpose,
"aten::numpy_T": self.numpy_T,
"aten::flatten": self.flatten,
"aten::unflatten": self.unflatten,
"aten::addmm": self.addmm,
"aten::size": self.size,
"aten::view": self.view,
Expand Down
36 changes: 35 additions & 1 deletion tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,40 @@ def _test_flatten(start_dim, end_dim):
verify_model(_test_flatten(-3, -2), inp)


@tvm.testing.uses_gpu
def test_unflatten():
"""test_unflatten"""

def _test_unflatten(dim, unflattened_size):
return lambda inp: torch.unflatten(inp, dim, unflattened_size)

inp = torch.rand(60)

# [60] -> [3, 5, 2, 2]
verify_model(_test_unflatten(0, (3, 5, 2, 2)), inp)
verify_model(_test_unflatten(0, (-1, 5, 2, 2)), inp)
verify_model(_test_unflatten(0, (3, -1, 2, 2)), inp)
verify_model(_test_unflatten(0, (3, 5, -1, 2)), inp)
verify_model(_test_unflatten(0, (3, 5, 2, -1)), inp)
verify_model(_test_unflatten(-1, (3, 5, 2, 2)), inp)
verify_model(_test_unflatten(-1, (-1, 5, 2, 2)), inp)
verify_model(_test_unflatten(-1, (3, -1, 2, 2)), inp)
verify_model(_test_unflatten(-1, (3, 5, -1, 2)), inp)
verify_model(_test_unflatten(-1, (3, 5, 2, -1)), inp)

inp = torch.rand(3, 4, 1)

# [3, 4, 1] -> [3, 2, 2, 1]
verify_model(_test_unflatten(1, (2, 2)), inp)
verify_model(_test_unflatten(1, (-1, 2)), inp)

inp = torch.rand(5, 12, 3)

# [5, 12, 3] -> [5, 2, 2, 3, 1, 1, 3]
verify_model(_test_unflatten(1, (2, 2, 3, 1, 1)), inp)
verify_model(_test_unflatten(-2, (2, 2, 3, 1, 1)), inp)


@tvm.testing.uses_gpu
def test_forward_transpose():
"""test_forward_transpose"""
Expand Down Expand Up @@ -4744,7 +4778,7 @@ def test_fn(x, mask):
verify_model(test_fn, [inp.to(torch.float64), inp > 0.5])


@pytest.mark.skip(reason="unsupported op: 'aten::scaled_dot_product_attention', 'aten::unflatten'")
@pytest.mark.skip(reason="unsupported op: 'aten::scaled_dot_product_attention'")
def test_transformer():
"""test_transformer"""
model = torch.nn.Transformer(d_model=256, nhead=8, num_encoder_layers=6, num_decoder_layers=6)
Expand Down

0 comments on commit bce8243

Please sign in to comment.