Skip to content

Commit

Permalink
Add fallback for __mod__ magic method (#1934)
Browse files Browse the repository at this point in the history
* Add fallback for `__mod__` magic method

* Add 'CHANGELOG' entry

* Complete tests
  • Loading branch information
lycantropos authored Oct 19, 2021
1 parent bf26dae commit 7349513
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fix incorrect linking to version-specific DLL instead of `python3.dll` when cross-compiling to Windows with `abi3`. [#1880](https://github.com/PyO3/pyo3/pull/1880)
- Fix panic in generated `#[derive(FromPyObject)]` for enums. [#1888](https://github.com/PyO3/pyo3/pull/1888)
- Fix cross-compiling to Python 3.7 builds with the "m" abi flag. [#1908](https://github.com/PyO3/pyo3/pull/1908)
- Fix `__mod__` magic method fallback to `__rmod__`. [#1934](https://github.com/PyO3/pyo3/pull/1934).

## [0.14.5] - 2021-09-05

Expand Down
2 changes: 2 additions & 0 deletions pyo3-macros-backend/src/defs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,9 @@ pub const NUM: Proto = Proto {
SlotDef::new(&["__mul__", "__rmul__"], "Py_nb_multiply", "mul_rmul"),
SlotDef::new(&["__mul__"], "Py_nb_multiply", "mul"),
SlotDef::new(&["__rmul__"], "Py_nb_multiply", "rmul"),
SlotDef::new(&["__mod__", "__rmod__"], "Py_nb_remainder", "mod_rmod"),
SlotDef::new(&["__mod__"], "Py_nb_remainder", "mod_"),
SlotDef::new(&["__rmod__"], "Py_nb_remainder", "rmod"),
SlotDef::new(
&["__divmod__", "__rdivmod__"],
"Py_nb_divmod",
Expand Down
7 changes: 7 additions & 0 deletions src/class/number.rs
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,14 @@ py_binary_fallback_num_func!(
);
py_binary_num_func!(mul, PyNumberMulProtocol, T::__mul__);
py_binary_reversed_num_func!(rmul, PyNumberRMulProtocol, T::__rmul__);
py_binary_fallback_num_func!(
mod_rmod,
T,
PyNumberModProtocol::__mod__,
PyNumberRModProtocol::__rmod__
);
py_binary_num_func!(mod_, PyNumberModProtocol, T::__mod__);
py_binary_reversed_num_func!(rmod, PyNumberRModProtocol, T::__rmod__);
py_binary_fallback_num_func!(
divmod_rdivmod,
T,
Expand Down
22 changes: 22 additions & 0 deletions tests/test_arithmetics_protos.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ impl PyNumberProtocol for BinaryArithmetic {
format!("{:?} - {:?}", lhs, rhs)
}

fn __mod__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} % {:?}", lhs, rhs)
}

fn __mul__(lhs: &PyAny, rhs: &PyAny) -> String {
format!("{:?} * {:?}", lhs, rhs)
}
Expand Down Expand Up @@ -195,6 +199,8 @@ fn binary_arithmetic() {
py_run!(py, c, "assert 1 - c == '1 - BA'");
py_run!(py, c, "assert c * 1 == 'BA * 1'");
py_run!(py, c, "assert 1 * c == '1 * BA'");
py_run!(py, c, "assert c % 1 == 'BA % 1'");
py_run!(py, c, "assert 1 % c == '1 % BA'");

py_run!(py, c, "assert c << 1 == 'BA << 1'");
py_run!(py, c, "assert 1 << c == '1 << BA'");
Expand Down Expand Up @@ -225,6 +231,10 @@ impl PyNumberProtocol for RhsArithmetic {
format!("{:?} - RA", other)
}

fn __rmod__(&self, other: &PyAny) -> String {
format!("{:?} % RA", other)
}

fn __rmul__(&self, other: &PyAny) -> String {
format!("{:?} * RA", other)
}
Expand Down Expand Up @@ -264,6 +274,8 @@ fn rhs_arithmetic() {
py_run!(py, c, "assert 1 + c == '1 + RA'");
py_run!(py, c, "assert c.__rsub__(1) == '1 - RA'");
py_run!(py, c, "assert 1 - c == '1 - RA'");
py_run!(py, c, "assert c.__rmod__(1) == '1 % RA'");
py_run!(py, c, "assert 1 % c == '1 % RA'");
py_run!(py, c, "assert c.__rmul__(1) == '1 * RA'");
py_run!(py, c, "assert 1 * c == '1 * RA'");
py_run!(py, c, "assert c.__rlshift__(1) == '1 << RA'");
Expand Down Expand Up @@ -299,6 +311,10 @@ impl PyNumberProtocol for LhsAndRhs {
format!("{:?} - {:?}", lhs, rhs)
}

fn __mod__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} % {:?}", lhs, rhs)
}

fn __mul__(lhs: PyRef<Self>, rhs: &PyAny) -> String {
format!("{:?} * {:?}", lhs, rhs)
}
Expand Down Expand Up @@ -339,6 +355,10 @@ impl PyNumberProtocol for LhsAndRhs {
format!("{:?} - RA", other)
}

fn __rmod__(&self, other: &PyAny) -> String {
format!("{:?} % RA", other)
}

fn __rmul__(&self, other: &PyAny) -> String {
format!("{:?} * RA", other)
}
Expand Down Expand Up @@ -388,6 +408,7 @@ fn lhs_fellback_to_rhs() {
// If the light hand value is `LhsAndRhs`, LHS is used.
py_run!(py, c, "assert c + 1 == 'LR + 1'");
py_run!(py, c, "assert c - 1 == 'LR - 1'");
py_run!(py, c, "assert c % 1 == 'LR % 1'");
py_run!(py, c, "assert c * 1 == 'LR * 1'");
py_run!(py, c, "assert c << 1 == 'LR << 1'");
py_run!(py, c, "assert c >> 1 == 'LR >> 1'");
Expand All @@ -399,6 +420,7 @@ fn lhs_fellback_to_rhs() {
// Fellback to RHS because of type mismatching
py_run!(py, c, "assert 1 + c == '1 + RA'");
py_run!(py, c, "assert 1 - c == '1 - RA'");
py_run!(py, c, "assert 1 % c == '1 % RA'");
py_run!(py, c, "assert 1 * c == '1 * RA'");
py_run!(py, c, "assert 1 << c == '1 << RA'");
py_run!(py, c, "assert 1 >> c == '1 >> RA'");
Expand Down

0 comments on commit 7349513

Please sign in to comment.