Skip to content

Commit

Permalink
Add RAFT model for optical flow (#5022)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Dec 6, 2021
1 parent 9b57de6 commit 01ffb3a
Show file tree
Hide file tree
Showing 8 changed files with 752 additions and 4 deletions.
15 changes: 14 additions & 1 deletion docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Models and pre-trained weights
The ``torchvision.models`` subpackage contains definitions of models for addressing
different tasks, including: image classification, pixelwise semantic
segmentation, object detection, instance segmentation, person
keypoint detection and video classification.
keypoint detection, video classification, and optical flow.

.. note ::
Backward compatibility is guaranteed for loading a serialized
Expand Down Expand Up @@ -798,3 +798,16 @@ ResNet (2+1)D
:template: function.rst

torchvision.models.video.r2plus1d_18

Optical flow
============

Raft
----

.. autosummary::
:toctree: generated/
:template: function.rst

torchvision.models.optical_flow.raft_large
torchvision.models.optical_flow.raft_small
Binary file not shown.
Binary file not shown.
35 changes: 32 additions & 3 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _get_expected_file(name=None):
return expected_file


def _assert_expected(output, name, prec):
def _assert_expected(output, name, prec=None, atol=None, rtol=None):
"""Test that a python value matches the recorded contents of a file
based on a "check" name. The value must be
pickable with `torch.save`. This file
Expand All @@ -110,10 +110,11 @@ def _assert_expected(output, name, prec):
MAX_PICKLE_SIZE = 50 * 1000 # 50 KB
binary_size = os.path.getsize(expected_file)
if binary_size > MAX_PICKLE_SIZE:
raise RuntimeError(f"The output for {filename}, is larger than 50kb")
raise RuntimeError(f"The output for {filename}, is larger than 50kb - got {binary_size}kb")
else:
expected = torch.load(expected_file)
rtol = atol = prec
rtol = rtol or prec # keeping prec param for legacy reason, but could be removed ideally
atol = atol or prec
torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False)


Expand Down Expand Up @@ -818,5 +819,33 @@ def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_load
assert n_trainable_params == _model_tests_values[model_name]["n_trn_params_per_layer"]


@needs_cuda
@pytest.mark.parametrize("model_builder", (models.optical_flow.raft_large, models.optical_flow.raft_small))
@pytest.mark.parametrize("scripted", (False, True))
def test_raft(model_builder, scripted):

torch.manual_seed(0)

# We need very small images, otherwise the pickle size would exceed the 50KB
# As a resut we need to override the correlation pyramid to not downsample
# too much, otherwise we would get nan values (effective H and W would be
# reduced to 1)
corr_block = models.optical_flow.raft.CorrBlock(num_levels=2, radius=2)

model = model_builder(corr_block=corr_block).eval().to("cuda")
if scripted:
model = torch.jit.script(model)

bs = 1
img1 = torch.rand(bs, 3, 80, 72).cuda()
img2 = torch.rand(bs, 3, 80, 72).cuda()

preds = model(img1, img2)
flow_pred = preds[-1]
# Tolerance is fairly high, but there are 2 * H * W outputs to check
# The .pkl were generated on the AWS cluter, on the CI it looks like the resuts are slightly different
_assert_expected(flow_pred, name=model_builder.__name__, atol=1e-2, rtol=1)


if __name__ == "__main__":
pytest.main([__file__])
1 change: 1 addition & 0 deletions torchvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .regnet import *
from . import detection
from . import feature_extraction
from . import optical_flow
from . import quantization
from . import segmentation
from . import video
1 change: 1 addition & 0 deletions torchvision/models/optical_flow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .raft import RAFT, raft_large, raft_small
45 changes: 45 additions & 0 deletions torchvision/models/optical_flow/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Optional

import torch
import torch.nn.functional as F
from torch import Tensor


def grid_sample(img: Tensor, absolute_grid: Tensor, mode: str = "bilinear", align_corners: Optional[bool] = None):
"""Same as torch's grid_sample, with absolute pixel coordinates instead of normalized coordinates."""
h, w = img.shape[-2:]

xgrid, ygrid = absolute_grid.split([1, 1], dim=-1)
xgrid = 2 * xgrid / (w - 1) - 1
ygrid = 2 * ygrid / (h - 1) - 1
normalized_grid = torch.cat([xgrid, ygrid], dim=-1)

return F.grid_sample(img, normalized_grid, mode=mode, align_corners=align_corners)


def make_coords_grid(batch_size: int, h: int, w: int):
coords = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch_size, 1, 1, 1)


def upsample_flow(flow, up_mask: Optional[Tensor] = None):
"""Upsample flow by a factor of 8.
If up_mask is None we just interpolate.
If up_mask is specified, we upsample using a convex combination of its weights. See paper page 8 and appendix B.
Note that in appendix B the picture assumes a downsample factor of 4 instead of 8.
"""
batch_size, _, h, w = flow.shape
new_h, new_w = h * 8, w * 8

if up_mask is None:
return 8 * F.interpolate(flow, size=(new_h, new_w), mode="bilinear", align_corners=True)

up_mask = up_mask.view(batch_size, 1, 9, 8, 8, h, w)
up_mask = torch.softmax(up_mask, dim=2) # "convex" == weights sum to 1

upsampled_flow = F.unfold(8 * flow, kernel_size=3, padding=1).view(batch_size, 2, 9, 1, 1, h, w)
upsampled_flow = torch.sum(up_mask * upsampled_flow, dim=2)

return upsampled_flow.permute(0, 1, 4, 2, 5, 3).reshape(batch_size, 2, new_h, new_w)
Loading

0 comments on commit 01ffb3a

Please sign in to comment.