diff --git a/discussion/robust_inverse_graphics/nerf/rendering.py b/discussion/robust_inverse_graphics/nerf/rendering.py new file mode 100644 index 0000000000..7c48c9d7e5 --- /dev/null +++ b/discussion/robust_inverse_graphics/nerf/rendering.py @@ -0,0 +1,397 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Radiance field rendering.""" + +from collections.abc import Sequence +import functools +import operator +from typing import Any, Callable, NamedTuple, Optional + +import chex +from flax import struct +import jax +import jax.numpy as jnp +import numpy as np +from camp_zipnerf.internal import coord as mipnerf_coord +from camp_zipnerf.internal import math as mipnerf_math +from camp_zipnerf.internal import render as mipnerf_render +from camp_zipnerf.internal import stepfun as mipnerf_stepfun + +__all__ = [ + 'Ray', + 'RaySample', + 'ray_trace_mip_radiance_field', + 'RayTraceMipExtra', +] + +# World-space coordinates. Shape: [3] +WorldSpace = jnp.ndarray +# Density at point (positive real). Shape: [] +Density = jnp.ndarray +# Color at a point (elements in [0, 1]). Shape: [3] +RGB = jnp.ndarray + + +class RaySample(NamedTuple): + """A sample from a ray. + + Attributes: + position: Sample position. Shape: [3] + viewdir: View directions (from camera). Shape: [3] + covariance: Covariance of a Gaussian fit to the truncated cone associated + with this ray sample. Shape: [3, 3]. + """ + position: WorldSpace + viewdir: jnp.ndarray + covariance: Optional[jnp.ndarray] = None + + @classmethod + def test_sample(cls) -> 'RaySample': + return cls( + position=jnp.zeros(3), viewdir=jnp.ones(3), covariance=jnp.eye(3) + ) + + +class Ray(NamedTuple): + """A ray. + + Attributes: + origin: Origin of the ray. + direction: Un-normalized direction of the ray. The magnitude encodes how + fast to move along the ray. + viewdir: The normalized direction of the ray. + radius: Radius of the truncated cone associated with the ray at the origin. + """ + origin: WorldSpace + direction: jnp.ndarray + viewdir: jnp.ndarray + radius: Optional[jnp.ndarray] = None + + +@struct.dataclass +class RayTraceMipLevelExtra: + """Extra-outputs from ray_trace_mip_radiance_field for each sampling level. + + `num_rfs` is the number of radiance fields, and could be an empty dimension. + + Attributes: + z_vals: Values between the `near` and `far` which span the truncated cones + that are then fit with gaussians. Shape: [num_samples + 1] + means: Means of the gaussians where `rf_fn` was evaluated. Shape: + [num_samples, 3] + covariances: Covariances of the gaussians where `rf_fn` was evaluated. + Shape: [num_samples, 3, 3] + density: `rf_fn` density at each sample point: Shape: [num_samples, num_rfs] + composed_density: Composed density. Shape: [num_samples] + rgb: RGB outputs for this level. Shape: [3] + acc: Accumulated alpha, pre-composed. Shape: [num_rfs] + composed_acc: Accumulated alpha. Shape: [] + mean_distance: Mean distance from the camera to the scattering point along + the ray. Shape: [] + p95_distance: 95th percentile distance from the camera to the scattering + point along the ray. Shape: [] + extra: Extra outputs from `rf_fn`, for debugging and other uses. + """ + z_vals: jnp.ndarray + means: WorldSpace + covariances: jnp.ndarray + + density: Density + composed_density: Density + + rgb: RGB + acc: jnp.ndarray + composed_acc: jnp.ndarray + mean_distance: jnp.ndarray + p95_distance: jnp.ndarray + + extra: Any + + +@struct.dataclass +class RayTraceMipExtra(RayTraceMipLevelExtra): + """Extra-outputs from ray_trace_mip_radiance_field. + + `num_rfs` is the number of radiance fields, and could be an empty dimension. + + Attributes: + levels: Extra-outputs for each sampling level. The remaining fields + correspond to the last sampling level. + """ + + levels: Sequence[RayTraceMipLevelExtra] + + +def ray_trace_mip_radiance_field( + rf_fn: Callable[[RaySample], tuple[tuple[Density, RGB], Any]], + ray: Ray, + num_samples: Sequence[int], + near: float, + far: float, + jitter: str = 'correlated', + ray_shape: str = 'cone', + background_color: jnp.ndarray = np.ones(3, np.float32), + weight_bias: float = 0.01, + epsilon: float = 1e-10, + ray_warp_fn: str | Callable[[jax.Array], jax.Array] | None = None, + seed: Optional[jax.Array] = None, +) -> tuple[RGB, RayTraceMipExtra]: + """Given a single ray, and a radiance field, computes the color. + + `rf_fn` is a mapping `(ray_sample) -> ((density, rgb), extra)`. + + `rf_fn` is allowed to return values with a leading dimension, which is taken + to indicate multiple RFs. Their colors and densities will be composed by + summing the densities and blending the colors appropriately. + + The ray samples are generated by sampling a `z_val` from a distribution over + `[near, far]` and computing a truncated cone boundaries as: + `position = z_val * ray_direction + ray_origin`. + + This function uses multi-level sampling, where the `rf_fn` is first evaluated + on a set of coarse positions and the densities computed at those locations are + used to construct an improved sampling distribution. The number of levels is + determined by the length of the `num_samples` argument. + + This function is intended to be used with Integrated Positional Encoding from + [1]. It computes both positions(means) and covariances of gaussians that + approximate the truncated cones along the ray. + + Args: + rf_fn: Radiance field function. + ray: The ray. + num_samples: A sequence of numbers of samples for each level. + near: Minimum z-val. + far: Maximum z-val. + jitter: Type of jitter to use. Can be `stratified`, `correlated` or `none`. + This corresponds to generating independent samples within each bucket, + generating the same sample for each bucket and no jitter at all. + ray_shape: Can be either `cone` or `cylinder`. + background_color: Background color to use. + weight_bias: Bias to add to the sampling weights, so as to sample from the + empty regions implied by `rf_fn` to make sure they stay empty. + epsilon: Used to stabilize composition of multiple radiance fields. + ray_warp_fn: Ray warp function. See Equation 11 in + https://arxiv.org/abs/2111.12077. + seed: Optional random seed used for jitter. Can be omitted if using the + `none` jitter. + + Returns: + rgb: Composed color. + extra: An array of `RayTraceMipExtra`, one per level. + + #### References + + 1. Barron, J. T., Mildenhall, B., Tancik, M., Hedman, P., Martin-Brualla, R., + & Srinivasan, P. P. (2021). Mip-NeRF: A Multiscale Representation for + Anti-Aliasing Neural Radiance Fields. In arXiv [cs.CV]. arXiv. + http://arxiv.org/abs/2103.13415 + """ + # We warp the [near, far] range into [0, 1], with the linear re-scaling by + # default. + _, s_to_z = mipnerf_coord.construct_ray_warps(ray_warp_fn, near, far) + s_vals = jnp.array([0.0, 1.0]) + weights = jnp.ones(1) + + extras = [] + + for level_num_samples in num_samples: + # Fix logits for samples which have equal values. + logits_resample = jnp.where( + s_vals[..., 1:] > s_vals[..., :-1], + mipnerf_math.safe_log(weights + weight_bias), + -jnp.inf, + ) + + match jitter: + case 'none': + sample_seed = None + single_jitter = False + case 'correlated': + sample_seed, seed = jax.random.split(seed) + single_jitter = True + case 'stratified': + sample_seed, seed = jax.random.split(seed) + single_jitter = False + case _: + raise ValueError(f'Unknown jitter type: {jitter}') + + # Sample values along the ray. + s_vals = mipnerf_stepfun.sample_intervals( + sample_seed, + s_vals, + w_logits=logits_resample, + num_samples=level_num_samples, + single_jitter=single_jitter, + domain=(0.0, 1.0), + ) + + # For training stability. + # TODO(siege): Make this configurable? + s_vals = jax.lax.stop_gradient(s_vals) + + z_vals = s_to_z(s_vals) + + # Compute the means/covariances of the approximating gaussians. + means, covs = mipnerf_render.cast_rays( + tdist=z_vals, + origins=ray.origin, + directions=ray.direction, + radii=ray.radius, + ray_shape=ray_shape, + diag=False, + ) + + # Evaluate `rf_fn` at the gaussian. + (density, rgb), extra = jax.vmap( + lambda pos, cov: rf_fn( # pylint: disable=g-long-lambda + RaySample(position=pos, covariance=cov, viewdir=ray.viewdir) + ) + )(means, covs) + + # Compose across RFs, if necessary. + if len(density.shape) == 1: + rf_weights = density + composed_density = density + composed_rgb = rgb + else: + chex.assert_rank(density, 2) + + rf_weights = density / (density.sum(1, keepdims=True) + epsilon) + composed_density = density.sum(1) + composed_rgb = (rgb * rf_weights[..., jnp.newaxis]).sum(1) + + # Perform volumetric rendering. The weights computed here are used to inform + # future levels. + weights = mipnerf_render.compute_alpha_weights( + density=composed_density, + tdist=z_vals, + dirs=ray.direction, + ) + + rendering = mipnerf_render.volumetric_rendering( + rgbs=composed_rgb, + weights=weights, + tdist=z_vals, + bg_rgbs=jnp.asarray(background_color), + compute_extras=True, + ) + + composed_acc = rendering['acc'] + if len(density.shape) == 1: + acc = composed_acc + else: + acc = (weights[..., jnp.newaxis] * rf_weights).sum(0) + + rt_extra = RayTraceMipLevelExtra( + z_vals=z_vals, + means=means, + covariances=covs, + composed_density=composed_density, + density=density, + rgb=rendering['rgb'], + acc=acc, + composed_acc=composed_acc, + mean_distance=rendering['distance_mean'], + p95_distance=rendering['distance_percentile_95'], + extra=extra, + ) + extras.append(rt_extra) + + return extras[-1].rgb, RayTraceMipExtra( + levels=extras, + z_vals=extras[-1].z_vals, + means=extras[-1].means, + covariances=extras[-1].covariances, + composed_density=extras[-1].composed_density, + density=extras[-1].density, + rgb=extras[-1].rgb, + acc=extras[-1].acc, + composed_acc=extras[-1].composed_acc, + mean_distance=extras[-1].mean_distance, + p95_distance=extras[-1].p95_distance, + extra=extras[-1].extra, + ) + + +@functools.partial( + jax.jit, static_argnames=('rf_fn', 'num_samples', 'ray_warp_fn') +) +def render_rf( + rf_fn: Callable[ + [RaySample, Any], + tuple[tuple[Density, RGB], Any], + ], + rays: Ray, + near: float, + far: float, + num_samples: Sequence[int], + seed: jax.Array, + ray_warp_fn: str | Callable[[jax.Array], jax.Array] | None = None, + cond_kwargs: Any | None = None, + **kwargs: Any, +) -> tuple[RGB, Any]: + """Render a radiance field using Mip rendering. + + This is a thin wrapper around `ray_trace_mip_radiance_field` to handle + arbitrary batch shapes and do proper seed handling. + + Args: + rf_fn: Radiance field. + rays: Rays. + near: Near plane. + far: Far plane. + num_samples: See `ray_trace_mip_radiance_field`. + seed: Random seed. + ray_warp_fn: Ray warp function. See Equation 11 in + https://arxiv.org/abs/2111.12077. + cond_kwargs: If given, conditional data from which to derive rf_fn. + **kwargs: Passed to `ray_trace_mip_radiance_field`. + + Returns: + Same as `ray_trace_mip_radiance_field`. + """ + if cond_kwargs is None: + cond_kwargs = {} + batch_shape = rays.origin.shape[:-1] + batch_ndims = len(batch_shape) + batch_size = functools.reduce(operator.mul, batch_shape) + + tree_flatten = functools.partial( + jax.tree.map, lambda x: x.reshape((-1,) + x.shape[batch_ndims:]) + ) + + @jax.vmap + def _render(ray, cond_kwargs, seed): + return ray_trace_mip_radiance_field( + rf_fn=functools.partial(rf_fn, **cond_kwargs), + ray=ray, + num_samples=num_samples, + near=near, + far=far, + seed=seed, + ray_warp_fn=ray_warp_fn, + **kwargs, + ) + + rgb, extra = _render( + tree_flatten(rays), + tree_flatten(cond_kwargs), + jax.random.split(seed, batch_size), + ) + + return jax.tree.map( + lambda x: x.reshape(batch_shape + x.shape[1:]), (rgb, extra) + ) diff --git a/discussion/robust_inverse_graphics/nerf/rendering_test.py b/discussion/robust_inverse_graphics/nerf/rendering_test.py new file mode 100644 index 0000000000..fd2232f06e --- /dev/null +++ b/discussion/robust_inverse_graphics/nerf/rendering_test.py @@ -0,0 +1,183 @@ +# Copyright 2024 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for rendering.""" + +import jax +import jax.numpy as jnp +import numpy as np +from discussion.robust_inverse_graphics.nerf import rendering +from discussion.robust_inverse_graphics.util import test_util + + +class RenderingTest(test_util.TestCase): + + def testRenderMip(self): + num_samples = [6, 4] + + def rf_fn(ray_sample): + self.assertEqual(ray_sample.covariance.shape, [3, 3]) + return (0.0, np.zeros(3)), () + + ray = rendering.Ray( + origin=np.zeros(3), + direction=np.ones(3), + viewdir=np.ones(3) / np.sqrt(3), + radius=np.array([0.1]), + ) + + @jax.jit + def run(seed): + return rendering.ray_trace_mip_radiance_field( + rf_fn=rf_fn, + ray=ray, + num_samples=num_samples, + near=1.0, + far=2.0, + seed=seed, + ) + + rgb, extra = run(jax.random.PRNGKey(0)) + + self.assertEqual([3], rgb.shape) + + for level, (ns, extra_l) in enumerate(zip(num_samples, extra.levels)): + self.assertEqual([ns + 1], extra_l.z_vals.shape, msg=f'{level=}') + self.assertEqual([ns, 3], extra_l.means.shape, msg=f'{level=}') + self.assertEqual([ns, 3, 3], extra_l.covariances.shape, msg=f'{level=}') + self.assertEqual([ns], extra_l.density.shape, msg=f'{level=}') + self.assertEqual([ns], extra_l.composed_density.shape, msg=f'{level=}') + + self.assertEqual([], extra_l.acc.shape) + self.assertEqual([], extra_l.mean_distance.shape) + self.assertEqual([], extra_l.p95_distance.shape) + self.assertEqual([], extra_l.composed_acc.shape) + + ns = num_samples[-1] + self.assertEqual([ns + 1], extra.z_vals.shape, msg='last level') + self.assertEqual([ns, 3], extra.means.shape, msg='last level') + self.assertEqual([ns, 3, 3], extra.covariances.shape, msg='last level') + self.assertEqual([ns], extra.density.shape, msg='last level') + self.assertEqual([ns], extra.composed_density.shape, msg='last level') + + self.assertEqual([], extra.acc.shape) + self.assertEqual([], extra.mean_distance.shape) + self.assertEqual([], extra.p95_distance.shape) + self.assertEqual([], extra.composed_acc.shape) + + def testRenderMipComposed(self): + num_nerfs = 3 + num_samples = [6, 4] + + def rf_fn(ray_sample): + self.assertEqual(ray_sample.covariance.shape, [3, 3]) + return (np.zeros(num_nerfs), np.zeros((num_nerfs, 3))), () + + ray = rendering.Ray( + origin=np.zeros(3), + direction=np.ones(3), + viewdir=np.ones(3) / np.sqrt(3), + radius=np.array([0.1]), + ) + + @jax.jit + def run(seed): + return rendering.ray_trace_mip_radiance_field( + rf_fn=rf_fn, + ray=ray, + num_samples=num_samples, + near=1.0, + far=2.0, + seed=seed, + ) + + rgb, extra = run(jax.random.PRNGKey(0)) + + self.assertEqual([3], rgb.shape) + + for level, (ns, extra_l) in enumerate(zip(num_samples, extra.levels)): + self.assertEqual([ns + 1], extra_l.z_vals.shape, msg=f'{level=}') + self.assertEqual([ns, 3], extra_l.means.shape, msg=f'{level=}') + self.assertEqual([ns, 3, 3], extra_l.covariances.shape, msg=f'{level=}') + self.assertEqual([ns, num_nerfs], extra_l.density.shape, msg=f'{level=}') + self.assertEqual([ns], extra_l.composed_density.shape, msg=f'{level=}') + + self.assertEqual([num_nerfs], extra_l.acc.shape) + self.assertEqual([], extra_l.mean_distance.shape) + self.assertEqual([], extra_l.p95_distance.shape) + self.assertEqual([], extra_l.composed_acc.shape) + + ns = num_samples[-1] + self.assertEqual([ns + 1], extra.z_vals.shape, msg='last level') + self.assertEqual([ns, 3], extra.means.shape, msg='last level') + self.assertEqual([ns, 3, 3], extra.covariances.shape, msg='last level') + self.assertEqual([ns, num_nerfs], extra.density.shape, msg='last level') + self.assertEqual([ns], extra.composed_density.shape, msg='last level') + + self.assertEqual([num_nerfs], extra.acc.shape) + self.assertEqual([], extra.mean_distance.shape) + self.assertEqual([], extra.p95_distance.shape) + self.assertEqual([], extra.composed_acc.shape) + + def test_render_fn(self): + rays = rendering.Ray( + origin=jnp.zeros([32, 32, 3]), + radius=jnp.ones((32, 32)), + direction=jnp.ones([32, 32, 3]), + viewdir=jnp.ones([32, 32, 3]), + ) + rf_fn = lambda ray_sample: ((jnp.zeros([]), jnp.zeros(3)), ()) + + rgb, extra = rendering.render_rf( + rf_fn, + rays=rays, + near=1.0, + far=2.0, + num_samples=(3, 5), + seed=jax.random.PRNGKey(0), + ) + + self.assertEqual(rgb.shape, [32, 32, 3]) + self.assertEqual(extra.composed_acc.shape, [32, 32]) + + def test_conditional_render_fn(self): + rays = rendering.Ray( + origin=jnp.zeros([32, 32, 3]), + radius=jnp.ones((32, 32)), + direction=jnp.ones([32, 32, 3]), + viewdir=jnp.ones([32, 32, 3]), + ) + + def cond_rf_fn(ray_sample, cond): + del ray_sample + return (jnp.zeros([]), jnp.zeros(3)), ({'cond_sum': cond.sum()},) + + cond_dim = 7 + rgb, extra = rendering.render_rf( + cond_rf_fn, + rays=rays, + near=1.0, + far=2.0, + num_samples=(3, 5), + seed=jax.random.PRNGKey(0), + cond_kwargs={'cond': jnp.ones([32, 32, cond_dim])}, + ) + + self.assertEqual(rgb.shape, [32, 32, 3]) + self.assertEqual(extra.composed_acc.shape, [32, 32]) + self.assertAllEqual(extra.extra[0]['cond_sum'], cond_dim) + + +if __name__ == '__main__': + test_util.main() diff --git a/discussion/robust_inverse_graphics/util/BUILD b/discussion/robust_inverse_graphics/util/BUILD deleted file mode 100644 index a6ece023f1..0000000000 --- a/discussion/robust_inverse_graphics/util/BUILD +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright 2024 The TensorFlow Probability Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -# Utilities. - -# [internal] load pytype.bzl (pytype_strict_library, pytype_strict_test) -# [internal] load strict.bzl -# Placeholder: py_library - -package( - # default_applicable_licenses - default_visibility = ["//discussion/robust_inverse_graphics:__subpackages__"], -) - -licenses(["notice"]) - -# pytype_strict -py_library( - name = "array_util", - srcs = ["array_util.py"], - deps = [ - # jax dep, - ], -) - -# pytype_strict -py_library( - name = "camera_util", - srcs = ["camera_util.py"], - deps = [ - # numpy dep, - # pyquaternion dep, - ], -) - -# pytype_strict -py_library( - name = "gin_utils", - srcs = ["gin_utils.py"], - deps = [ - # absl/flags dep, - # gin dep, - # yaml dep, - ], -) - -# pytype_strict -py_library( - name = "math_util", - srcs = ["math_util.py"], - deps = [ - # jax dep, - ], -) - -# pytype_strict -py_test( - name = "math_util_test", - srcs = ["math_util_test.py"], - deps = [ - ":math_util", - ":test_util", - # google/protobuf:use_fast_cpp_protos dep, - # jax dep, - ], -) - -# pytype_strict -py_library( - name = "plot_util", - srcs = ["plot_util.py"], - deps = [ - # jax dep, - # matplotlib dep, - # numpy dep, - "//fun_mc:using_jax", - ], -) - -# Not strict or pytype due to the test_util.jax dep. -py_library( - name = "test_util", - testonly = 1, - srcs = ["test_util.py"], - deps = [ - # absl/testing:absltest dep, - # jax dep, - "//tensorflow_probability/python/internal:test_util.jax", - ], -) - -# py_strict -py_test( - name = "test_util_test", - srcs = ["test_util_test.py"], - deps = [ - ":test_util", - # flax dep, - # google/protobuf:use_fast_cpp_protos dep, - # numpy dep, - ], -) - -# pytype_strict -py_library( - name = "tree2", - srcs = ["tree2.py"], - deps = [ - # etils/epath dep, - # immutabledict dep, - # numpy dep, - ], -) - -# py_strict -py_test( - name = "tree2_test", - srcs = ["tree2_test.py"], - deps = [ - ":test_util", - ":tree2", - # absl/testing:parameterized dep, - # flax:core dep, - # google/protobuf:use_fast_cpp_protos dep, - # immutabledict dep, - # jax dep, - # numpy dep, - "//tensorflow_probability:jax", - ], -) - -# pytype_strict -py_library( - name = "tree_util", - srcs = ["tree_util.py"], - deps = [ - # jax dep, - ], -) - -# py_strict -py_test( - name = "tree_util_test", - srcs = ["tree_util_test.py"], - deps = [ - ":test_util", - ":tree_util", - # flax:core dep, - # google/protobuf:use_fast_cpp_protos dep, - # jax dep, - ], -)