Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consistently use the highest bit of vector masks when converting to i1 vectors #104693

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 103 additions & 88 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,60 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
}};
}

/// Returns the bitwidth of the `$ty` argument if it is an `Int` type.
macro_rules! require_int_ty {
($ty: expr, $diag: expr) => {
match $ty {
ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
_ => {
return_error!($diag);
}
}
};
}

/// Returns the bitwidth of the `$ty` argument if it is an `Int` or `Uint` type.
macro_rules! require_int_or_uint_ty {
($ty: expr, $diag: expr) => {
match $ty {
ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
ty::Uint(i) => {
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits())
}
_ => {
return_error!($diag);
}
}
};
}

/// Converts a vector mask, where each element has a bit width equal to the data elements it is used with,
/// down to an i1 based mask that can be used by llvm intrinsics.
///
/// The rust simd semantics are that each element should either consist of all ones or all zeroes,
/// but this information is not available to llvm. Truncating the vector effectively uses the lowest bit,
/// but codegen for several targets is better if we consider the highest bit by shifting.
///
/// For x86 SSE/AVX targets this is beneficial since most instructions with mask parameters only consider the highest bit.
/// So even though on llvm level we have an additional shift, in the final assembly there is no shift or truncate and
/// instead the mask can be used as is.
///
/// For aarch64 and other targets there is a benefit because a mask from the sign bit can be more
/// efficiently converted to an all ones / all zeroes mask by comparing whether each element is negative.
fn vector_mask_to_bitmask<'a, 'll, 'tcx>(
bx: &mut Builder<'a, 'll, 'tcx>,
i_xn: &'ll Value,
in_elem_bitwidth: u64,
in_len: u64,
) -> &'ll Value {
// Shift the MSB to the right by "in_elem_bitwidth - 1" into the first bit position.
let shift_idx = bx.cx.const_int(bx.type_ix(in_elem_bitwidth), (in_elem_bitwidth - 1) as _);
let shift_indices = vec![shift_idx; in_len as _];
let i_xn_msb = bx.lshr(i_xn, bx.const_vector(shift_indices.as_slice()));
// Truncate vector to an <i1 x N>
bx.trunc(i_xn_msb, bx.type_vector(bx.type_i1(), in_len))
}

let tcx = bx.tcx();
let sig =
tcx.normalize_erasing_late_bound_regions(ty::ParamEnv::reveal_all(), callee_ty.fn_sig(tcx));
Expand Down Expand Up @@ -1294,14 +1348,11 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
m_len == v_len,
InvalidMonomorphization::MismatchedLengths { span, name, m_len, v_len }
);
match m_elem_ty.kind() {
ty::Int(_) => {}
_ => return_error!(InvalidMonomorphization::MaskType { span, name, ty: m_elem_ty }),
}
// truncate the mask to a vector of i1s
let i1 = bx.type_i1();
let i1xn = bx.type_vector(i1, m_len as u64);
let m_i1s = bx.trunc(args[0].immediate(), i1xn);
let in_elem_bitwidth = require_int_ty!(
m_elem_ty.kind(),
InvalidMonomorphization::MaskType { span, name, ty: m_elem_ty }
);
let m_i1s = vector_mask_to_bitmask(bx, args[0].immediate(), in_elem_bitwidth, m_len);
return Ok(bx.select(m_i1s, args[1].immediate(), args[2].immediate()));
}

Expand All @@ -1319,32 +1370,12 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
let expected_bytes = expected_int_bits / 8 + ((expected_int_bits % 8 > 0) as u64);

// Integer vector <i{in_bitwidth} x in_len>:
let (i_xn, in_elem_bitwidth) = match in_elem.kind() {
ty::Int(i) => (
args[0].immediate(),
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
),
ty::Uint(i) => (
args[0].immediate(),
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
),
_ => return_error!(InvalidMonomorphization::VectorArgument {
span,
name,
in_ty,
in_elem
}),
};
let in_elem_bitwidth = require_int_or_uint_ty!(
in_elem.kind(),
InvalidMonomorphization::VectorArgument { span, name, in_ty, in_elem }
);

// Shift the MSB to the right by "in_elem_bitwidth - 1" into the first bit position.
let shift_indices =
vec![
bx.cx.const_int(bx.type_ix(in_elem_bitwidth), (in_elem_bitwidth - 1) as _);
in_len as _
];
let i_xn_msb = bx.lshr(i_xn, bx.const_vector(shift_indices.as_slice()));
// Truncate vector to an <i1 x N>
let i1xn = bx.trunc(i_xn_msb, bx.type_vector(bx.type_i1(), in_len));
let i1xn = vector_mask_to_bitmask(bx, args[0].immediate(), in_elem_bitwidth, in_len);
// Bitcast <i1 x N> to iN:
let i_ = bx.bitcast(i1xn, bx.type_ix(in_len));

Expand Down Expand Up @@ -1562,28 +1593,23 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
}
);

