Skip to content

Commit

Permalink
Merge pull request #539 from Shnatsel/stb-paeth
Browse files Browse the repository at this point in the history
Port of stb_image optimized paeth unfiltering
  • Loading branch information
Shnatsel authored Dec 4, 2024
2 parents 9020cd9 + 95fabd4 commit 7e4a5a4
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 80 deletions.
10 changes: 7 additions & 3 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ jobs:
feature_check:
strategy:
matrix:
features: ["", "benchmarks"]
runs-on: ubuntu-latest
features: ["", "unstable", "benchmarks"]
os: [ubuntu-latest, macos-latest] # macos-latest is ARM
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
Expand All @@ -54,7 +55,10 @@ jobs:
rustup target add powerpc-unknown-linux-gnu
cargo build --target powerpc-unknown-linux-gnu
test_all:
runs-on: ubuntu-latest
strategy:
matrix:
os: [ubuntu-latest, macos-latest] # macos-latest is ARM
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
- run: rustup default stable
Expand Down
139 changes: 62 additions & 77 deletions src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@ use crate::common::BytesPerPixel;
/// TODO(https://github.com/rust-lang/rust/issues/86656): Stop gating this module behind the
/// "unstable" feature of the `png` crate. This should be possible once the "portable_simd"
/// feature of Rust gets stabilized.
#[cfg(feature = "unstable")]
///
/// This is only known to help on x86, with no change measured on most benchmarks on ARM,
/// and even severely regressing some of them.
/// So despite the code being portable, we only enable this for x86.
/// We can add more platforms once this code is proven to be beneficial for them.
#[cfg(all(feature = "unstable", target_arch = "x86_64"))]
mod simd {
use std::simd::cmp::{SimdOrd, SimdPartialEq, SimdPartialOrd};
use std::simd::num::{SimdInt, SimdUint};
use std::simd::{u8x4, u8x8, LaneCount, Simd, SimdElement, SupportedLaneCount};

Expand Down Expand Up @@ -39,18 +43,6 @@ mod simd {
out.into()
}

/// This is an equivalent of the `PaethPredictor` function from
/// [the spec](http://www.libpng.org/pub/png/spec/1.2/PNG-Filters.html#Filter-type-4-Paeth)
/// except that it simultaneously calculates the predictor for all SIMD lanes.
/// Mapping between parameter names and pixel positions can be found in
/// [a diagram here](https://www.w3.org/TR/png/#filter-byte-positions).
///
/// Examples of how different pixel types may be represented as multiple SIMD lanes:
/// - RGBA => 4 lanes of `i16x4` contain R, G, B, A
/// - RGB => 4 lanes of `i16x4` contain R, G, B, and a ignored 4th value
///
/// The SIMD algorithm below is based on [`libpng`](https://github.com/glennrp/libpng/blob/f8e5fa92b0e37ab597616f554bee254157998227/intel/filter_sse2_intrinsics.c#L261-L280).
///
/// Functionally equivalent to `simd::paeth_predictor` but does not temporarily convert
/// the SIMD elements to `i16`.
fn paeth_predictor_u8<const N: usize>(
Expand All @@ -61,44 +53,11 @@ mod simd {
where
LaneCount<N>: SupportedLaneCount,
{
// Calculates the absolute difference between `a` and `b`.
fn abs_diff_simd<const N: usize>(a: Simd<u8, N>, b: Simd<u8, N>) -> Simd<u8, N>
where
LaneCount<N>: SupportedLaneCount,
{
a.simd_max(b) - b.simd_min(a)
let mut out = [0; N];
for i in 0..N {
out[i] = super::filter_paeth_decode(a[i].into(), b[i].into(), c[i].into());
}

// Uses logic from `filter::filter_paeth` to calculate absolute values
// entirely in `Simd<u8, N>`. This method avoids unpacking and packing
// penalties resulting from conversion to and from `Simd<i16, N>`.
// ```
// let pa = b.max(c) - c.min(b);
// let pb = a.max(c) - c.min(a);
// let pc = if (a < c) == (c < b) {
// pa.max(pb) - pa.min(pb)
// } else {
// 255
// };
// ```
let pa = abs_diff_simd(b, c);
let pb = abs_diff_simd(a, c);
let pc = a
.simd_lt(c)
.simd_eq(c.simd_lt(b))
.select(abs_diff_simd(pa, pb), Simd::splat(255));

let smallest = pc.simd_min(pa.simd_min(pb));

// Paeth algorithm breaks ties favoring a over b over c, so we execute the following
// lane-wise selection:
//
// if smalest == pa
// then select a
// else select (if smallest == pb then select b else select c)
smallest
.simd_eq(pa)
.select(a, smallest.simd_eq(pb).select(b, c))
out.into()
}

/// Memory of previous pixels (as needed to unfilter `FilterType::Paeth`).
Expand Down Expand Up @@ -318,32 +277,44 @@ impl Default for AdaptiveFilterType {
}
}

#[cfg(target_arch = "x86_64")]
fn filter_paeth_decode(a: u8, b: u8, c: u8) -> u8 {
// Decoding seems to optimize better with this algorithm
let pa = (i16::from(b) - i16::from(c)).abs();
let pb = (i16::from(a) - i16::from(c)).abs();
let pc = ((i16::from(a) - i16::from(c)) + (i16::from(b) - i16::from(c))).abs();

let mut out = a;
let mut min = pa;

if pb < min {
min = pb;
out = b;
}
if pc < min {
out = c;
}

out
// Decoding optimizes better with this algorithm than with `filter_paeth()`
//
// This formulation looks very different from the reference in the PNG spec, but is
// actually equivalent and has favorable data dependencies and admits straightforward
// generation of branch-free code, which helps performance significantly.
//
// Adapted from public domain PNG implementation:
// https://github.com/nothings/stb/blob/5c205738c191bcb0abc65c4febfa9bd25ff35234/stb_image.h#L4657-L4668
let thresh = i16::from(c) * 3 - (i16::from(a) + i16::from(b));
let lo = a.min(b);
let hi = a.max(b);
let t0 = if hi as i16 <= thresh { lo } else { c };
let t1 = if thresh <= lo as i16 { hi } else { t0 };
return t1;
}

#[cfg(feature = "unstable")]
#[cfg(all(feature = "unstable", target_arch = "x86_64"))]
fn filter_paeth_decode_i16(a: i16, b: i16, c: i16) -> i16 {
// Like `filter_paeth_decode` but vectorizes better when wrapped in SIMD
let pa = (b - c).abs();
let pb = (a - c).abs();
let pc = ((a - c) + (b - c)).abs();
// Like `filter_paeth_decode` but vectorizes better when wrapped in SIMD types.
// Used for bpp=3 and bpp=6
let thresh = c * 3 - (a + b);
let lo = a.min(b);
let hi = a.max(b);
let t0 = if hi <= thresh { lo } else { c };
let t1 = if thresh <= lo { hi } else { t0 };
return t1;
}

#[cfg(not(target_arch = "x86_64"))]
fn filter_paeth_decode(a: u8, b: u8, c: u8) -> u8 {
// On ARM this algorithm performs much better than the one above adapted from stb,
// and this is the better-studied algorithm we've always used here,
// so we default to it on all non-x86 platforms.
let pa = (i16::from(b) - i16::from(c)).abs();
let pb = (i16::from(a) - i16::from(c)).abs();
let pc = ((i16::from(a) - i16::from(c)) + (i16::from(b) - i16::from(c))).abs();

let mut out = a;
let mut min = pa;
Expand Down Expand Up @@ -769,7 +740,8 @@ pub(crate) fn unfilter(
}
}
BytesPerPixel::Three => {
#[cfg(feature = "unstable")]
// Do not enable this algorithm on ARM, that would be a big performance hit
#[cfg(all(feature = "unstable", target_arch = "x86_64"))]
simd::unfilter_paeth3(previous, current);

#[cfg(not(feature = "unstable"))]
Expand Down Expand Up @@ -797,7 +769,7 @@ pub(crate) fn unfilter(
}
}
BytesPerPixel::Four => {
#[cfg(feature = "unstable")]
#[cfg(all(feature = "unstable", target_arch = "x86_64"))]
simd::unfilter_paeth_u8::<4>(previous, current);

#[cfg(not(feature = "unstable"))]
Expand Down Expand Up @@ -828,7 +800,7 @@ pub(crate) fn unfilter(
}
}
BytesPerPixel::Six => {
#[cfg(feature = "unstable")]
#[cfg(all(feature = "unstable", target_arch = "x86_64"))]
simd::unfilter_paeth6(previous, current);

#[cfg(not(feature = "unstable"))]
Expand Down Expand Up @@ -865,7 +837,7 @@ pub(crate) fn unfilter(
}
}
BytesPerPixel::Eight => {
#[cfg(feature = "unstable")]
#[cfg(all(feature = "unstable", target_arch = "x86_64"))]
simd::unfilter_paeth_u8::<8>(previous, current);

#[cfg(not(feature = "unstable"))]
Expand Down Expand Up @@ -1160,6 +1132,19 @@ mod test {
}
}

#[test]
#[ignore] // takes ~20s without optimizations
fn paeth_impls_are_equivalent() {
use super::{filter_paeth, filter_paeth_decode};
for a in 0..=255 {
for b in 0..=255 {
for c in 0..=255 {
assert_eq!(filter_paeth(a, b, c), filter_paeth_decode(a, b, c));
}
}
}
}

#[test]
fn roundtrip_ascending_previous_line() {
// A multiple of 8, 6, 4, 3, 2, 1
Expand Down

0 comments on commit 7e4a5a4

Please sign in to comment.