Skip to content

Commit

Permalink
[ORTModule] ATen Support for upsample_bilinear (#14519)
Browse files Browse the repository at this point in the history
It's required by model MobileViT.
  • Loading branch information
centwang authored Feb 4, 2023
1 parent c1a0fc5 commit 3d75187
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 9 deletions.
11 changes: 6 additions & 5 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,10 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"_adaptive_avg_pool2d": self._infer_aten_pool2d,
"numpy_T": self._infer_Transpose,
"native_group_norm": self._infer_aten_group_norm,
"upsample_nearest1d": self._infer_aten_upsample_nearest,
"upsample_nearest2d": self._infer_aten_upsample_nearest,
"upsample_nearest3d": self._infer_aten_upsample_nearest,
"upsample_nearest1d": self._infer_aten_upsample,
"upsample_nearest2d": self._infer_aten_upsample,
"upsample_nearest3d": self._infer_aten_upsample,
"upsample_bilinear2d": self._infer_aten_upsample,
}
self.run_ = True
self.suggested_merge_ = {}
Expand Down Expand Up @@ -1389,14 +1390,14 @@ def _infer_aten_group_norm(self, node):
)
)

def _infer_aten_upsample_nearest(self, node):
def _infer_aten_upsample(self, node):
new_shape = None
input_shape = self._get_shape(node, 0)
if input_shape is not None:
new_shape = input_shape[:2]
output_size = self._try_get_value(node, 1)
if output_size is not None:
new_shape += [dim_size.item() for dim_size in output_size]
new_shape += [dim_size.item() if type(dim_size) == np.int64 else dim_size for dim_size in output_size]
else:
rank = len(input_shape)
new_shape += [str(self._new_symbolic_dim_from_output(node, 0, i)) for i in range(2, rank)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,10 @@ def native_group_norm_gradient():

# PyTorch removed related backward functions with "vec" overload name since 1.13. The functions with no overload name
# are available for all versions, though they are not that convienent to use.
def _upsample_nearest_gradient(backward_fn, dims):
def _upsample_gradient(backward_fn, dims):
scales = ["" for _ in range(dims)]
if "bilinear" in backward_fn:
scales = ["I(2)"] + scales
return [
("Shape", ["I(0)"], ["Shape_X"]),
("Shape", ["O(0)"], ["Shape_Y"]),
Expand All @@ -258,14 +260,19 @@ def _upsample_nearest_gradient(backward_fn, dims):

@register_gradient("org.pytorch.aten", "ATen", "upsample_nearest1d", "vec")
def upsample_nearest1d_gradient():
return _upsample_nearest_gradient("upsample_nearest1d_backward", 1)
return _upsample_gradient("upsample_nearest1d_backward", 1)


@register_gradient("org.pytorch.aten", "ATen", "upsample_nearest2d", "vec")
def upsample_nearest2d_gradient():
return _upsample_nearest_gradient("upsample_nearest2d_backward", 2)
return _upsample_gradient("upsample_nearest2d_backward", 2)


@register_gradient("org.pytorch.aten", "ATen", "upsample_nearest3d", "vec")
def upsample_nearest3d_gradient():
return _upsample_nearest_gradient("upsample_nearest3d_backward", 3)
return _upsample_gradient("upsample_nearest3d_backward", 3)


@register_gradient("org.pytorch.aten", "ATen", "upsample_bilinear2d", "vec")
def upsample_bilinear2d_gradient():
return _upsample_gradient("upsample_bilinear2d_backward", 2)
Original file line number Diff line number Diff line change
Expand Up @@ -799,3 +799,16 @@ def upsample_nearest2d(g, input, output_size, scale_factors):
@register_symbolic("upsample_nearest3d")
def upsample_nearest3d(g, input, output_size, scale_factors):
return _upsample_nearest(g, input, output_size, scale_factors, "upsample_nearest3d")


@register_symbolic("upsample_bilinear2d")
def upsample_bilinear2d(g, input, output_size, align_corners, scale_factors):
return g.op(
"org.pytorch.aten::ATen",
input,
output_size,
align_corners,
scale_factors,
operator_s="upsample_bilinear2d",
overload_name_s="vec",
)
Original file line number Diff line number Diff line change
Expand Up @@ -1782,6 +1782,34 @@ def run_step(model, input):
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)


def test_aten_upsample_bilinear():
class _NeuralNetUpsampleBilinear(torch.nn.Module):
def __init__(self):
super(_NeuralNetUpsampleBilinear, self).__init__()

def forward(self, input):
return torch.nn.functional.interpolate(input, size=(8, 12), mode="bilinear")

device = "cuda"
pt_model = _NeuralNetUpsampleBilinear().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))

def run_step(model, input):
prediction = model(input)
prediction.sum().backward()
return prediction

# reset manual seed to reset the generator
torch.manual_seed(2333)
pt_input = torch.randn([2, 4, 6, 8], dtype=torch.float, device=device, requires_grad=True)
ort_input = copy.deepcopy(pt_input)
pt_prediction = run_step(pt_model, pt_input)
ort_prediction = run_step(ort_model, ort_input)

_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)


def test_gradient_correctness_cast_chain():
class NeuralNetCast(torch.nn.Module):
def __init__(self, D):
Expand Down

0 comments on commit 3d75187

Please sign in to comment.