-
Notifications
You must be signed in to change notification settings - Fork 387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added substation segementation dataset #2352
base: main
Are you sure you want to change the base?
Conversation
Hi @rijuld, thanks for the contribution! If you're new to creating PyTorch datasets, I highly recommend reading the following tutorials:
The only difference between datasets in torchvision and NonGeoDatasets in TorchGeo is that our Most of your issues seem to be due to the use of |
Hi @adamjstewart , thanks a ton for the feedback! I will go through this tutorial. |
@adamjstewart I'm unsure about how to test the configuration file locally. Additionally, would testing the conf file cover the datamodule as well, or would I need to write separate test cases for the datamodule to ensure adequate code coverage? |
You can get the test name from CI and then run it like so: > pytest tests/trainers/test_segmentation.py::TestSemanticSegmentationTask::test_trainer[True-substation]
Yes |
@@ -53,6 +53,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands | |||
`SSL4EO`_-S12,T,Sentinel-1/2,"CC-BY-4.0",1M,-,264x264,10,"SAR, MSI" | |||
`SSL4EO-L Benchmark`_,S,Lansat & CDL,"CC0-1.0",25K,134,264x264,30,MSI | |||
`SSL4EO-L Benchmark`_,S,Lansat & NLCD,"CC0-1.0",25K,17,264x264,30,MSI | |||
`Substation`_,S,OpenStreetMap & Sentinel-2, "CC-BY-SA 2.0", 27K, 2, 228x228, 10, MSI |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where did you find the license? I don't see anything on GitHub or in the paper.
"""Fixture for the Substation.""" | ||
root = os.path.join(os.getcwd(), 'tests', 'data', 'substation') | ||
|
||
yield Substation( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't normally use yield when creating our fixtures, any reason you're using a generator?
num_of_timepoints=4, | ||
) | ||
|
||
@pytest.mark.parametrize( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't you parametrize the fixture instead of the unit test? Then all other unit tests will also be parametrized.
from collections.abc import Generator | ||
from pathlib import Path | ||
from typing import Any | ||
from unittest.mock import MagicMock |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't use unittest
, we use pytest
. Can you convert these to MonkeyPatch
which is builtin to pytest?
|
||
import torch | ||
from torch.utils.data import Subset, random_split | ||
from tqdm import tqdm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tqdm isn't a dependency and we don't really want to add new deps unless we have to
def __init__( | ||
self, | ||
root: Path, | ||
bands: list[int], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are tuples also allowed? In this case you would use Sequence instead of list
self.load_image_filenames() | ||
|
||
def load_image_filenames(self) -> None: | ||
"""Load image filenames from the image directory.""" | ||
self.image_filenames = os.listdir(self.image_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.load_image_filenames() | |
def load_image_filenames(self) -> None: | |
"""Load image filenames from the image directory.""" | |
self.image_filenames = os.listdir(self.image_dir) | |
self.image_filenames = sorted(os.listdir(self.image_dir)) |
No need for a helper function for a single line of code. Also, sorting this list makes the dataset reproducible.
timepoint_aggregation: How to aggregate multiple timepoints. | ||
num_of_timepoints: Number of timepoints to use for each image. | ||
download: Whether to download the dataset if it is not found. | ||
checksum: Whether to verify the dataset after downloading. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are in a different order than the parameters. Also, can you add a transforms parameter to match all of our other datasets? It should be used at the end of __getitem__
.
print(prediction.shape) | ||
ncols = 3 | ||
|
||
print(mask.shape, image.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can remove debugging print statements
self.url_for_images, | ||
self.root, | ||
filename=self.filename_images, | ||
md5='INSERT_IMAGES_MD5_HASH' if self.checksum else None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: add MD5s
No description provided.