Skip to content

Commit

Permalink
Merge pull request #4106 from pleroy/YUSoSlow2
Browse files Browse the repository at this point in the history
Speed up Stehlé-Zimmermann by using locally-constructed Taylor polynomials
  • Loading branch information
pleroy authored Oct 5, 2024
2 parents 287add5 + 9c11419 commit 9b9d46c
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 159 deletions.
39 changes: 30 additions & 9 deletions functions/accurate_table_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,29 @@ using namespace principia::base::_thread_pool;
using namespace principia::numerics::_polynomial_in_monomial_basis;

using AccurateFunction = std::function<cpp_bin_float_50(cpp_rational const&)>;
using ApproximateFunction = std::function<double(cpp_rational const&)>;

// The use of factories below greatly speeds up the search (in one case, from
// multiple hours to 11 s) without affecting correctness. It seems that Taylor
// polynomials constructed at the `starting_argument` get transformed into
// polynomials with larger and larger coefficients as we scan more and more
// distant slices. This in turn makes polynomial composition and lattice
// reduction progressively more expensive. Locally constructed Taylor
// polynomials behave much better (100'000× speed-ups have been observed for
// some slices).

template<typename ArgValue, int degree>
using AccuratePolynomial =
PolynomialInMonomialBasis<ArgValue, ArgValue, degree>;
template<typename ArgValue, int degree>
using AccuratePolynomialFactory =
std::function<AccuratePolynomial<ArgValue, degree>(cpp_rational const&)>;

// The remainders don't need to be extremely precise, so for speed
// they are computed using double.
using ApproximateFunction = std::function<double(cpp_rational const&)>;
using ApproximateFunctionFactory =
std::function<ApproximateFunction(cpp_rational const&)>;


