-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
K-fold v2 #315
K-fold v2 #315
Conversation
@@ -700,7 +705,7 @@ def label_vocab_nlabels(embedding_path: Path) -> Tuple[pd.DataFrame, int]: | |||
|
|||
|
|||
def dataloader_from_split_name( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All @khumairraj :)
fold as the train split. | ||
Folds will be sorted before applying the above strategy | ||
Total data splits will be equal to n, n being the total number of folds. | ||
Each fold will be tested by training on the remaining folds. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd just write an explicit example of what is returned, key and values
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or say:
"""
With 5-fold, for example, we would have:
test=fold1, val=fold2, train=fold3..5,
test=fold2, val=fold3, train=fold4..5,1,
...
@@ -925,6 +1028,8 @@ def task_predictions( | |||
seed_everything(42, workers=False) | |||
|
|||
metadata = json.load(embedding_path.joinpath("task_metadata.json").open()) | |||
metadata["mode"] = "folds" # remove me, only for testing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the convention that was defined? Now I'm forgetting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we have a convention defined yet - need to check this b/c metadata["mode"] might not actually be defined for the existing open tasks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated this so if "fold" is in metadata (and is a list of fold string names) then it will perform k-fold. Otherwise will work as normal with pre-defined splits.
"embedding_path": str(embedding_path), | ||
} | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function is getting a bit long, can we break it down?
Can we save the scores of the different folds?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did my best to pull some stuff out with out going crazy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still kinda concerned tho, it's really long. Any way we can go further?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for completing this @jorshi. It looks much cleaner.
Also, there was some code duplication before, mostly around the behaviour of finding the best grid point and then retraining on all(train + Val) the data(without validation steps like early stopping, etc) with the characteristics of the best grid point(including the early stopping epoch and others). But, I see in this pr that we have removed that behavior. I think it was introducing lots of conditions, and it is good that we removed it.
!= json.load(open(task_path.joinpath("test.embedding-dimensions.json")))[1] | ||
): | ||
# Ensure all embedding sizes are the same across splits/folds | ||
embedding_size = embedding_sizes[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we even need embedding_size
any more then? if we have embedding_sizes
?
metadata: Dict[str, Any], | ||
data_splits: Dict[str, List[str]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is data_splits? Maybe it's worth having a doc-string here? Or is this called by another function with the same set of parameters?
I am wondering if it even makes sense that we have a class for such a long list of things? Or most of them if it makes sense to divide it in a sensible way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
json.load(embedding_path.joinpath(f"{split_name}.json").open()) | ||
) | ||
test_target_events = {} | ||
for split_name in data_splits["test"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh man this seems really confusing, particularly if you understand the other pattern from hear-preprocess that splits is a list, not a dict of lists. How do we fix this and make it clearer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -845,7 +885,7 @@ def task_predictions_train( | |||
logger=logger, | |||
) | |||
train_dataloader = dataloader_from_split_name( | |||
"train", | |||
data_splits["train"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again this stuff is so confusing. Can we put all this weird logic in one place with a clear explanation, so this weird pattern doesn't occur throughout this file? The fewer patterns you have to remember, the better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updates the hear eval pipeline to support K-Fold datasets. Depends on preprocessed data updates that are proposed in hearbenchmark/hear-preprocess#100 -- review those first. The small datasets that are used for tests in this will need to be updated in order for tests to run properly.
This builds off of @khumairraj's PR #310