Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added substation segementation dataset #2352

Open
wants to merge 73 commits into
base: main
Choose a base branch
from
Open

Conversation

rijuld
Copy link

@rijuld rijuld commented Oct 17, 2024

No description provided.

@github-actions github-actions bot added documentation Improvements or additions to documentation datasets Geospatial or benchmark datasets testing Continuous integration testing labels Oct 17, 2024
@adamjstewart adamjstewart added this to the 0.7.0 milestone Oct 17, 2024
@adamjstewart
Copy link
Collaborator

Hi @rijuld, thanks for the contribution! If you're new to creating PyTorch datasets, I highly recommend reading the following tutorials:

The only difference between datasets in torchvision and NonGeoDatasets in TorchGeo is that our __getitem__ returns a dictionary instead of a tuple. Other than that, they share all the same basic components.

Most of your issues seem to be due to the use of args. I think you just need to remove this and explicitly list all parameters in the function signature. This will also simplify your testing code. Take a look at other existing datasets, we have about 75 examples to choose from. If you find one that is similar to your dataset, it shouldn't actually require that many changes to get them working.

@rijuld
Copy link
Author

rijuld commented Oct 22, 2024

Hi @adamjstewart , thanks a ton for the feedback! I will go through this tutorial.

@rijuld
Copy link
Author

rijuld commented Jan 9, 2025

@adamjstewart I'm unsure about how to test the configuration file locally. Additionally, would testing the conf file cover the datamodule as well, or would I need to write separate test cases for the datamodule to ensure adequate code coverage?

@adamjstewart
Copy link
Collaborator

I'm unsure about how to test the configuration file locally.

You can get the test name from CI and then run it like so:

> pytest tests/trainers/test_segmentation.py::TestSemanticSegmentationTask::test_trainer[True-substation]

Additionally, would testing the conf file cover the datamodule as well

Yes

@rijuld rijuld requested a review from adamjstewart January 22, 2025 11:57
@@ -53,6 +53,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`SSL4EO`_-S12,T,Sentinel-1/2,"CC-BY-4.0",1M,-,264x264,10,"SAR, MSI"
`SSL4EO-L Benchmark`_,S,Lansat & CDL,"CC0-1.0",25K,134,264x264,30,MSI
`SSL4EO-L Benchmark`_,S,Lansat & NLCD,"CC0-1.0",25K,17,264x264,30,MSI
`Substation`_,S,OpenStreetMap & Sentinel-2, "CC-BY-SA 2.0", 27K, 2, 228x228, 10, MSI
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where did you find the license? I don't see anything on GitHub or in the paper.

"""Fixture for the Substation."""
root = os.path.join(os.getcwd(), 'tests', 'data', 'substation')

yield Substation(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't normally use yield when creating our fixtures, any reason you're using a generator?

num_of_timepoints=4,
)

@pytest.mark.parametrize(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't you parametrize the fixture instead of the unit test? Then all other unit tests will also be parametrized.

from collections.abc import Generator
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't use unittest, we use pytest. Can you convert these to MonkeyPatch which is builtin to pytest?


import torch
from torch.utils.data import Subset, random_split
from tqdm import tqdm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tqdm isn't a dependency and we don't really want to add new deps unless we have to

def __init__(
self,
root: Path,
bands: list[int],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are tuples also allowed? In this case you would use Sequence instead of list

Comment on lines +82 to +86
self.load_image_filenames()

def load_image_filenames(self) -> None:
"""Load image filenames from the image directory."""
self.image_filenames = os.listdir(self.image_dir)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.load_image_filenames()
def load_image_filenames(self) -> None:
"""Load image filenames from the image directory."""
self.image_filenames = os.listdir(self.image_dir)
self.image_filenames = sorted(os.listdir(self.image_dir))

No need for a helper function for a single line of code. Also, sorting this list makes the dataset reproducible.

timepoint_aggregation: How to aggregate multiple timepoints.
num_of_timepoints: Number of timepoints to use for each image.
download: Whether to download the dataset if it is not found.
checksum: Whether to verify the dataset after downloading.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are in a different order than the parameters. Also, can you add a transforms parameter to match all of our other datasets? It should be used at the end of __getitem__.

print(prediction.shape)
ncols = 3

print(mask.shape, image.shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can remove debugging print statements

self.url_for_images,
self.root,
filename=self.filename_images,
md5='INSERT_IMAGES_MD5_HASH' if self.checksum else None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: add MD5s

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
datamodules PyTorch Lightning datamodules datasets Geospatial or benchmark datasets documentation Improvements or additions to documentation testing Continuous integration testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants