Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing numpy.ndarray to SetFitHead.predict_proba #207

Merged
merged 7 commits into from
Jan 10, 2023

Conversation

jegork
Copy link
Contributor

@jegork jegork commented Nov 28, 2022

Hello!

Currently, when using SetFitModel.predict_proba("test") leads to:

  File "/Users/jegorkitskerkin/Documents/vectory/./vectory/scorer/setfit.py", line 45, in get
    print(self.model.predict_proba(query))
  File "/Users/jegorkitskerkin/opt/miniconda3/envs/vectory/lib/python3.10/site-packages/setfit/modeling.py", line 314, in predict_proba
    return self.model_head.predict_proba(embeddings)
  File "/Users/jegorkitskerkin/opt/miniconda3/envs/vectory/lib/python3.10/site-packages/setfit/modeling.py", line 138, in predict_proba
    return self(x_test)
  File "/Users/jegorkitskerkin/opt/miniconda3/envs/vectory/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/jegorkitskerkin/opt/miniconda3/envs/vectory/lib/python3.10/site-packages/setfit/modeling.py", line 122, in forward
    logits = self.linear(x)
  File "/Users/jegorkitskerkin/opt/miniconda3/envs/vectory/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/jegorkitskerkin/opt/miniconda3/envs/vectory/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
TypeError: linear(): argument 'input' (position 1) must be Tensor, not numpy.ndarray

This happens because the SentenceTransformer.encode returns numpy.ndarray rather than tensors, so I have moved numpy.ndarray to torch.Tensor conversion from SetFitHead.predict to .predict_proba (as predict calls predict_proba internally)

Feel free to add better solutions, as this one has to call isinstance(x_test, torch.Tensor) both in .predict and .predict_proba (so that .predict returns same datatype as in the parameter)

Copy link
Contributor

@blakechi blakechi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jegork,

Thank you so much for fixing this issue. I left some suggestions and the core idea is moving the full converting part to predict_proba. Just some suggestions, I'm happy to hear your feedback and discuss a better solution.

src/setfit/modeling.py Outdated Show resolved Hide resolved
src/setfit/modeling.py Show resolved Hide resolved
src/setfit/modeling.py Show resolved Hide resolved
@tomaarsen
Copy link
Member

To add on to @blakechi's comments, I think this PR is a good place to discuss the convert_to_tensor parameter in the SentenceTransformer its encode method. This method is used here in SetFit:

def predict(self, x_test: torch.Tensor) -> torch.Tensor:
embeddings = self.model_body.encode(x_test, normalize_embeddings=self.normalize_embeddings)
return self.model_head.predict(embeddings)
def predict_proba(self, x_test: torch.Tensor) -> torch.Tensor:
embeddings = self.model_body.encode(x_test, normalize_embeddings=self.normalize_embeddings)
return self.model_head.predict_proba(embeddings)

Because the SetFitHead.predict(_proba) methods are not generally directly interacted with, but only via SetFitModel.predict(_proba), I propose to set this parameter to True (and convert_to_numpy to False) and then we can completely remove support for np.ndarray. After all, why would a torch nn.Module subclass that is used as a last part of a custom torch pipeline need to have support np.ndarray?

Optionally, we may implement a as_numpy (or to_numpy) parameter that converts the torch Tensor to an np.ndarray, but only after the head has done its computation.


Another related issue that I had noted is the type hinting for the predict(_proba) methods from the previous snippet are nonsense: (self, x_test: torch.Tensor) -> torch.Tensor:.

self.model_body.encode accepts Union[List[str], str]], so x_test should have that type. The embeddings are np.ndarray because SentenceTransformer produces a np.ndarray by default, and then the self.model_head.predict(_proba) preserves the type, causing the method to return a np.ndarray, not torch.Tensor.

So, the current type hints should be (self, x_test: Union[List[str], str]) -> np.ndarray:, and with my proposal above implemented, they would be (self, x_test: Union[List[str], str], as_numpy: bool = False) -> Union[torch.Tensor, np.ndarray]:.

Would love to hear your thoughts.

  • Tom Aarsen

@blakechi
Copy link
Contributor

blakechi commented Nov 30, 2022

Now I'm a fan of @tomaarsen proposed solution! 😂 Great proposal! Ya, by doing so it's much cleaner and it makes more sense in term of integrating with sklearn's head since we should put the converting part in where two different heads are integrated, which is SetFitModel, and make SetFitHead takes care of torch.Tensor related tasks.

Just add on a bit, we may also add a check to see which head we are using right now to determine it's True or False for SentenceTransformer.encode's convert_to_numpy since the sklearn's head can't accept a torch.Tensor.

It may depends on which side we prefer more, but considering the consistency of the outputs, we may want to always convert the outputs to one type (either ndarray or torch.Tensor) and have either to_numpy or to_tensor correspondingly for flexibility.

And lastly, it will be great if you (@jegork) can fix the typing issue in this PR as @tomaarsen mentioned :)

@jegork
Copy link
Contributor Author

jegork commented Dec 6, 2022

Hey @blakechi,
Thanks for your guidance! I've made .predict_proba output the same data type (tensor/ndarray) as the input that it was given and added the if statements in .predict that you've mentioned.

