Skip to content

Commit

Permalink
Add visibility parameter to draw_keypoints() (#8225)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hug <[email protected]>
Co-authored-by: Nicolas Hug <[email protected]>
  • Loading branch information
3 people authored Feb 6, 2024
1 parent 36d0e3e commit ae14789
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 11 deletions.
62 changes: 61 additions & 1 deletion gallery/others/plot_visualization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def show(imgs):
show(res)

# %%
# As we see the keypoints appear as colored circles over the image.
# As we see, the keypoints appear as colored circles over the image.
# The coco keypoints for a person are ordered and represent the following list.\

coco_keypoints = [
Expand Down Expand Up @@ -460,3 +460,63 @@ def show(imgs):

res = draw_keypoints(person_int, keypoints, connectivity=connect_skeleton, colors="blue", radius=4, width=3)
show(res)

# %%
# That looks pretty good.
#
# .. _draw_keypoints_with_visibility:
#
# Drawing Keypoints with Visibility
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Let's have a look at the results, another keypoint prediction module produced, and show the connectivity:

prediction = torch.tensor(
[[[208.0176, 214.2409, 1.0000],
[000.0000, 000.0000, 0.0000],
[197.8246, 210.6392, 1.0000],
[000.0000, 000.0000, 0.0000],
[178.6378, 217.8425, 1.0000],
[221.2086, 253.8591, 1.0000],
[160.6502, 269.4662, 1.0000],
[243.9929, 304.2822, 1.0000],
[138.4654, 328.8935, 1.0000],
[277.5698, 340.8990, 1.0000],
[153.4551, 374.5145, 1.0000],
[000.0000, 000.0000, 0.0000],
[226.0053, 370.3125, 1.0000],
[221.8081, 455.5516, 1.0000],
[273.9723, 448.9486, 1.0000],
[193.6275, 546.1933, 1.0000],
[273.3727, 545.5930, 1.0000]]]
)

res = draw_keypoints(person_int, prediction, connectivity=connect_skeleton, colors="blue", radius=4, width=3)
show(res)

# %%
# What happened there?
# The model, which predicted the new keypoints,
# can't detect the three points that are hidden on the upper left body of the skateboarder.
# More precisely, the model predicted that `(x, y, vis) = (0, 0, 0)` for the left_eye, left_ear, and left_hip.
# So we definitely don't want to display those keypoints and connections, and you don't have to.
# Looking at the parameters of :func:`~torchvision.utils.draw_keypoints`,
# we can see that we can pass a visibility tensor as an additional argument.
# Given the models' prediction, we have the visibility as the third keypoint dimension, we just need to extract it.
# Let's split the ``prediction`` into the keypoint coordinates and their respective visibility,
# and pass both of them as arguments to :func:`~torchvision.utils.draw_keypoints`.

coordinates, visibility = prediction.split([2, 1], dim=-1)
visibility = visibility.bool()

res = draw_keypoints(
person_int, coordinates, visibility=visibility, connectivity=connect_skeleton, colors="blue", radius=4, width=3
)
show(res)

# %%
# We can see that the undetected keypoints are not draw and the invisible keypoint connections were skipped.
# This can reduce the noise on images with multiple detections, or in cases like ours,
# when the keypoint-prediction model missed some detections.
# Most torch keypoint-prediction models return the visibility for every prediction, ready for you to use it.
# The :func:`~torchvision.models.detection.keypointrcnn_resnet50_fpn` model,
# which we used in the first case, does so too.
Binary file added test/assets/fakedata/draw_keypoints_visibility.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
83 changes: 83 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,77 @@ def test_draw_keypoints_colored(colors):
assert_equal(img, img_cp)


@pytest.mark.parametrize("connectivity", [[(0, 1)], [(0, 1), (1, 2)]])
@pytest.mark.parametrize(
"vis",
[
torch.tensor([[1, 1, 0], [1, 1, 0]], dtype=torch.bool),
torch.tensor([[1, 1, 0], [1, 1, 0]], dtype=torch.float).unsqueeze_(-1),
],
)
def test_draw_keypoints_visibility(connectivity, vis):
# Keypoints is declared on top as global variable
keypoints_cp = keypoints.clone()

img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()

vis_cp = vis if vis is None else vis.clone()

result = utils.draw_keypoints(
image=img,
keypoints=keypoints,
connectivity=connectivity,
colors="red",
visibility=vis,
)
assert result.size(0) == 3
assert_equal(keypoints, keypoints_cp)
assert_equal(img, img_cp)

# compare with a fakedata image
# connect the key points 0 to 1 for both skeletons and do not show the other key points
path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoints_visibility.png"
)
if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)

expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
assert_equal(result, expected)

if vis_cp is None:
assert vis is None
else:
assert_equal(vis, vis_cp)
assert vis.dtype == vis_cp.dtype


def test_draw_keypoints_visibility_default():
# Keypoints is declared on top as global variable
keypoints_cp = keypoints.clone()

img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()

result = utils.draw_keypoints(
image=img,
keypoints=keypoints,
connectivity=[(0, 1)],
colors="red",
visibility=None,
)
assert result.size(0) == 3
assert_equal(keypoints, keypoints_cp)
assert_equal(img, img_cp)

# compare against fakedata image, which connects 0->1 for both key-point skeletons
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png")
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
assert_equal(result, expected)


