Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New datasets framework #111

Merged
merged 1 commit into from
Jan 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ For an overview on neural fields, we recommend you check out the EG STAR report:

## Latest Updates

* _01/02/23_ `attrdict` dependency added as part of the new datasets framework. If you pull latest, make sure to `pip install attrdict`.
* _17/01/23_ `pycuda` replaced with `cuda-python`. Wisp can be installed from pip now (If you pull, run **pip install -r requirements_app.txt**)
* _05/01/23_ Mains are now introduced as standalone apps, for easier support of new pipelines (**breaking change**)
* _21/12/22_ Most modules have been cleaned, reorganized and documented.
Expand Down
53 changes: 29 additions & 24 deletions app/nerf/main_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
import logging
import numpy as np
import torch
import wisp
from wisp.app_utils import default_log_setup, args_to_log_format
import wisp.config_parser as config_parser
from wisp.framework import WispState
from wisp.datasets import MultiviewDataset
from wisp.datasets.transforms import SampleRays
from wisp.datasets import MultiviewDataset, SampleRays
from wisp.models.grids import BLASGrid, OctreeGrid, CodebookOctreeGrid, TriplanarGrid, HashGrid
from wisp.tracers import BaseTracer, PackedRFTracer
from wisp.models.nefs import BaseNeuralField, NeuralRadianceField
Expand Down Expand Up @@ -240,7 +240,7 @@ def parse_args():
return args, args_dict


def load_dataset(args) -> torch.utils.data.Dataset:
def load_dataset(args) -> MultiviewDataset:
""" Loads a multiview dataset comprising of pairs of images and calibrated cameras.
The types of supported datasets are defined by multiview_dataset_format:
'standard' - refers to the standard NeRF format popularized by Mildenhall et al. 2020,
Expand All @@ -250,16 +250,19 @@ def load_dataset(args) -> torch.utils.data.Dataset:
This dataset includes depth information which allows for performance improving optimizations in some cases.
"""
transform = SampleRays(num_samples=args.num_rays_sampled_per_img)
train_dataset = MultiviewDataset(dataset_path=args.dataset_path,
multiview_dataset_format=args.multiview_dataset_format,
mip=args.mip,
bg_color=args.bg_color,
dataset_num_workers=args.dataset_num_workers,
transform=transform)
return train_dataset


def load_grid(args, dataset: torch.utils.data.Dataset) -> BLASGrid:
train_dataset = wisp.datasets.load_multiview_dataset(dataset_path=args.dataset_path,
split='train',
mip=args.mip,
bg_color=args.bg_color,
dataset_num_workers=args.dataset_num_workers,
transform=transform)
validation_dataset = None
if args.valid_every > -1 or args.valid_only:
validation_dataset = train_dataset.create_split(split='val', transform=None)
return train_dataset, validation_dataset


def load_grid(args, dataset: MultiviewDataset) -> BLASGrid:
""" Wisp's implementation of NeRF uses feature grids to improve the performance and quality (allowing therefore,
interactivity).
This function loads the feature grid to use within the neural pipeline.
Expand All @@ -269,13 +272,12 @@ def load_grid(args, dataset: torch.utils.data.Dataset) -> BLASGrid:
See corresponding grid constructors for each of their arg details.
"""
grid = None
# Optimization: For octrees based grids, if dataset contains depth info, initialize only cells known to be occupied
has_depth_supervision = getattr(dataset, "coords", None) is not None