@tomaarsen thanks for your input! Indeed seems like this is a more coherent solution. Do you want me to implement your idea as part of this PR?

@blakechi, the typing issue that is mentioned appears in SetFitModel, as this PR is related to another issue, do you think it is a good idea to fix in this PR? I can open a new PR and address it there

@tomaarsen
Copy link
Member

If we want to go with the approach of removing np.ndarray support for the SetFitHead, then it seems fitting for this PR, as it kind of runs counter to what is already being proposed here.

@jegork
Copy link
Contributor Author

jegork commented Dec 8, 2022

@tomaarsen I have implemented your suggestion, would be nice if you could take a look. However, it seems that the method signature that you suggested is incorrect ((self, x_test: Union[List[str], str]) -> np.ndarray).
.predict/predict_proba can only accept List[str] and not just a str. Otherwise, in the current implementation, it would lead to errors in case head is LogisticRegression.

Reshape your data either using array.reshape(-1, 1) if your data has a single feature or array.reshape(1, -1) if it contains a single sample.

I decided to keep it as is for now, if you think it is good to add also str support, then I can reshape it when needed, based on the datatype of the passed input.

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for fixing these methods and their types @jegork and welcome to the 🤗 community!

The PR looks great as it is, and we can later include support for predicting on a single string if there's a strong request from the community.

I've left one nit and you'll need to merge the latest changes from main to resolve the merge conflicts

Also thank you to @blakechi and @tomaarsen for fantastic feedback on this PR 🔥

src/setfit/modeling.py Outdated Show resolved Hide resolved
@jegork
Copy link
Contributor Author

jegork commented Dec 12, 2022

Thanks for your reply! @lewtun
I've renamed out -> outputs and fetched the latest changes from main

Copy link
Member

@tomaarsen tomaarsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests are failing for me locally. I get 8 errors like so:

FAILED tests/test_trainer.py::SetFitTrainerTest::test_raise_when_metric_value_is_invalid - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

This is caused by a detail that @blakechi already mentioned earlier in #207 (comment):

Just add on a bit, we may also add a check to see which head we are using right now to determine it's True or False for SentenceTransformer.encode's convert_to_numpy since the sklearn's head can't accept a torch.Tensor.

In short, the issue originates here now:

        embeddings = self.model_body.encode(
            x_test, normalize_embeddings=self.normalize_embeddings, convert_to_tensor=True
        )

        outputs = self.model_head.predict(embeddings)

The self.model_body can produce GPU Tensors, while the current self.model_head cannot accept a torch Tensor if it is a simple LogisticRegression. Well, it can, it will internally perform a np.asarray(...) call, which happens to work for a tensor on the CPU, but not if the tensor is on the GPU. In the latter case, Torch will tell you to perform .cpu() or .detach().cpu() first.

This all is indicative that we need a much larger overhaul of the two head implementations: there's way too many calls to isinstance that should not be needed. I'll try to plan out a design.

This was referenced Dec 15, 2022
@tomaarsen tomaarsen added the bug Something isn't working label Dec 20, 2022
@lewtun
Copy link
Member

lewtun commented Dec 28, 2022

Hi folks, just coming back to this after my vacation - apologies for the delay!

@tomaarsen are you still getting errors on the tests locally? I just checked out this branch and it seems to run fine :) If it's green on your end, I think we can merge this!

@tomaarsen
Copy link
Member

tomaarsen commented Dec 28, 2022

Welcome back!

I'm afraid that the test failures persist. See my review for details on the bug. The CI doesn't catch it because it doesn't have a GPU.

@lewtun
Copy link
Member

lewtun commented Dec 29, 2022

Thanks for the clarification @tomaarsen ! Let's try to sort out this messy two-head business and see if there's a clean way we can deal with these types of inconsistencies between torch tensors and numpy arrays

Previously, the embeddings were always Tensors, which doesn't work with the Logistic Regression head, in particular if the tensors are on CUDA.
@tomaarsen
Copy link
Member

I've merged main into this PR to update it. Furthermore, I've resolved the aforementioned type conversion issues by relying on has_differentiable_head introduced in #257. I use this property to either use Tensors or numpy arrays for the self.model_head.predict or self.model_head.predict_proba calls, and then again to determine if output type transformations are required.

I still believe that we would benefit from refactoring, but I recognize that we should already push out this fix before a potential refactor is ready.

I believe this PR is quite important, for several issues have been made to report the bug first reported in #207 (comment). I'd like to merge this ASAP.

@jegork
Copy link
Contributor Author

jegork commented Jan 9, 2023

Would be nice if we can have it merged, thank you for your assistance @tomaarsen !

@tomaarsen
Copy link
Member

tomaarsen commented Jan 9, 2023

I'll merge it if you think it looks good now! @jegork

Others are also welcome to review it :)

@jegork
Copy link
Contributor Author

jegork commented Jan 10, 2023

@tomaarsen looks good to me!

@tomaarsen tomaarsen merged commit fa07883 into huggingface:main Jan 10, 2023
@tomaarsen
Copy link
Member

Thanks for this @jegork! Apologies for the delay in tackling this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants