Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/deprecate binary cross entropy loss #203

Merged

Conversation

blakechi
Copy link
Contributor

@blakechi blakechi commented Nov 24, 2022

This PR is opened to resolve an issue found by @tomaarsen. Thanks @tomaarsen and nice PR #187, I use it as the template!

Pull request overview

Major changes

  • Deprecate torch.nn.BCELoss for binary classification and replace it by torch.nn.CrossEntropyLoss as 2-class classification.
  • Update docstring
  • Update test_modeling.py

Minor changes

  • Add eps for numerical stability when scaling logits with temperature

Details

Problem

As @tomaarsen mentioned here:

As it turns out, fix 2 is a bit tricky due to the different requirements of the different loss functions. The input (i.e. the prediction) is always float32, so that's all sorted. Then, the BCELoss wants the target (i.e. the labels) to be the same dtype as the input, so also float32. This is why I've opted to make fix 2.

However, the other loss used by the SetFitHead, CrossEntropyLoss, wants the target to be float64 whenever the target has a different shape than the input.

BCELoss and CrossEntropyLoss require different data types respectively, which increases complexity to cast the labels into correct data type and makes the code harder to read.

Solution

Therefore, this PR deprecates BCELoss and use CrossEntropyLoss for binary classification.

Result

I run some experiments on CR and EnronSpam datasets to check whether the performance is changed.
Here is the results with N=8 as Table 2 in the paper:

Head CR EnronSpam Amazon-CF
sklearn (from paper) 88.5 (1.9) 90.1 (3.4) 40.3 (11.8)
pytorch (BCE) 88.8 (1.3) 90.0 (3.3) 76.4 (11.2)
pytorch (CrossEntropy) 88.8 (1.4) 90.3 (3.2) 73.1 (10.5)

(Hyperparamters: batch size = 16, L2 weight (weight decay) = 0, head learning rate = 1e-2, keep body frozen)

From the results, their performance are similar except Amazon-CF. It looked strange to me and after several runs, I still got similar results at the end. Here is the notebook I used to run experiments for pytorch (CrossEntropy).

Conclusion

The advantages of this change (copied from here):

Still need to validate whether the results of Amazon-CF is correct. @lewtun could you have a check on the notebook to validate the experiments? Thanks!

Copy link
Member

@tomaarsen tomaarsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking great! Thanks for this. Well written PR, too!

I reran your experiments, and can corroborate the results. It all looks above board, too. Very interesting to see such large differences compared to the LogisticRegression head. Perhaps this is caused by a class inbalance or something similar?
I did some further tests with other datasets, i.e. non-binary, and it still works like intended.

I'll put #187 on hold until this is merged.

@blakechi
Copy link
Contributor Author

Thanks for you feedbacks, @tomaarsen! That really helps validate the experiments!

I checked the test dataset of Amazon-CF and it matched your guess. The proportion is 10% (503 samples) for 1 vs 90% (5497 samples) for 0. Now I'm thinking why in this PR I didn't get similar results as what we see now. Maybe something is changed? Will look into it.

@tomaarsen
Copy link
Member

@blakechi I had gotten a hunch about the results and the class imbalance. No reputable researchers would evaluate a classification task with a 90:10 class inbalance using accuracy, and most certainly no reasonable data scientist would then only score a ~43% on the task, hah!
The discrepancy is shown here:

TEST_DATASET_TO_METRIC = {
"emotion": "accuracy",
"SentEval-CR": "accuracy",
"sst5": "accuracy",
"ag_news": "accuracy",
"enron_spam": "accuracy",
"amazon_counterfactual_en": "matthews_correlation",
}

In the paper, amazon_counterfactual_en is evaluated using the matthews correlation coefficient, rather than accuracy. See Table 6, appendix A for the details.
However, in the script, if you supply just a list of datasets, each will be evaluated using just accuracy. Your earlier test used the test dataset from the previous snippet, causing the AmazonCF task to be evaluated using MCC.


I've re-ran the experiments for the LogisticRegression head and the differentiable SetFitHead with these commands:

python scripts/setfit/run_fewshot.py --sample_sizes=8 --lr=0.01 --keep_body_frozen --datasets amazon_counterfactual_en --batch_size 4
python scripts/setfit/run_fewshot.py --sample_sizes=8 --classifier=pytorch --lr=0.01 --keep_body_frozen --datasets amazon_counterfactual_en --batch_size 4

with the following change to scripts/setfit/run_fewshot.py:

diff --git a/scripts/setfit/run_fewshot.py b/scripts/setfit/run_fewshot.py
index 088e25d..ae03992 100644
--- a/scripts/setfit/run_fewshot.py
+++ b/scripts/setfit/run_fewshot.py
@@ -91,7 +91,7 @@ def main():
     elif args.is_test_set:
         dataset_to_metric = TEST_DATASET_TO_METRIC
     else:
-        dataset_to_metric = {dataset: "accuracy" for dataset in args.datasets}
+        dataset_to_metric = {dataset: "matthews_correlation" for dataset in args.datasets}
 
     # Configure loss function
     loss_class = LOSS_NAME_TO_CLASS[args.loss]

This has resulted in these outputs:

Head Amazon CF (MCC)
LogisticRegression 44.4 (8.6)
SetFitHead w. CrossEntropyLoss 44.6 (16.9)

This is much more along the line of what we would expect. With other words, this PR seems to work like intended, without any fun additional surprises.

  • Tom Aarsen

@blakechi
Copy link
Contributor Author

Great finding! Ya, that explains why the numbers are so different. Appreciated for the experiments as well!

@blakechi
Copy link
Contributor Author

Put this PR on hold since #207 might make some changes that overlaps.

@tomaarsen tomaarsen marked this pull request as ready for review January 10, 2023 14:58
Copy link
Member

@tomaarsen tomaarsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this ought to be ready for merging now. I resolved some merge conflicts and verified that training after this PR works equivalently to the main branch. I ran several experiments using sst2 and sst5 with success, and clearly the tests pass, too.

Additionally, this PR should supersede #187 and solve #186.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants