A jax library of common machine learning loss functions
pip install jax-loss
- Binary Cross Entropy With Logits Loss
If you don't see the loss function you need, feel free to open an issue or open a PR!
- L1
- Mean Squared Error
- Cross Entropy
- Negative Log Liklihood
- KL Divergance
- Binary Cross Entropy
- Triplet Margin
- Cosine Embedding Loss
- Hinge