def test_draw_keypoints_errors():
h, w = 10, 10
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
Expand All @@ -379,6 +450,18 @@ def test_draw_keypoints_errors():
with pytest.raises(ValueError, match="keypoints must be of shape"):
invalid_keypoints = torch.tensor([[10, 10, 10, 10], [5, 6, 7, 8]], dtype=torch.float)
utils.draw_keypoints(image=img, keypoints=invalid_keypoints)
with pytest.raises(ValueError, match=re.escape("visibility must be of shape (num_instances, K)")):
one_dim_visibility = torch.tensor([True, True, True], dtype=torch.bool)
utils.draw_keypoints(image=img, keypoints=keypoints, visibility=one_dim_visibility)
with pytest.raises(ValueError, match=re.escape("visibility must be of shape (num_instances, K)")):
three_dim_visibility = torch.ones((2, 3, 4), dtype=torch.bool)
utils.draw_keypoints(image=img, keypoints=keypoints, visibility=three_dim_visibility)
with pytest.raises(ValueError, match="keypoints and visibility must have the same dimensionality"):
vis_wrong_n = torch.ones((3, 3), dtype=torch.bool)
utils.draw_keypoints(image=img, keypoints=keypoints, visibility=vis_wrong_n)
with pytest.raises(ValueError, match="keypoints and visibility must have the same dimensionality"):
vis_wrong_k = torch.ones((2, 4), dtype=torch.bool)
utils.draw_keypoints(image=img, keypoints=keypoints, visibility=vis_wrong_k)


@pytest.mark.parametrize("batch", (True, False))
Expand Down
56 changes: 46 additions & 10 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,29 +331,44 @@ def draw_keypoints(
colors: Optional[Union[str, Tuple[int, int, int]]] = None,
radius: int = 2,
width: int = 3,
visibility: Optional[torch.Tensor] = None,
) -> torch.Tensor:

"""
Draws Keypoints on given RGB image.
The values of the input image should be uint8 between 0 and 255.
Keypoints can be drawn for multiple instances at a time.
This method allows that keypoints and their connectivity are drawn based on the visibility of this keypoint.
Args:
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances,
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoint locations for each of the N instances,
in the format [x, y].
connectivity (List[Tuple[int, int]]]): A List of tuple where,
each tuple contains pair of keypoints to be connected.
connectivity (List[Tuple[int, int]]]): A List of tuple where each tuple contains a pair of keypoints
to be connected.
If at least one of the two connected keypoints has a ``visibility`` of False,
this specific connection is not drawn.
Exclusions due to invisibility are computed per-instance.
colors (str, Tuple): The color can be represented as
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
radius (int): Integer denoting radius of keypoint.
width (int): Integer denoting width of line connecting keypoints.
visibility (Tensor): Tensor of shape (num_instances, K) specifying the visibility of the K
keypoints for each of the N instances.
True means that the respective keypoint is visible and should be drawn.
False means invisible, so neither the point nor possible connections containing it are drawn.
The input tensor will be cast to bool.
Default ``None`` means that all the keypoints are visible.
For more details, see :ref:`draw_keypoints_with_visibility`.
Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn.
"""

if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(draw_keypoints)
# validate image
if not isinstance(image, torch.Tensor):
raise TypeError(f"The image must be a tensor, got {type(image)}")
elif image.dtype != torch.uint8:
Expand All @@ -363,24 +378,45 @@ def draw_keypoints(
elif image.size()[0] != 3:
raise ValueError("Pass an RGB image. Other Image formats are not supported")

# validate keypoints
if keypoints.ndim != 3:
raise ValueError("keypoints must be of shape (num_instances, K, 2)")

# validate visibility
if visibility is None: # set default
visibility = torch.ones(keypoints.shape[:-1], dtype=torch.bool)
# If the last dimension is 1, e.g., after calling split([2, 1], dim=-1) on the output of a keypoint-prediction
# model, make sure visibility has shape (num_instances, K).
# Iff K = 1, this has unwanted behavior, but K=1 does not really make sense in the first place.
visibility = visibility.squeeze(-1)
if visibility.ndim != 2:
raise ValueError(f"visibility must be of shape (num_instances, K). Got ndim={visibility.ndim}")
if visibility.shape != keypoints.shape[:-1]:
raise ValueError(
"keypoints and visibility must have the same dimensionality for num_instances and K. "
f"Got {visibility.shape = } and {keypoints.shape = }"
)

ndarr = image.permute(1, 2, 0).cpu().numpy()
img_to_draw = Image.fromarray(ndarr)
draw = ImageDraw.Draw(img_to_draw)
img_kpts = keypoints.to(torch.int64).tolist()

for kpt_id, kpt_inst in enumerate(img_kpts):
for inst_id, kpt in enumerate(kpt_inst):
x1 = kpt[0] - radius
x2 = kpt[0] + radius
y1 = kpt[1] - radius
y2 = kpt[1] + radius
img_vis = visibility.cpu().bool().tolist()

for kpt_inst, vis_inst in zip(img_kpts, img_vis):
for kpt_coord, kp_vis in zip(kpt_inst, vis_inst):
if not kp_vis:
continue
x1 = kpt_coord[0] - radius
x2 = kpt_coord[0] + radius
y1 = kpt_coord[1] - radius
y2 = kpt_coord[1] + radius
draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0)

if connectivity:
for connection in connectivity:
if (not vis_inst[connection[0]]) or (not vis_inst[connection[1]]):
continue
start_pt_x = kpt_inst[connection[0]][0]
start_pt_y = kpt_inst[connection[0]][1]

Expand Down

0 comments on commit ae14789

Please sign in to comment.