match element_ty2.kind() {
ty::Int(_) => (),
_ => {
return_error!(InvalidMonomorphization::ThirdArgElementType {
span,
name,
expected_element: element_ty2,
third_arg: arg_tys[2]
});
let mask_elem_bitwidth = require_int_ty!(
element_ty2.kind(),
InvalidMonomorphization::ThirdArgElementType {
span,
name,
expected_element: element_ty2,
third_arg: arg_tys[2]
}
}
);

// Alignment of T, must be a constant integer value:
let alignment_ty = bx.type_i32();
let alignment = bx.const_i32(bx.align_of(in_elem).bytes() as i32);

// Truncate the mask vector to a vector of i1s:
let (mask, mask_ty) = {
let i1 = bx.type_i1();
let i1xn = bx.type_vector(i1, in_len);
(bx.trunc(args[2].immediate(), i1xn), i1xn)
};
let mask = vector_mask_to_bitmask(bx, args[2].immediate(), mask_elem_bitwidth, in_len);
let mask_ty = bx.type_vector(bx.type_i1(), in_len);

// Type of the vector of pointers:
let llvm_pointer_vec_ty = llvm_vector_ty(bx, element_ty1, in_len);
Expand Down Expand Up @@ -1668,8 +1694,8 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
}
);

require!(
matches!(mask_elem.kind(), ty::Int(_)),
let m_elem_bitwidth = require_int_ty!(
mask_elem.kind(),
InvalidMonomorphization::ThirdArgElementType {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should refer to the first argument, but there is a ui test for it and fixing it probably belongs in a separate PR

span,
name,
Expand All @@ -1678,17 +1704,13 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
}
);

let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
let mask_ty = bx.type_vector(bx.type_i1(), mask_len);

// Alignment of T, must be a constant integer value:
let alignment_ty = bx.type_i32();
let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32);

// Truncate the mask vector to a vector of i1s:
let (mask, mask_ty) = {
let i1 = bx.type_i1();
let i1xn = bx.type_vector(i1, mask_len);
(bx.trunc(args[0].immediate(), i1xn), i1xn)
};

let llvm_pointer = bx.type_ptr();

// Type of the vector of elements:
Expand Down Expand Up @@ -1760,8 +1782,8 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
}
);

require!(
matches!(mask_elem.kind(), ty::Int(_)),
let m_elem_bitwidth = require_int_ty!(
mask_elem.kind(),
InvalidMonomorphization::ThirdArgElementType {
span,
name,
Expand All @@ -1770,17 +1792,13 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
}
);

let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
let mask_ty = bx.type_vector(bx.type_i1(), mask_len);

// Alignment of T, must be a constant integer value:
let alignment_ty = bx.type_i32();
let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32);

// Truncate the mask vector to a vector of i1s:
let (mask, mask_ty) = {
let i1 = bx.type_i1();
let i1xn = bx.type_vector(i1, in_len);
(bx.trunc(args[0].immediate(), i1xn), i1xn)
};

let ret_t = bx.type_void();

let llvm_pointer = bx.type_ptr();
Expand Down Expand Up @@ -1859,28 +1877,23 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
);

// The element type of the third argument must be a signed integer type of any width:
match element_ty2.kind() {
ty::Int(_) => (),
_ => {
return_error!(InvalidMonomorphization::ThirdArgElementType {
span,
name,
expected_element: element_ty2,
third_arg: arg_tys[2]
});
let mask_elem_bitwidth = require_int_ty!(
element_ty2.kind(),
InvalidMonomorphization::ThirdArgElementType {
span,
name,
expected_element: element_ty2,
third_arg: arg_tys[2]
}
}
);

// Alignment of T, must be a constant integer value:
let alignment_ty = bx.type_i32();
let alignment = bx.const_i32(bx.align_of(in_elem).bytes() as i32);

// Truncate the mask vector to a vector of i1s:
let (mask, mask_ty) = {
let i1 = bx.type_i1();
let i1xn = bx.type_vector(i1, in_len);
(bx.trunc(args[2].immediate(), i1xn), i1xn)
};
let mask = vector_mask_to_bitmask(bx, args[2].immediate(), mask_elem_bitwidth, in_len);
let mask_ty = bx.type_vector(bx.type_i1(), in_len);

let ret_t = bx.type_void();

Expand Down Expand Up @@ -2018,8 +2031,13 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
);
args[0].immediate()
} else {
match in_elem.kind() {
ty::Int(_) | ty::Uint(_) => {}
let bitwidth = match in_elem.kind() {
ty::Int(i) => {
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits())
}
ty::Uint(i) => {
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits())
}
_ => return_error!(InvalidMonomorphization::UnsupportedSymbol {
span,
name,
Expand All @@ -2028,12 +2046,9 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
in_elem,
ret_ty
}),
}
};