template<std::int64_t zeroes>
cpp_rational GalExhaustiveSearch(std::vector<AccurateFunction> const& functions,
Expand Down Expand Up @@ -54,8 +72,9 @@ absl::StatusOr<cpp_rational> StehléZimmermannSimultaneousSearch(
template<std::int64_t zeroes>
absl::StatusOr<cpp_rational> StehléZimmermannSimultaneousFullSearch(
std::array<AccurateFunction, 2> const& functions,
std::array<AccuratePolynomial<cpp_rational, 2>, 2> const& polynomials,
std::array<ApproximateFunction, 2> const& remainders,
std::array<AccuratePolynomialFactory<cpp_rational, 2>, 2> const&
polynomials,
std::array<ApproximateFunctionFactory, 2> const& remainders,
cpp_rational const& starting_argument,
ThreadPool<void>* search_pool = nullptr);

Expand All @@ -66,9 +85,9 @@ template<std::int64_t zeroes>
std::vector<absl::StatusOr<cpp_rational>>
StehléZimmermannSimultaneousMultisearch(
std::array<AccurateFunction, 2> const& functions,
std::vector<std::array<AccuratePolynomial<cpp_rational, 2>, 2>> const&
polynomials,
std::vector<std::array<ApproximateFunction, 2>> const& remainders,
std::vector<std::array<AccuratePolynomialFactory<cpp_rational, 2>, 2>>
const& polynomials,
std::vector<std::array<ApproximateFunctionFactory, 2>> const& remainders,
std::vector<cpp_rational> const& starting_arguments);

// Same as above, but instead of accumulating all the results and returning them
Expand All @@ -77,9 +96,9 @@ StehléZimmermannSimultaneousMultisearch(
template<std::int64_t zeroes>
void StehléZimmermannSimultaneousStreamingMultisearch(
std::array<AccurateFunction, 2> const& functions,
std::vector<std::array<AccuratePolynomial<cpp_rational, 2>, 2>> const&
polynomials,
std::vector<std::array<ApproximateFunction, 2>> const& remainders,
std::vector<std::array<AccuratePolynomialFactory<cpp_rational, 2>, 2>>
const& polynomials,
std::vector<std::array<ApproximateFunctionFactory, 2>> const& remainders,
std::vector<cpp_rational> const& starting_arguments,
std::function<void(/*index=*/std::int64_t,
absl::StatusOr<cpp_rational>)> const& callback);
Expand All @@ -88,7 +107,9 @@ void StehléZimmermannSimultaneousStreamingMultisearch(

using internal::AccurateFunction;
using internal::AccuratePolynomial;
using internal::AccuratePolynomialFactory;
using internal::ApproximateFunction;
using internal::ApproximateFunctionFactory;
using internal::GalExhaustiveMultisearch;
using internal::GalExhaustiveSearch;
using internal::StehléZimmermannSimultaneousFullSearch;
Expand Down
111 changes: 58 additions & 53 deletions functions/accurate_table_generator_body.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,6 @@ using namespace principia::quantities::_quantities;
constexpr std::int64_t T_max = 4;
static_assert(T_max >= 1);

// When starting a new interval, we multiply the value `T` that led to a
// rejection of the previous interval and multiply it by this value. This
// avoids restarting from a large value of `T` and doing pointless halving.
constexpr std::int64_t T_multiplier = 2;
static_assert(T_multiplier >= 1);

template<std::int64_t zeroes>
bool HasDesiredZeroes(cpp_bin_float_50 const& y) {
std::int64_t y_exponent;
Expand All @@ -75,11 +69,18 @@ bool AllFunctionValuesHaveDesiredZeroes(

struct StehléZimmermannSpecification {
std::array<AccurateFunction, 2> functions;
std::array<AccuratePolynomial<cpp_rational, 2>, 2> polynomials;
std::array<ApproximateFunction, 2> remainders;
std::array<AccuratePolynomialFactory<cpp_rational, 2>, 2> polynomials;
std::array<ApproximateFunctionFactory, 2> remainders;
cpp_rational argument;
};

template<typename Factory>
std::array<std::invoke_result_t<Factory, cpp_rational>, 2> EvaluateFactoriesAt(
std::array<Factory, 2> const& factories,
cpp_rational const& argument) {
return {factories[0](argument), factories[1](argument)};
}

// In general, scales the argument, functions, polynomials, and remainders to
// lie within [1/2, 1[. There is a subtlety though if the input is such that
// either the argument or a function is a power of 2 or close enough to a power
Expand Down Expand Up @@ -144,37 +145,41 @@ StehléZimmermannSpecification ScaleToBinade01(

std::array<double, 2> function_scales;
std::array<AccurateFunction, 2> scaled_functions;
std::array<ApproximateFunction, 2> scaled_remainders;
std::array<AccuratePolynomialFactory<cpp_rational, 2>, 2> scaled_polynomials;
std::array<ApproximateFunctionFactory, 2> scaled_remainders;
for (std::int64_t i = 0; i < scaled_functions.size(); ++i) {
function_scales[i] = compute_scale(functions[i](lower_bound),
functions[i](upper_bound));
scaled_functions[i] = [argument_scale,
function_scale = function_scales[i],
function = functions[i],
i](cpp_rational const& argument) {
function =
functions[i]](cpp_rational const& argument) {
return function_scale * function(argument / argument_scale);
};
scaled_polynomials[i] = [argument_scale,
function_scale = function_scales[i],
polynomial = polynomials[i]](
cpp_rational const& argument₀) {
return function_scale * Compose(polynomial(argument₀ / argument_scale),
AccuratePolynomial<cpp_rational, 1>(
{0, 1 / argument_scale}));
};
scaled_remainders[i] = [argument_scale,
function_scale = function_scales[i],
remainder = remainders[i],
i](cpp_rational const& argument) {
return function_scale * remainder(argument / argument_scale);
remainder =
remainders[i]](cpp_rational const& argument₀) {
return [argument₀,
argument_scale,
function_scale,
remainder](cpp_rational const& argument) {
return function_scale *
remainder(argument₀ / argument_scale)(argument / argument_scale);
};
};
}

auto build_scaled_polynomial =
[argument_scale, &starting_argument](
double const function_scale,
AccuratePolynomial<cpp_rational, 2> const& polynomial) {
return function_scale * Compose(polynomial,
AccuratePolynomial<cpp_rational, 1>(
{0, 1 / argument_scale}));
};
return {.functions = scaled_functions,
.polynomials = {build_scaled_polynomial(function_scales[0],
polynomials[0]),
build_scaled_polynomial(function_scales[1],
polynomials[1])},
.polynomials = scaled_polynomials,
.remainders = scaled_remainders,
.argument = scaled_argument};
}
Expand Down Expand Up @@ -280,14 +285,9 @@ absl::StatusOr<cpp_rational> StehléZimmermannSimultaneousSliceSearch(
constexpr std::int64_t N = 1LL << std::numeric_limits<double>::digits;

// [SZ05], section 3.2, proves that T³ = O(M * N). Of course, the
// multiplicative factor is not known. In practice it seems that a value of
// T₀ that's too large is very costly as it results in many intervals that are
// rejected with `OutOfRange` and must be halved and retried. A value that's
// too small on the other hand can slow down progress. The fudge factor 1/128
// attempts to strike a balance between these problems; it has been chosen by
// benchmarking SinCos18 around 1167/2048.
// multiplicative factor is not known, but 1 works well in practice.
std::int64_t const T₀ = static_cast<std::int64_t>(
Cbrt(static_cast<double>(M) * static_cast<double>(N)) / 128.0);
Cbrt(static_cast<double>(M) * static_cast<double>(N)));

// Construct intervals of measure `2 * T₀` above and below `scaled.argument`
// and search for solutions on each side alternatively.
Expand All @@ -298,15 +298,21 @@ absl::StatusOr<cpp_rational> StehléZimmermannSimultaneousSliceSearch(
.min = scaled.argument - cpp_rational(2 * (slice_index + 1) * T₀, N),
.max = scaled.argument - cpp_rational(2 * slice_index * T₀, N)};

// Evaluate the factories at the centre of each half of the slice.
auto const high_polynomials =
EvaluateFactoriesAt(scaled.polynomials, initial_high_interval.midpoint());
auto const high_remainders =
EvaluateFactoriesAt(scaled.remainders, initial_high_interval.midpoint());
auto const low_polynomials =
EvaluateFactoriesAt(scaled.polynomials, initial_low_interval.midpoint());
auto const low_remainders =
EvaluateFactoriesAt(scaled.remainders, initial_low_interval.midpoint());

// The radii of the intervals remaining to cover above and below the
// `scaled.argument`.
std::int64_t high_T_to_cover = T₀;
std::int64_t low_T_to_cover = T₀;

// The last value of `T` for which the search was run and found no solution.
std::int64_t last_high_T = T₀;
std::int64_t last_low_T = T₀;

// When exiting this loop, we have completely processed
// `initial_high_interval` and `initial_low_interval`.
for (;;) {
Expand All @@ -316,7 +322,7 @@ absl::StatusOr<cpp_rational> StehléZimmermannSimultaneousSliceSearch(
}

if (high_T_to_cover > 0) {
std::int64_t T = std::min(T_multiplier * last_high_T, high_T_to_cover);
std::int64_t T = high_T_to_cover;
// This loop exits (breaks or returns) when `T <= T_max` because
// exhaustive search always gives an answer.
for (;;) {
Expand All @@ -328,8 +334,8 @@ absl::StatusOr<cpp_rational> StehléZimmermannSimultaneousSliceSearch(
VLOG(3) << "T = " << T << ", high_T_to_cover = " << high_T_to_cover;
auto const status_or_solution =
StehléZimmermannSimultaneousSearch<zeroes>(scaled.functions,
scaled.polynomials,
scaled.remainders,
high_polynomials,
high_remainders,
high_interval_midpoint,
N,
T);
Expand All @@ -344,7 +350,6 @@ absl::StatusOr<cpp_rational> StehléZimmermannSimultaneousSliceSearch(
} else if (absl::IsNotFound(status)) {
// No solutions here, go to the next interval.
high_T_to_cover -= T;
last_high_T = T;
break;
} else {
return status;
Expand All @@ -353,7 +358,7 @@ absl::StatusOr<cpp_rational> StehléZimmermannSimultaneousSliceSearch(
}
}
if (low_T_to_cover > 0) {
std::int64_t T = std::min(T_multiplier * last_low_T, low_T_to_cover);
std::int64_t T = low_T_to_cover;
// This loop exits (breaks or returns) when `T <= T_max` because
// exhaustive search always gives an answer.
for (;;) {
Expand All @@ -365,8 +370,8 @@ absl::StatusOr<cpp_rational> StehléZimmermannSimultaneousSliceSearch(
VLOG(3) << "T = " << T << ", low_T_to_cover = " << low_T_to_cover;
auto const status_or_solution =
StehléZimmermannSimultaneousSearch<zeroes>(scaled.functions,
scaled.polynomials,
scaled.remainders,
low_polynomials,
low_remainders,
low_interval_midpoint,
N,
T);
Expand All @@ -382,7 +387,6 @@ absl::StatusOr<cpp_rational> StehléZimmermannSimultaneousSliceSearch(
} else if (absl::IsNotFound(status)) {
// No solutions here, go to the next interval.
low_T_to_cover -= T;
last_low_T = T;
break;
} else {
return status;
Expand Down Expand Up @@ -624,8 +628,9 @@ absl::StatusOr<cpp_rational> StehléZimmermannSimultaneousSearch(
template<std::int64_t zeroes>
absl::StatusOr<cpp_rational> StehléZimmermannSimultaneousFullSearch(
std::array<AccurateFunction, 2> const& functions,
std::array<AccuratePolynomial<cpp_rational, 2>, 2> const& polynomials,
std::array<ApproximateFunction, 2> const& remainders,
std::array<AccuratePolynomialFactory<cpp_rational, 2>, 2> const&
polynomials,
std::array<ApproximateFunctionFactory, 2> const& remainders,
cpp_rational const& starting_argument,
ThreadPool<void>* const search_pool) {
// Start by scaling the specification of the search. The rest of this
Expand Down Expand Up @@ -780,9 +785,9 @@ template<std::int64_t zeroes>
std::vector<absl::StatusOr<cpp_rational>>
StehléZimmermannSimultaneousMultisearch(
std::array<AccurateFunction, 2> const& functions,
std::vector<std::array<AccuratePolynomial<cpp_rational, 2>, 2>> const&
polynomials,
std::vector<std::array<ApproximateFunction, 2>> const& remainders,
std::vector<std::array<AccuratePolynomialFactory<cpp_rational, 2>, 2>>
const& polynomials,
std::vector<std::array<ApproximateFunctionFactory, 2>> const& remainders,
std::vector<cpp_rational> const& starting_arguments) {
std::vector<absl::StatusOr<cpp_rational>> result;
result.resize(starting_arguments.size());
Expand All @@ -801,9 +806,9 @@ StehléZimmermannSimultaneousMultisearch(
template<std::int64_t zeroes>
void StehléZimmermannSimultaneousStreamingMultisearch(
std::array<AccurateFunction, 2> const& functions,
std::vector<std::array<AccuratePolynomial<cpp_rational, 2>, 2>> const&
polynomials,
std::vector<std::array<ApproximateFunction, 2>> const& remainders,
std::vector<std::array<AccuratePolynomialFactory<cpp_rational, 2>, 2>>
const& polynomials,
std::vector<std::array<ApproximateFunctionFactory, 2>> const& remainders,
std::vector<cpp_rational> const& starting_arguments,
std::function<void(/*index=*/std::int64_t,
absl::StatusOr<cpp_rational>)> const& callback) {
Expand Down
Loading

0 comments on commit 9b9d46c

Please sign in to comment.