Skip to content

Commit

Permalink
cleanup legacy transforms tests (#8013)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Oct 5, 2023
1 parent af3077e commit 67f3ce2
Show file tree
Hide file tree
Showing 7 changed files with 5,371 additions and 5,757 deletions.
9 changes: 4 additions & 5 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ def make_bounding_boxes(
canvas_size=DEFAULT_SIZE,
*,
format=tv_tensors.BoundingBoxFormat.XYXY,
num_boxes=1,
dtype=None,
device="cpu",
):
Expand All @@ -419,8 +420,7 @@ def sample_position(values, max_value):

dtype = dtype or torch.float32

num_objects = 1
h, w = [torch.randint(1, s, (num_objects,)) for s in canvas_size]
h, w = [torch.randint(1, s, (num_boxes,)) for s in canvas_size]
y = sample_position(h, canvas_size[0])
x = sample_position(w, canvas_size[1])

Expand All @@ -443,12 +443,11 @@ def sample_position(values, max_value):
)


def make_detection_mask(size=DEFAULT_SIZE, *, dtype=None, device="cpu"):
def make_detection_masks(size=DEFAULT_SIZE, *, num_masks=1, dtype=None, device="cpu"):
"""Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks"""
num_objects = 1
return tv_tensors.Mask(
torch.testing.make_tensor(
(num_objects, *size),
(num_masks, *size),
low=0,
high=2,
dtype=dtype or torch.bool,
Expand Down
Loading

0 comments on commit 67f3ce2

Please sign in to comment.