// boolean reductions operate on vectors of i1s:
let i1 = bx.type_i1();
let i1xn = bx.type_vector(i1, in_len as u64);
bx.trunc(args[0].immediate(), i1xn)
vector_mask_to_bitmask(bx, args[0].immediate(), bitwidth, in_len as _)
};
return match in_elem.kind() {
ty::Int(_) | ty::Uint(_) => {
Expand Down
4 changes: 2 additions & 2 deletions tests/assembly/simd-intrinsic-gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ pub unsafe extern "C" fn gather_f64x4(mask: m64x4, ptrs: pf64x4) -> f64x4 {
// FIXME: This should also get checked to generate a gather instruction for avx2.
// Currently llvm scalarizes this code, see https://github.com/llvm/llvm-project/issues/59789
//
// x86-avx512: vpsllq ymm0, ymm0, 63
// x86-avx512-NEXT: vpmovq2m k1, ymm0
// x86-avx512-NOT: vpsllq
// x86-avx512: vpmovq2m k1, ymm0
// x86-avx512-NEXT: vpxor xmm0, xmm0, xmm0
// x86-avx512-NEXT: vgatherqpd ymm0 {k1}, ymmword ptr [1*ymm1]
simd_gather(f64x4([0_f64, 0_f64, 0_f64, 0_f64]), ptrs, mask)
Expand Down
27 changes: 13 additions & 14 deletions tests/assembly/simd-intrinsic-mask-load.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ extern "rust-intrinsic" {
pub unsafe extern "C" fn load_i8x16(mask: m8x16, pointer: *const i8) -> i8x16 {
// Since avx2 supports no masked loads for bytes, the code tests each individual bit
// and jumps to code that inserts individual bytes.
// x86-avx2: vpsllw xmm0, xmm0, 7
// x86-avx2-NEXT: vpmovmskb eax, xmm0
// x86-avx2-NEXT: vpxor xmm0, xmm0
// x86-avx2-NOT: vpsllw
// x86-avx2-DAG: vpmovmskb eax
// x86-avx2-DAG: vpxor
// x86-avx2-NEXT: test al, 1
// x86-avx2-NEXT: jne
// x86-avx2-NEXT: test al, 2
Expand All @@ -57,32 +57,31 @@ pub unsafe extern "C" fn load_i8x16(mask: m8x16, pointer: *const i8) -> i8x16 {
// x86-avx2-NEXT: vmovd xmm0, [[REG]]
// x86-avx2-DAG: vpinsrb xmm0, xmm0, byte ptr [rdi + 1], 1
//
// x86-avx512: vpsllw xmm0, xmm0, 7
// x86-avx512-NEXT: vpmovb2m k1, xmm0
// x86-avx512-NOT: vpsllw
// x86-avx512: vpmovb2m k1, xmm0
// x86-avx512-NEXT: vmovdqu8 xmm0 {k1} {z}, xmmword ptr [rdi]
simd_masked_load(mask, pointer, i8x16([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
}

// CHECK-LABEL: load_f32x8
#[no_mangle]
pub unsafe extern "C" fn load_f32x8(mask: m32x8, pointer: *const f32) -> f32x8 {
// x86-avx2: vpslld ymm0, ymm0, 31
// x86-avx2-NEXT: vmaskmovps ymm0, ymm0, ymmword ptr [rdi]
// x86-avx2-NOT: vpslld
// x86-avx2: vmaskmovps ymm0, ymm0, ymmword ptr [rdi]
//
// x86-avx512: vpslld ymm0, ymm0, 31
// x86-avx512-NEXT: vpmovd2m k1, ymm0
// x86-avx512-NOT: vpslld
// x86-avx512: vpmovd2m k1, ymm0
// x86-avx512-NEXT: vmovups ymm0 {k1} {z}, ymmword ptr [rdi]
simd_masked_load(mask, pointer, f32x8([0_f32, 0_f32, 0_f32, 0_f32, 0_f32, 0_f32, 0_f32, 0_f32]))
}

// CHECK-LABEL: load_f64x4
#[no_mangle]
pub unsafe extern "C" fn load_f64x4(mask: m64x4, pointer: *const f64) -> f64x4 {
// x86-avx2: vpsllq ymm0, ymm0, 63
// x86-avx2-NEXT: vmaskmovpd ymm0, ymm0, ymmword ptr [rdi]
// x86-avx2-NOT: vpsllq
// x86-avx2: vmaskmovpd ymm0, ymm0, ymmword ptr [rdi]
//
// x86-avx512: vpsllq ymm0, ymm0, 63
// x86-avx512-NEXT: vpmovq2m k1, ymm0
// x86-avx512-NEXT: vmovupd ymm0 {k1} {z}, ymmword ptr [rdi]
// x86-avx512-NOT: vpsllq
// x86-avx512: vpmovq2m k1, ymm0
simd_masked_load(mask, pointer, f64x4([0_f64, 0_f64, 0_f64, 0_f64]))
}
Loading
Loading