# Optimization: For octrees based grids, if dataset contains depth info, initialize only cells known to be occupied
if args.grid_type == "OctreeGrid":
if has_depth_supervision:
if dataset.supports_depth():
grid = OctreeGrid.from_pointcloud(
pointcloud=dataset.coords,
pointcloud=dataset.as_pointcloud(),
feature_dim=args.feature_dim,
base_lod=args.base_lod,
num_lods=args.num_lods,
Expand All @@ -295,9 +297,9 @@ def load_grid(args, dataset: torch.utils.data.Dataset) -> BLASGrid:
feature_bias=args.feature_bias,
)
elif args.grid_type == "CodebookOctreeGrid":
if has_depth_supervision:
if dataset.supports_depth:
grid = CodebookOctreeGrid.from_pointcloud(
pointcloud=dataset.coords,
pointcloud=dataset.as_pointcloud(),
feature_dim=args.feature_dim,
base_lod=args.base_lod,
num_lods=args.num_lods,
Expand Down Expand Up @@ -359,7 +361,7 @@ def load_grid(args, dataset: torch.utils.data.Dataset) -> BLASGrid:
return grid


def load_neural_field(args, dataset: torch.utils.data.Dataset) -> BaseNeuralField:
def load_neural_field(args, dataset: MultiviewDataset) -> BaseNeuralField:
""" Creates a "Neural Field" instance which converts input coordinates to some output signal.
Here a NeuralRadianceField is created, which maps 3D coordinates (+ 2D view direction) -> RGB + density.
The NeuralRadianceField uses spatial feature grids internally for faster feature interpolation and raymarching.
Expand Down Expand Up @@ -416,7 +418,7 @@ def load_neural_pipeline(args, dataset, device) -> Pipeline:
return pipeline


def load_trainer(pipeline, train_dataset, device, scene_state, args, args_dict) -> BaseTrainer:
def load_trainer(pipeline, train_dataset, validation_dataset, device, scene_state, args, args_dict) -> BaseTrainer:
""" Loads the NeRF trainer.
The trainer is responsible for managing the optimization life-cycles and can be operated in 2 modes:
- Headless, which will run the train() function until all training steps are exhausted.
Expand All @@ -431,7 +433,8 @@ def load_trainer(pipeline, train_dataset, device, scene_state, args, args_dict)
optimizer_params = config_parser.get_args_for_function(args, optimizer_cls)

trainer = MultiviewTrainer(pipeline=pipeline,
dataset=train_dataset,
train_dataset=train_dataset,
validation_dataset=validation_dataset,
num_epochs=args.epochs,
batch_size=args.batch_size,
optim_cls=optimizer_cls,
Expand Down Expand Up @@ -478,10 +481,12 @@ def is_interactive() -> bool:
default_log_setup(args.log_level)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_dataset = load_dataset(args=args)
train_dataset, validation_dataset = load_dataset(args=args)
pipeline = load_neural_pipeline(args=args, dataset=train_dataset, device=device)
scene_state = WispState() # Joint trainer / app state
trainer = load_trainer(pipeline=pipeline, train_dataset=train_dataset, device=device, scene_state=scene_state,
trainer = load_trainer(pipeline=pipeline,
train_dataset=train_dataset, validation_dataset=validation_dataset,
device=device, scene_state=scene_state,
args=args, args_dict=args_dict)
app = load_app(args=args, scene_state=scene_state, trainer=trainer)

Expand Down
30 changes: 19 additions & 11 deletions app/nglod/main_nglod.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from wisp.app_utils import default_log_setup, args_to_log_format
import wisp.config_parser as config_parser
from wisp.framework import WispState
from wisp.datasets import SDFDataset
from wisp.datasets import SDFDataset, MeshSampledSDFDataset, OctreeSampledSDFDataset
from wisp.accelstructs import OctreeAS
from wisp.models.grids import BLASGrid, OctreeGrid, CodebookOctreeGrid, TriplanarGrid, HashGrid
from wisp.tracers import BaseTracer, PackedSDFTracer
from wisp.models.nefs import BaseNeuralField, NeuralSDF
Expand Down Expand Up @@ -237,24 +238,31 @@ def parse_args():
return args, args_dict


def load_dataset(args, pipeline: Pipeline) -> torch.utils.data.Dataset:
def load_dataset(args, pipeline: Pipeline) -> SDFDataset:
""" Loads a dataset of SDF samples generated over the surface of a mesh. """
if isinstance(pipeline.nef.grid, OctreeGrid):
train_dataset = SDFDataset.from_grid(
if OctreeSampledSDFDataset.supports_blas(pipeline.nef.grid.blas):
# The current grid representation uses a bottom-level acceleration structure which can be used for faster
# and more precise resampling of sdf values
train_dataset = OctreeSampledSDFDataset(
occupancy_struct=pipeline.nef.grid.blas,
split='train',
transform=None,
sample_mode=args.sample_mode,
num_samples=args.num_samples,
get_normals=args.get_normals,
sample_tex=args.sample_tex,
grid=pipeline.nef.grid,
samples_per_voxel=args.samples_per_voxel)
samples_per_voxel=args.samples_per_voxel
)
else:
train_dataset = SDFDataset.from_mesh(
train_dataset = MeshSampledSDFDataset(
dataset_path=args.dataset_path,
split='train',
transform=None,
sample_mode=args.sample_mode,
num_samples=args.num_samples,
get_normals=args.get_normals,
sample_tex=args.sample_tex,
dataset_path=args.dataset_path,
mode_norm=args.mode_mesh_norm)
mode_norm=args.mode_mesh_norm
)
return train_dataset


Expand Down Expand Up @@ -380,7 +388,7 @@ def load_trainer(pipeline, train_dataset, device, scene_state, args, args_dict)
optimizer_cls = config_parser.get_module(name=args.optimizer_type)
optimizer_params = config_parser.get_args_for_function(args, optimizer_cls)
trainer = SDFTrainer(pipeline=pipeline,
dataset=train_dataset,
train_dataset=train_dataset,
num_epochs=args.epochs,
batch_size=args.batch_size,
optim_cls=optimizer_cls,
Expand Down
11 changes: 5 additions & 6 deletions examples/latent_nerf/main_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
import torch
from wisp.app_utils import default_log_setup, args_to_log_format
from wisp.framework import WispState
from wisp.datasets import MultiviewDataset
from wisp.datasets.transforms import SampleRays
from wisp.datasets import NeRFSyntheticDataset, SampleRays
from wisp.trainers import MultiviewTrainer
from wisp.models.grids import OctreeGrid
from wisp.models.pipeline import Pipeline
Expand All @@ -37,12 +36,12 @@
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# NeRF is trained with a MultiviewDataset, which knows how to generate RGB rays from a set of images + cameras
train_dataset = MultiviewDataset(
train_dataset = NeRFSyntheticDataset(
dataset_path=args.dataset_path,
multiview_dataset_format='standard',
split='train',
mip=0,
bg_color='black',
dataset_num_workers=-1,
dataset_num_workers=0,
transform=SampleRays(
num_samples=4096
)
Expand Down Expand Up @@ -74,7 +73,7 @@
weight_decay = 0
exp_name = 'siggraph_2022_demo'
trainer = MultiviewTrainer(pipeline=pipeline,
dataset=train_dataset,
train_dataset=train_dataset,
num_epochs=args.epochs,
batch_size=1, # 1 image per batch
optim_cls=torch.optim.RMSprop,
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ PyDispatcher
pynvml
setuptools==59.5.0
wandb>=0.13.5
pytest>=7.1.0
pytest>=7.1.0
attrdict
2 changes: 1 addition & 1 deletion tests/apps/test_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def test_triplanar_lego(self, lego_path, dataset_num_workers):
out = run_wisp_script(cmd, cli_args)
metrics = collect_metrics_from_log(out, ['PSNR'])

assert float(metrics[100]['PSNR']) > 30.5, 'PSNR is too low.'
assert float(metrics[100]['PSNR']) > 30.4, 'PSNR is too low.'
report_metrics(metrics) # Prints to log

def test_codebook_V8(self, V8_path, dataset_num_workers):
Expand Down
6 changes: 4 additions & 2 deletions wisp/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,14 @@ def get_module(name, module_type=None):
def get_args_for_function(args, func):
""" Given a func (for example an __init__(..) function or from_X(..)), and also the parsed args,
return the subset of args that func expects and args contains. """
if isinstance(args, argparse.Namespace):
args = vars(args) # Namespace -> dict
has_kwargs = inspect.getfullargspec(func).varkw != None
if has_kwargs:
collected_args = vars(args)
collected_args = args
else:
parameters = dict(inspect.signature(func).parameters)
collected_args = {a: getattr(args, a) for a in parameters if hasattr(args, a)}
collected_args = {a: args[a] for a in parameters if a in args}
return collected_args


Expand Down
8 changes: 4 additions & 4 deletions wisp/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.

from .sdf_dataset import SDFDataset
from .multiview_dataset import MultiviewDataset
from .random_view_dataset import RandomViewDataset
from .utils import default_collate
from .base_datasets import *
from .utils import *
from .formats import *
from .transforms import *
Loading