Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

K-fold v2 #315

Merged
merged 26 commits into from
Nov 7, 2021
Merged

K-fold v2 #315

merged 26 commits into from
Nov 7, 2021

Conversation

jorshi
Copy link
Contributor

@jorshi jorshi commented Nov 4, 2021

Updates the hear eval pipeline to support K-Fold datasets. Depends on preprocessed data updates that are proposed in hearbenchmark/hear-preprocess#100 -- review those first. The small datasets that are used for tests in this will need to be updated in order for tests to run properly.

This builds off of @khumairraj's PR #310

@@ -700,7 +705,7 @@ def label_vocab_nlabels(embedding_path: Path) -> Tuple[pd.DataFrame, int]:


def dataloader_from_split_name(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All @khumairraj :)

fold as the train split.
Folds will be sorted before applying the above strategy
Total data splits will be equal to n, n being the total number of folds.
Each fold will be tested by training on the remaining folds.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd just write an explicit example of what is returned, key and values

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or say:

"""
With 5-fold, for example, we would have:
test=fold1, val=fold2, train=fold3..5,
test=fold2, val=fold3, train=fold4..5,1,
...

@@ -925,6 +1028,8 @@ def task_predictions(
seed_everything(42, workers=False)

metadata = json.load(embedding_path.joinpath("task_metadata.json").open())
metadata["mode"] = "folds" # remove me, only for testing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the convention that was defined? Now I'm forgetting

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we have a convention defined yet - need to check this b/c metadata["mode"] might not actually be defined for the existing open tasks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated this so if "fold" is in metadata (and is a list of fold string names) then it will perform k-fold. Otherwise will work as normal with pre-defined splits.

"embedding_path": str(embedding_path),
}
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function is getting a bit long, can we break it down?

Can we save the scores of the different folds?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did my best to pull some stuff out with out going crazy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still kinda concerned tho, it's really long. Any way we can go further?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jorshi jorshi changed the title [WIP] K-fold v2 K-fold v2 Nov 5, 2021
Copy link
Contributor

@khumairraj khumairraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for completing this @jorshi. It looks much cleaner.
Also, there was some code duplication before, mostly around the behaviour of finding the best grid point and then retraining on all(train + Val) the data(without validation steps like early stopping, etc) with the characteristics of the best grid point(including the early stopping epoch and others). But, I see in this pr that we have removed that behavior. I think it was introducing lots of conditions, and it is good that we removed it.

!= json.load(open(task_path.joinpath("test.embedding-dimensions.json")))[1]
):
# Ensure all embedding sizes are the same across splits/folds
embedding_size = embedding_sizes[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we even need embedding_size any more then? if we have embedding_sizes?

metadata: Dict[str, Any],
data_splits: Dict[str, List[str]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is data_splits? Maybe it's worth having a doc-string here? Or is this called by another function with the same set of parameters?

I am wondering if it even makes sense that we have a class for such a long list of things? Or most of them if it makes sense to divide it in a sensible way?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

json.load(embedding_path.joinpath(f"{split_name}.json").open())
)
test_target_events = {}
for split_name in data_splits["test"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh man this seems really confusing, particularly if you understand the other pattern from hear-preprocess that splits is a list, not a dict of lists. How do we fix this and make it clearer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -845,7 +885,7 @@ def task_predictions_train(
logger=logger,
)
train_dataloader = dataloader_from_split_name(
"train",
data_splits["train"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again this stuff is so confusing. Can we put all this weird logic in one place with a clear explanation, so this weird pattern doesn't occur throughout this file? The fewer patterns you have to remember, the better

Copy link
Contributor

@khumairraj khumairraj Nov 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@turian turian merged commit 9547ddb into main Nov 7, 2021
@turian turian deleted the add_kfold_3 branch November 7, 2021 21:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants