diff --git a/pytorch3d/renderer/blending.py b/pytorch3d/renderer/blending.py index 0865547a6..8d50c52f1 100644 --- a/pytorch3d/renderer/blending.py +++ b/pytorch3d/renderer/blending.py @@ -90,7 +90,9 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor: return torch.flip(pixel_colors, [1]) -def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor: +def softmax_rgb_blend( + colors, fragments, blend_params, znear: float = 1.0, zfar: float = 100 +) -> torch.Tensor: """ RGB and alpha channel blending to return an RGBA image based on the method proposed in [0] @@ -118,6 +120,8 @@ def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor: exponential function used to control the opacity of the color. - background_color: (3) element list/tuple/torch.Tensor specifying the RGB values for the background color. + znear: float, near clipping plane in the z direction + zfar: float, far clipping plane in the z direction Returns: RGBA pixel_colors: (N, H, W, 4) @@ -125,6 +129,7 @@ def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor: [0] Shichen Liu et al, 'Soft Rasterizer: A Differentiable Renderer for Image-based 3D Reasoning' """ + N, H, W, K = fragments.pix_to_face.shape device = fragments.pix_to_face.device pix_colors = torch.ones( @@ -140,11 +145,6 @@ def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor: delta = np.exp(1e-10 / blend_params.gamma) * 1e-10 delta = torch.tensor(delta, device=device) - # Near and far clipping planes. - # TODO: add zfar/znear as input params. - zfar = 100.0 - znear = 1.0 - # Mask for padded pixels. mask = fragments.pix_to_face >= 0 @@ -164,6 +164,7 @@ def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor: # Weights for each face. Adjust the exponential by the max z to prevent # overflow. zbuf shape (N, H, W, K), find max over K. # TODO: there may still be some instability in the exponent calculation. + z_inv = (zfar - fragments.zbuf) / (zfar - znear) * mask z_inv_max = torch.max(z_inv, dim=-1).values[..., None] weights_num = prob_map * torch.exp((z_inv - z_inv_max) / blend_params.gamma) diff --git a/pytorch3d/renderer/mesh/__init__.py b/pytorch3d/renderer/mesh/__init__.py index 32e431e5f..3ac0e00a2 100644 --- a/pytorch3d/renderer/mesh/__init__.py +++ b/pytorch3d/renderer/mesh/__init__.py @@ -1,5 +1,10 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from .texturing import ( # isort:skip + interpolate_texture_map, + interpolate_vertex_colors, +) from .rasterize_meshes import rasterize_meshes from .rasterizer import MeshRasterizer, RasterizationSettings from .renderer import MeshRenderer @@ -13,10 +18,6 @@ TexturedSoftPhongShader, ) from .shading import gouraud_shading, phong_shading -from .texturing import ( # isort: skip - interpolate_face_attributes, - interpolate_texture_map, - interpolate_vertex_colors, -) +from .utils import interpolate_face_attributes __all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/pytorch3d/renderer/mesh/renderer.py b/pytorch3d/renderer/mesh/renderer.py index aeac515d7..18da3d34a 100644 --- a/pytorch3d/renderer/mesh/renderer.py +++ b/pytorch3d/renderer/mesh/renderer.py @@ -5,6 +5,9 @@ import torch import torch.nn as nn +from .rasterizer import Fragments +from .utils import _clip_barycentric_coordinates, _interpolate_zbuf + # A renderer class should be initialized with a # function for rasterization and a function for shading. # The rasterizer should: @@ -34,6 +37,34 @@ def __init__(self, rasterizer, shader): self.shader = shader def forward(self, meshes_world, **kwargs) -> torch.Tensor: + """ + Render a batch of images from a batch of meshes by rasterizing and then shading. + + NOTE: If the blur radius for rasterization is > 0.0, some pixels can have one or + more barycentric coordinates lying outside the range [0, 1]. For a pixel with + out of bounds barycentric coordinates with respect to a face f, clipping is required + before interpolating the texture uv coordinates and z buffer so that the colors and + depths are limited to the range for the corresponding face. + """ fragments = self.rasterizer(meshes_world, **kwargs) + raster_settings = kwargs.get( + "raster_settings", self.rasterizer.raster_settings + ) + if raster_settings.blur_radius > 0.0: + # TODO: potentially move barycentric clipping to the rasterizer + # if no downstream functions requires unclipped values. + # This will avoid unnecssary re-interpolation of the z buffer. + clipped_bary_coords = _clip_barycentric_coordinates( + fragments.bary_coords + ) + clipped_zbuf = _interpolate_zbuf( + fragments.pix_to_face, clipped_bary_coords, meshes_world + ) + fragments = Fragments( + bary_coords=clipped_bary_coords, + zbuf=clipped_zbuf, + dists=fragments.dists, + pix_to_face=fragments.pix_to_face, + ) images = self.shader(fragments, meshes_world, **kwargs) return images diff --git a/pytorch3d/renderer/mesh/shader.py b/pytorch3d/renderer/mesh/shader.py index 1fbae6a63..efeb4792c 100644 --- a/pytorch3d/renderer/mesh/shader.py +++ b/pytorch3d/renderer/mesh/shader.py @@ -270,6 +270,7 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) lights = kwargs.get("lights", self.lights) materials = kwargs.get("materials", self.materials) + blend_params = kwargs.get("blend_params", self.blend_params) colors = phong_shading( meshes=meshes, fragments=fragments, @@ -278,7 +279,7 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: cameras=cameras, materials=materials, ) - images = softmax_rgb_blend(colors, fragments, self.blend_params) + images = softmax_rgb_blend(colors, fragments, blend_params) return images diff --git a/pytorch3d/renderer/mesh/shading.py b/pytorch3d/renderer/mesh/shading.py index 69d1e305a..1b9effebc 100644 --- a/pytorch3d/renderer/mesh/shading.py +++ b/pytorch3d/renderer/mesh/shading.py @@ -70,8 +70,12 @@ def phong_shading( vertex_normals = meshes.verts_normals_packed() # (V, 3) faces_verts = verts[faces] faces_normals = vertex_normals[faces] - pixel_coords = interpolate_face_attributes(fragments, faces_verts) - pixel_normals = interpolate_face_attributes(fragments, faces_normals) + pixel_coords = interpolate_face_attributes( + fragments.pix_to_face, fragments.bary_coords, faces_verts + ) + pixel_normals = interpolate_face_attributes( + fragments.pix_to_face, fragments.bary_coords, faces_normals + ) ambient, diffuse, specular = _apply_lighting( pixel_coords, pixel_normals, lights, cameras, materials ) @@ -122,7 +126,9 @@ def gouraud_shading( ) verts_colors_shaded = vertex_colors * (ambient + diffuse) + specular face_colors = verts_colors_shaded[faces] - colors = interpolate_face_attributes(fragments, face_colors) + colors = interpolate_face_attributes( + fragments.pix_to_face, fragments.bary_coords, face_colors + ) return colors diff --git a/pytorch3d/renderer/mesh/texturing.py b/pytorch3d/renderer/mesh/texturing.py index b79dbe3e7..86f3a4424 100644 --- a/pytorch3d/renderer/mesh/texturing.py +++ b/pytorch3d/renderer/mesh/texturing.py @@ -7,75 +7,7 @@ from pytorch3d.structures.textures import Textures - -def _clip_barycentric_coordinates(bary) -> torch.Tensor: - """ - Args: - bary: barycentric coordinates of shape (...., 3) where `...` represents - an arbitrary number of dimensions - - Returns: - bary: All barycentric coordinate values clipped to the range [0, 1] - and renormalized. The output is the same shape as the input. - """ - if bary.shape[-1] != 3: - msg = "Expected barycentric coords to have last dim = 3; got %r" - raise ValueError(msg % bary.shape) - clipped = bary.clamp(min=0, max=1) - clipped_sum = torch.clamp(clipped.sum(dim=-1, keepdim=True), min=1e-5) - clipped = clipped / clipped_sum - return clipped - - -def interpolate_face_attributes( - fragments, face_attributes: torch.Tensor, bary_clip: bool = False -) -> torch.Tensor: - """ - Interpolate arbitrary face attributes using the barycentric coordinates - for each pixel in the rasterized output. - - Args: - fragments: - The outputs of rasterization. From this we use - - - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices - of the faces (in the packed representation) which - overlap each pixel in the image. - - barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying - the barycentric coordianates of each pixel - relative to the faces (in the packed - representation) which overlap the pixel. - face_attributes: packed attributes of shape (total_faces, 3, D), - specifying the value of the attribute for each - vertex in the face. - bary_clip: Bool to indicate if barycentric_coords should be clipped - before being used for interpolation. - - Returns: - pixel_vals: tensor of shape (N, H, W, K, D) giving the interpolated - value of the face attribute for each pixel. - """ - pix_to_face = fragments.pix_to_face - barycentric_coords = fragments.bary_coords - F, FV, D = face_attributes.shape - if FV != 3: - raise ValueError("Faces can only have three vertices; got %r" % FV) - N, H, W, K, _ = barycentric_coords.shape - if pix_to_face.shape != (N, H, W, K): - msg = "pix_to_face must have shape (batch_size, H, W, K); got %r" - raise ValueError(msg % pix_to_face.shape) - if bary_clip: - barycentric_coords = _clip_barycentric_coordinates(barycentric_coords) - - # Replace empty pixels in pix_to_face with 0 in order to interpolate. - mask = pix_to_face == -1 - pix_to_face = pix_to_face.clone() - pix_to_face[mask] = 0 - idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D) - pixel_face_vals = face_attributes.gather(0, idx).view(N, H, W, K, 3, D) - pixel_vals = (barycentric_coords[..., None] * pixel_face_vals).sum(dim=-2) - pixel_vals[mask] = 0 # Replace masked values in output. - return pixel_vals +from .utils import interpolate_face_attributes def interpolate_texture_map(fragments, meshes) -> torch.Tensor: @@ -97,8 +29,8 @@ def interpolate_texture_map(fragments, meshes) -> torch.Tensor: relative to the faces (in the packed representation) which overlap the pixel. meshes: Meshes representing a batch of meshes. It is expected that - meshes has a textures attribute which is an instance of the - Textures class. + meshes has a textures attribute which is an instance of the + Textures class. Returns: texels: tensor of shape (N, H, W, K, C) giving the interpolated @@ -114,7 +46,9 @@ def interpolate_texture_map(fragments, meshes) -> torch.Tensor: texture_maps = meshes.textures.maps_padded() # pixel_uvs: (N, H, W, K, 2) - pixel_uvs = interpolate_face_attributes(fragments, faces_verts_uvs) + pixel_uvs = interpolate_face_attributes( + fragments.pix_to_face, fragments.bary_coords, faces_verts_uvs + ) N, H_out, W_out, K = fragments.pix_to_face.shape N, H_in, W_in, C = texture_maps.shape # 3 for RGB @@ -178,5 +112,7 @@ def interpolate_vertex_colors(fragments, meshes) -> torch.Tensor: vertex_textures = vertex_textures[meshes.verts_padded_to_packed_idx(), :] faces_packed = meshes.faces_packed() faces_textures = vertex_textures[faces_packed] # (F, 3, C) - texels = interpolate_face_attributes(fragments, faces_textures) + texels = interpolate_face_attributes( + fragments.pix_to_face, fragments.bary_coords, faces_textures + ) return texels diff --git a/pytorch3d/renderer/mesh/utils.py b/pytorch3d/renderer/mesh/utils.py new file mode 100644 index 000000000..a82f10ca3 --- /dev/null +++ b/pytorch3d/renderer/mesh/utils.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + + +import torch + + +def _clip_barycentric_coordinates(bary) -> torch.Tensor: + """ + Args: + bary: barycentric coordinates of shape (...., 3) where `...` represents + an arbitrary number of dimensions + + Returns: + bary: Barycentric coordinates clipped (i.e any values < 0 are set to 0) + and renormalized. We only clip the negative values. Values > 1 will fall + into the [0, 1] range after renormalization. + The output is the same shape as the input. + """ + if bary.shape[-1] != 3: + msg = "Expected barycentric coords to have last dim = 3; got %r" + raise ValueError(msg % bary.shape) + clipped = bary.clamp(min=0.0) + clipped_sum = torch.clamp(clipped.sum(dim=-1, keepdim=True), min=1e-5) + clipped = clipped / clipped_sum + return clipped + + +def interpolate_face_attributes( + pix_to_face: torch.Tensor, + barycentric_coords: torch.Tensor, + face_attributes: torch.Tensor, +) -> torch.Tensor: + """ + Interpolate arbitrary face attributes using the barycentric coordinates + for each pixel in the rasterized output. + + Args: + pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices + of the faces (in the packed representation) which + overlap each pixel in the image. + barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying + the barycentric coordianates of each pixel + relative to the faces (in the packed + representation) which overlap the pixel. + face_attributes: packed attributes of shape (total_faces, 3, D), + specifying the value of the attribute for each + vertex in the face. + + Returns: + pixel_vals: tensor of shape (N, H, W, K, D) giving the interpolated + value of the face attribute for each pixel. + """ + F, FV, D = face_attributes.shape + if FV != 3: + raise ValueError("Faces can only have three vertices; got %r" % FV) + N, H, W, K, _ = barycentric_coords.shape + if pix_to_face.shape != (N, H, W, K): + msg = "pix_to_face must have shape (batch_size, H, W, K); got %r" + raise ValueError(msg % pix_to_face.shape) + + # Replace empty pixels in pix_to_face with 0 in order to interpolate. + mask = pix_to_face == -1 + pix_to_face = pix_to_face.clone() + pix_to_face[mask] = 0 + idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D) + pixel_face_vals = face_attributes.gather(0, idx).view(N, H, W, K, 3, D) + pixel_vals = (barycentric_coords[..., None] * pixel_face_vals).sum(dim=-2) + pixel_vals[mask] = 0 # Replace masked values in output. + return pixel_vals + + +def _interpolate_zbuf( + pix_to_face: torch.Tensor, barycentric_coords: torch.Tensor, meshes +) -> torch.Tensor: + """ + A helper function to calculate the z buffer for each pixel in the + rasterized output. + + Args: + pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices + of the faces (in the packed representation) which + overlap each pixel in the image. + barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying + the barycentric coordianates of each pixel + relative to the faces (in the packed + representation) which overlap the pixel. + meshes: Meshes object representing a batch of meshes. + + Returns: + zbuffer: (N, H, W, K) FloatTensor + """ + verts = meshes.verts_packed() + faces = meshes.faces_packed() + faces_verts_z = verts[faces][..., 2][..., None] # (F, 3, 1) + return interpolate_face_attributes( + pix_to_face, barycentric_coords, faces_verts_z + )[ + ..., 0 + ] # (1, H, W, K) diff --git a/tests/data/test_blurry_textured_rendering.png b/tests/data/test_blurry_textured_rendering.png new file mode 100644 index 000000000..5ab0e4b63 Binary files /dev/null and b/tests/data/test_blurry_textured_rendering.png differ diff --git a/tests/data/test_simple_sphere_light_flat.png b/tests/data/test_simple_sphere_light_flat.png new file mode 100644 index 000000000..f8573d6df Binary files /dev/null and b/tests/data/test_simple_sphere_light_flat.png differ diff --git a/tests/data/test_simple_sphere_light_flat_elevated_camera.png b/tests/data/test_simple_sphere_light_flat_elevated_camera.png new file mode 100644 index 000000000..e2a19c837 Binary files /dev/null and b/tests/data/test_simple_sphere_light_flat_elevated_camera.png differ diff --git a/tests/test_mesh_rendering_utils.py b/tests/test_mesh_rendering_utils.py new file mode 100644 index 000000000..91c61b8e1 --- /dev/null +++ b/tests/test_mesh_rendering_utils.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + + +import unittest +import torch + +from pytorch3d.renderer.mesh.utils import _clip_barycentric_coordinates + + +class TestMeshRenderingUtils(unittest.TestCase): + def test_bary_clip(self): + N = 10 + bary = torch.randn(size=(N, 3)) + # randomly make some values negative + bary[bary < 0.3] *= -1.0 + # randomly make some values be greater than 1 + bary[bary > 0.8] *= 2.0 + negative_mask = bary < 0.0 + positive_mask = bary > 1.0 + clipped = _clip_barycentric_coordinates(bary) + self.assertTrue(clipped[negative_mask].sum() == 0) + self.assertTrue(clipped[positive_mask].gt(1.0).sum() == 0) + self.assertTrue(torch.allclose(clipped.sum(dim=-1), torch.ones(N))) diff --git a/tests/test_rendering_meshes.py b/tests/test_rendering_meshes.py index e6c782d2b..93e7ebf75 100644 --- a/tests/test_rendering_meshes.py +++ b/tests/test_rendering_meshes.py @@ -25,6 +25,7 @@ from pytorch3d.renderer.mesh.renderer import MeshRenderer from pytorch3d.renderer.mesh.shader import ( BlendParams, + HardFlatShader, HardGouraudShader, HardPhongShader, SoftSilhouetteShader, @@ -99,8 +100,9 @@ def test_simple_sphere(self, elevated_camera=False): images = renderer(sphere_mesh) rgb = images[0, ..., :3].squeeze().cpu() if DEBUG: + filename = "DEBUG_simple_sphere_light%s.png" % postfix Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( - DATA_DIR / "DEBUG_simple_sphere_light%s.png" % postfix + DATA_DIR / filename ) # Load reference image @@ -117,8 +119,9 @@ def test_simple_sphere(self, elevated_camera=False): images = renderer(sphere_mesh, lights=lights) rgb = images[0, ..., :3].squeeze().cpu() if DEBUG: + filename = "DEBUG_simple_sphere_dark%s.png" % postfix Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( - DATA_DIR / "DEBUG_simple_sphere_dark%s.png" % postfix + DATA_DIR / filename ) # Load reference image @@ -140,8 +143,9 @@ def test_simple_sphere(self, elevated_camera=False): images = renderer(sphere_mesh) rgb = images[0, ..., :3].squeeze().cpu() if DEBUG: + filename = "DEBUG_simple_sphere_light_gourad%s.png" % postfix Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( - DATA_DIR / "DEBUG_simple_sphere_light_gouraud%s.png" % postfix + DATA_DIR / filename ) # Load reference image @@ -149,7 +153,30 @@ def test_simple_sphere(self, elevated_camera=False): "test_simple_sphere_light_gouraud%s.png" % postfix ) self.assertTrue(torch.allclose(rgb, image_ref_gouraud, atol=0.005)) - self.assertFalse(torch.allclose(rgb, image_ref_phong, atol=0.005)) + + ###################################### + # Change the shader to a HardFlatShader + ###################################### + lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None] + renderer = MeshRenderer( + rasterizer=rasterizer, + shader=HardFlatShader( + lights=lights, cameras=cameras, materials=materials + ), + ) + images = renderer(sphere_mesh) + rgb = images[0, ..., :3].squeeze().cpu() + if DEBUG: + filename = "DEBUG_simple_sphere_light_flat%s.png" % postfix + Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / filename + ) + + # Load reference image + image_ref_flat = load_rgb_image( + "test_simple_sphere_light_flat%s.png" % postfix + ) + self.assertTrue(torch.allclose(rgb, image_ref_flat, atol=0.005)) def test_simple_sphere_elevated_camera(self): """ @@ -287,9 +314,6 @@ def test_texture_map(self): materials = Materials(device=device) lights = PointLights(device=device) lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None] - raster_settings = RasterizationSettings( - image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0 - ) # Init renderer renderer = MeshRenderer( @@ -327,3 +351,32 @@ def test_texture_map(self): images = renderer(mesh2) images[0, ...].sum().backward() self.assertIsNotNone(verts.grad) + + ################################# + # Add blurring to rasterization + ################################# + + blend_params = BlendParams(sigma=5e-4, gamma=1e-4) + raster_settings = RasterizationSettings( + image_size=512, + blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma, + faces_per_pixel=100, + bin_size=0, + ) + + images = renderer( + mesh.clone(), + raster_settings=raster_settings, + blend_params=blend_params, + ) + rgb = images[0, ..., :3].squeeze().cpu() + + # Load reference image + image_ref = load_rgb_image("test_blurry_textured_rendering.png") + + if DEBUG: + Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / "DEBUG_blurry_textured_rendering.png" + ) + + self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05)) diff --git a/tests/test_utils.py b/tests/test_rendering_utils.py similarity index 100% rename from tests/test_utils.py rename to tests/test_rendering_utils.py diff --git a/tests/test_texturing.py b/tests/test_texturing.py index cd3399db9..aea04553b 100644 --- a/tests/test_texturing.py +++ b/tests/test_texturing.py @@ -8,7 +8,6 @@ from pytorch3d.renderer.mesh.rasterizer import Fragments from pytorch3d.renderer.mesh.texturing import ( - _clip_barycentric_coordinates, interpolate_face_attributes, interpolate_texture_map, interpolate_vertex_colors, @@ -94,7 +93,9 @@ def test_interpolate_face_attributes_fail(self): dists=pix_to_face, ) with self.assertRaises(ValueError): - interpolate_face_attributes(fragments, face_attributes) + interpolate_face_attributes( + fragments.pix_to_face, fragments.bary_coords, face_attributes + ) # 2. pix_to_face must have shape (N, H, W, K) pix_to_face = torch.ones((1, 1, 1, 1, 3)) @@ -105,7 +106,9 @@ def test_interpolate_face_attributes_fail(self): dists=pix_to_face, ) with self.assertRaises(ValueError): - interpolate_face_attributes(fragments, face_attributes) + interpolate_face_attributes( + fragments.pix_to_face, fragments.bary_coords, face_attributes + ) def test_interpolate_texture_map(self): barycentric_coords = torch.tensor( @@ -220,13 +223,3 @@ def test_extend(self): ) with self.assertRaises(ValueError): tex_mesh.extend(N=-1) - - def test_clip_barycentric_coords(self): - barycentric_coords = torch.tensor( - [[1.5, -0.3, -0.2], [1.2, 0.3, -0.5]], dtype=torch.float32 - ) - expected_out = torch.tensor( - [[1.0, 0.0, 0.0], [1.0 / 1.3, 0.3 / 1.3, 0.0]], dtype=torch.float32 - ) - clipped = _clip_barycentric_coordinates(barycentric_coords) - self.assertTrue(torch.allclose(clipped, expected_out))