diff --git a/torchvision/_meta_registrations.py b/torchvision/_meta_registrations.py index 7baece2ae2c..58512753ef7 100644 --- a/torchvision/_meta_registrations.py +++ b/torchvision/_meta_registrations.py @@ -33,7 +33,7 @@ def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, samp ), ) num_rois = rois.size(0) - _, channels, height, width = input.size() + channels = input.size(1) return input.new_empty((num_rois, channels, pooled_height, pooled_width)) @@ -51,6 +51,51 @@ def meta_roi_align_backward( return grad.new_empty((batch_size, channels, height, width)) +@register_meta("ps_roi_align") +def meta_ps_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio): + torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]") + torch._check( + input.dtype == rois.dtype, + lambda: ( + "Expected tensor for input to have the same type as tensor for rois; " + f"but type {input.dtype} does not equal {rois.dtype}" + ), + ) + channels = input.size(1) + torch._check( + channels % (pooled_height * pooled_width) == 0, + "input channels must be a multiple of pooling height * pooling width", + ) + + num_rois = rois.size(0) + out_size = (num_rois, channels // (pooled_height * pooled_width), pooled_height, pooled_width) + return input.new_empty(out_size), torch.empty(out_size, dtype=torch.int32, device="meta") + + +@register_meta("_ps_roi_align_backward") +def meta_ps_roi_align_backward( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + batch_size, + channels, + height, + width, +): + torch._check( + grad.dtype == rois.dtype, + lambda: ( + "Expected tensor for grad to have the same type as tensor for rois; " + f"but type {grad.dtype} does not equal {rois.dtype}" + ), + ) + return grad.new_empty((batch_size, channels, height, width)) + + @torch._custom_ops.impl_abstract("torchvision::nms") def meta_nms(dets, scores, iou_threshold): torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D") diff --git a/torchvision/csrc/ops/autograd/ps_roi_align_kernel.cpp b/torchvision/csrc/ops/autograd/ps_roi_align_kernel.cpp index 47e51ce9ca2..7205e9b15db 100644 --- a/torchvision/csrc/ops/autograd/ps_roi_align_kernel.cpp +++ b/torchvision/csrc/ops/autograd/ps_roi_align_kernel.cpp @@ -16,16 +16,16 @@ class PSROIAlignFunction const torch::autograd::Variable& input, const torch::autograd::Variable& rois, double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, + c10::SymInt pooled_height, + c10::SymInt pooled_width, int64_t sampling_ratio) { ctx->saved_data["spatial_scale"] = spatial_scale; ctx->saved_data["pooled_height"] = pooled_height; ctx->saved_data["pooled_width"] = pooled_width; ctx->saved_data["sampling_ratio"] = sampling_ratio; - ctx->saved_data["input_shape"] = input.sizes(); + ctx->saved_data["input_shape"] = input.sym_sizes(); at::AutoDispatchBelowADInplaceOrView g; - auto result = ps_roi_align( + auto result = ps_roi_align_symint( input, rois, spatial_scale, @@ -48,19 +48,19 @@ class PSROIAlignFunction auto saved = ctx->get_saved_variables(); auto rois = saved[0]; auto channel_mapping = saved[1]; - auto input_shape = ctx->saved_data["input_shape"].toIntList(); - auto grad_in = detail::_ps_roi_align_backward( + auto input_shape = ctx->saved_data["input_shape"].toList(); + auto grad_in = detail::_ps_roi_align_backward_symint( grad_output[0], rois, channel_mapping, ctx->saved_data["spatial_scale"].toDouble(), - ctx->saved_data["pooled_height"].toInt(), - ctx->saved_data["pooled_width"].toInt(), + ctx->saved_data["pooled_height"].toSymInt(), + ctx->saved_data["pooled_width"].toSymInt(), ctx->saved_data["sampling_ratio"].toInt(), - input_shape[0], - input_shape[1], - input_shape[2], - input_shape[3]); + input_shape[0].get().toSymInt(), + input_shape[1].get().toSymInt(), + input_shape[2].get().toSymInt(), + input_shape[3].get().toSymInt()); return { grad_in, @@ -82,15 +82,15 @@ class PSROIAlignBackwardFunction const torch::autograd::Variable& rois, const torch::autograd::Variable& channel_mapping, double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, + c10::SymInt pooled_height, + c10::SymInt pooled_width, int64_t sampling_ratio, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { at::AutoDispatchBelowADInplaceOrView g; - auto grad_in = detail::_ps_roi_align_backward( + auto grad_in = detail::_ps_roi_align_backward_symint( grad, rois, channel_mapping, @@ -117,8 +117,8 @@ std::tuple ps_roi_align_autograd( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, + c10::SymInt pooled_height, + c10::SymInt pooled_width, int64_t sampling_ratio) { auto result = PSROIAlignFunction::apply( input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); @@ -131,13 +131,13 @@ at::Tensor ps_roi_align_backward_autograd( const at::Tensor& rois, const at::Tensor& channel_mapping, double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, + c10::SymInt pooled_height, + c10::SymInt pooled_width, int64_t sampling_ratio, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { return PSROIAlignBackwardFunction::apply( grad, rois, diff --git a/torchvision/csrc/ops/ps_roi_align.cpp b/torchvision/csrc/ops/ps_roi_align.cpp index 6d091b3c695..de458c0d62d 100644 --- a/torchvision/csrc/ops/ps_roi_align.cpp +++ b/torchvision/csrc/ops/ps_roi_align.cpp @@ -22,6 +22,21 @@ std::tuple ps_roi_align( input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); } +std::tuple ps_roi_align_symint( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_align.ps_roi_align"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::ps_roi_align", "") + .typed(); + return op.call( + input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); +} + namespace detail { at::Tensor _ps_roi_align_backward( @@ -54,13 +69,43 @@ at::Tensor _ps_roi_align_backward( width); } +at::Tensor _ps_roi_align_backward_symint( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_ps_roi_align_backward", "") + .typed(); + return op.call( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + batch_size, + channels, + height, + width); +} + } // namespace detail TORCH_LIBRARY_FRAGMENT(torchvision, m) { m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)")); + "torchvision::ps_roi_align(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio) -> (Tensor, Tensor)")); m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor")); + "torchvision::_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor")); } } // namespace ops diff --git a/torchvision/csrc/ops/ps_roi_align.h b/torchvision/csrc/ops/ps_roi_align.h index c5ed865982c..75650586bc6 100644 --- a/torchvision/csrc/ops/ps_roi_align.h +++ b/torchvision/csrc/ops/ps_roi_align.h @@ -14,6 +14,14 @@ VISION_API std::tuple ps_roi_align( int64_t pooled_width, int64_t sampling_ratio); +VISION_API std::tuple ps_roi_align_symint( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio); + namespace detail { at::Tensor _ps_roi_align_backward( @@ -29,6 +37,19 @@ at::Tensor _ps_roi_align_backward( int64_t height, int64_t width); +at::Tensor _ps_roi_align_backward_symint( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + int64_t sampling_ratio, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width); + } // namespace detail } // namespace ops