Skip to content

Commit

Permalink
[mono][jit] Adding support for Vector128::ExtractMostSignificantBits …
Browse files Browse the repository at this point in the history
…intrinsic on ARM64 miniJIT
  • Loading branch information
ivanpovazan committed Apr 13, 2023
1 parent 11a0671 commit 01e7f7a
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/mono/mono/arch/arm64/arm64-codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -1804,6 +1804,7 @@ arm_encode_arith_imm (int imm, guint32 *shift)
#define arm_neon_cmhi(p, width, type, rd, rn, rm) arm_neon_3svec_opcode ((p), (width), 0b1, (type), 0b00110, (rd), (rn), (rm))
#define arm_neon_cmhs(p, width, type, rd, rn, rm) arm_neon_3svec_opcode ((p), (width), 0b1, (type), 0b00111, (rd), (rn), (rm))
#define arm_neon_addp(p, width, type, rd, rn, rm) arm_neon_3svec_opcode ((p), (width), 0b0, (type), 0b10111, (rd), (rn), (rm))
#define arm_neon_ushl(p, width, type, rd, rn, rm) arm_neon_3svec_opcode ((p), (width), 0b1, (type), 0b01000, (rd), (rn), (rm))

// Generalized macros for float ops:
// width - determines if full register or its lower half is used one of {VREG_LOW, VREG_FULL}
Expand Down
2 changes: 2 additions & 0 deletions src/mono/mono/mini/cpu-arm64.mdesc
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,8 @@ create_scalar_unsafe_int: dest:x src1:i len:4
create_scalar_unsafe_float: dest:x src1:f len:4
arm64_bic: dest:x src1:x src2:x len:4
bitwise_select: dest:x src1:x src2:x src3:x len:12
arm64_ushl: dest:x src1:x src2:x len:4
arm64_ext_imm: dest:x src1:x src2:x len:4

generic_class_init: src1:a len:44 clob:c
gc_safe_point: src1:i len:12 clob:c
Expand Down
14 changes: 12 additions & 2 deletions src/mono/mono/mini/mini-arm64.c
Original file line number Diff line number Diff line change
Expand Up @@ -3920,8 +3920,18 @@ mono_arch_output_basic_block (MonoCompile *cfg, MonoBasicBlock *bb)
// case OP_XCONCAT:
// arm_neon_ext_16b(code, dreg, sreg1, sreg2, 8);
// break;

/* BRANCH */
case OP_ARM64_USHL: {
arm_neon_ushl (code, get_vector_size_macro (ins), get_type_size_macro (ins->inst_c1), dreg, sreg1, sreg2);
break;
}
case OP_ARM64_EXT_IMM: {
if (get_vector_size_macro (ins) == VREG_LOW)
arm_neon_ext_8b (code, dreg, sreg1, sreg2, ins->inst_c0);
else
arm_neon_ext_16b (code, dreg, sreg1, sreg2, ins->inst_c0);
break;
}
/* BRANCH */
case OP_BR:
mono_add_patch_info_rel (cfg, offset, MONO_PATCH_INFO_BB, ins->inst_target_bb, MONO_R_ARM64_B);
arm_b (code, code);
Expand Down
3 changes: 3 additions & 0 deletions src/mono/mono/mini/mini-ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -1759,6 +1759,7 @@ MINI_OP(OP_ARM64_ABSCOMPARE, "arm64_abscompare", XREG, XREG, XREG)
MINI_OP(OP_ARM64_XNARROW_SCALAR, "arm64_xnarrow_scalar", XREG, XREG, NONE)

MINI_OP3(OP_ARM64_EXT, "arm64_ext", XREG, XREG, XREG, IREG)
MINI_OP(OP_ARM64_EXT_IMM, "arm64_ext_imm", XREG, XREG, XREG)

MINI_OP3(OP_ARM64_SQRDMLAH, "arm64_sqrdmlah", XREG, XREG, XREG, XREG)
MINI_OP3(OP_ARM64_SQRDMLAH_BYSCALAR, "arm64_sqrdmlah_byscalar", XREG, XREG, XREG, XREG)
Expand All @@ -1775,6 +1776,8 @@ MINI_OP3(OP_ARM64_SQRDMLSH_SCALAR, "arm64_sqrdmlsh_scalar", XREG, XREG, XREG, XR
MINI_OP(OP_ARM64_TBL_INDIRECT, "arm64_tbl_indirect", XREG, IREG, XREG)
MINI_OP3(OP_ARM64_TBX_INDIRECT, "arm64_tbx_indirect", XREG, IREG, XREG, XREG)

MINI_OP(OP_ARM64_USHL, "arm64_ushl", XREG, XREG, XREG)

#endif // TARGET_ARM64

MINI_OP(OP_SIMD_FCVTL, "simd_convert_to_higher_precision", XREG, XREG, NONE)
Expand Down
150 changes: 148 additions & 2 deletions src/mono/mono/mini/simd-intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,96 @@ is_element_type_primitive (MonoType *vector_type)
}
}

static MonoInst*
emit_msb_vector_mask (MonoCompile *cfg, MonoClass *arg_class, MonoTypeEnum arg_type)
{
guint64 msb_mask_value[2];

switch (arg_type) {
case MONO_TYPE_I1:
case MONO_TYPE_U1:
msb_mask_value[0] = 0x8080808080808080;
msb_mask_value[1] = 0x8080808080808080;
break;
case MONO_TYPE_I2:
case MONO_TYPE_U2:
msb_mask_value[0] = 0x8000800080008000;
msb_mask_value[1] = 0x8000800080008000;
break;
#if TARGET_SIZEOF_VOID_P == 4
case MONO_TYPE_I:
case MONO_TYPE_U:
#endif
case MONO_TYPE_I4:
case MONO_TYPE_U4:
case MONO_TYPE_R4:
msb_mask_value[0] = 0x8000000080000000;
msb_mask_value[1] = 0x8000000080000000;
break;
#if TARGET_SIZEOF_VOID_P == 8
case MONO_TYPE_I:
case MONO_TYPE_U:
#endif
case MONO_TYPE_I8:
case MONO_TYPE_U8:
case MONO_TYPE_R8:
msb_mask_value[0] = 0x8000000000000000;
msb_mask_value[1] = 0x8000000000000000;
break;
default:
g_assert_not_reached ();
}

MonoInst* msb_mask_vec = emit_xconst_v128 (cfg, arg_class, (guint8*)msb_mask_value);
msb_mask_vec->klass = arg_class;
return msb_mask_vec;
}

static MonoInst*
emit_msb_shift_vector_constant (MonoCompile *cfg, MonoClass *arg_class, MonoTypeEnum arg_type)
{
guint64 msb_shift_value[2];

switch (arg_type) {
case MONO_TYPE_I1:
case MONO_TYPE_U1:
msb_shift_value[0] = 0x00FFFEFDFCFBFAF9;
msb_shift_value[1] = 0x00FFFEFDFCFBFAF9;
break;
case MONO_TYPE_I2:
case MONO_TYPE_U2:
msb_shift_value[0] = 0xFFF4FFF3FFF2FFF1;
msb_shift_value[1] = 0xFFF8FFF7FFF6FFF5;
break;
#if TARGET_SIZEOF_VOID_P == 4
case MONO_TYPE_I:
case MONO_TYPE_U:
#endif
case MONO_TYPE_I4:
case MONO_TYPE_U4:
case MONO_TYPE_R4:
msb_shift_value[0] = 0xFFFFFFE2FFFFFFE1;
msb_shift_value[1] = 0xFFFFFFE4FFFFFFE3;
break;
#if TARGET_SIZEOF_VOID_P == 8
case MONO_TYPE_I:
case MONO_TYPE_U:
#endif
case MONO_TYPE_I8:
case MONO_TYPE_U8:
case MONO_TYPE_R8:
msb_shift_value[0] = 0xFFFFFFFFFFFFFFC1;
msb_shift_value[1] = 0xFFFFFFFFFFFFFFC2;
break;
default:
g_assert_not_reached ();
}

MonoInst* msb_shift_vec = emit_xconst_v128 (cfg, arg_class, (guint8*)msb_shift_value);
msb_shift_vec->klass = arg_class;
return msb_shift_vec;
}

static MonoInst*
emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsig, MonoInst **args)
{
Expand Down Expand Up @@ -1234,7 +1324,6 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
case SN_ConvertToUInt64:
case SN_Create:
case SN_Dot:
case SN_ExtractMostSignificantBits:
case SN_GetElement:
case SN_GetLower:
case SN_GetUpper:
Expand Down Expand Up @@ -1542,7 +1631,64 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
return NULL;
#ifdef TARGET_WASM
return emit_simd_ins_for_sig (cfg, klass, OP_WASM_SIMD_BITMASK, -1, -1, fsig, args);
#else
#elif defined(TARGET_ARM64)
MonoInst* result_ins = NULL;
MonoClass* arg_class = mono_class_from_mono_type_internal (fsig->params [0]);
int size = mono_class_value_size (arg_class, NULL);
// FIXME: Support Vector64
if (size != 16)
return NULL;

// On arm64, there is no single instruction to perform the operation.
// To emulate the behavior we implement a similar approach to coreCLR by performing the following set of operations:
// The general case (all element types except byte/sbyte):
// 1. Extract the most significant bit of each element of the source vector by anding elements with MSB mask
// 2. Shift each element to the right by its index
// 3. Sum all the element values together tp get the final result of ExtractMSB
// Handling byte/sbyte element types:
// 1. Extract the most significant bit of each element of the source vector by anding elements with MSB mask
// 2. Shift each element to the right by its relative index to the high/low (64bit) boundary.
// - This basically means we treat the source 128bit, as two 64bit (high and low) vectors and shift their elements respectively
// 3. Sum the low 64bits of the shifted vector
// 4. Sum the high 64bits of the shifted vector
// 5. Shift left the result of 4) by 8, to adjust the value
// 6. OR the results from 3) and 5) to get the final result of ExtractMSB

MonoInst* msb_mask_vec = emit_msb_vector_mask (cfg, arg_class, arg0_type);
MonoInst* and_res_vec = emit_simd_ins_for_binary_op (cfg, arg_class, fsig, args, arg0_type, SN_BitwiseAnd);
and_res_vec->sreg2 = msb_mask_vec->dreg;

// NOTE: (on ARM64 ushl shifts a vector left or right depending on the sign of the shift constant)
MonoInst* msb_shift_vec = emit_msb_shift_vector_constant (cfg, arg_class, arg0_type);
MonoInst* shift_res_vec = emit_simd_ins (cfg, arg_class, OP_ARM64_USHL, and_res_vec->dreg, msb_shift_vec->dreg);
shift_res_vec->inst_c1 = arg0_type;

if (arg0_type == MONO_TYPE_I1 || arg0_type == MONO_TYPE_U1) {
// Always perform usigned operations as vector sum and extract operations could sign-extend the result into the GP register
// making the final result invalid. This is not needed for wider type as the maximum sum of extracted MSB cannot be larger than 8bits
arg0_type = MONO_TYPE_U1;

// In order to sum high and low 64bits of the shifted vector separatly, we use a zeroed vector and the extract operation
MonoInst* zero_vec = emit_xzero(cfg, arg_class);

MonoInst* ext_low_vec = emit_simd_ins (cfg, arg_class, OP_ARM64_EXT_IMM, zero_vec->dreg, shift_res_vec->dreg);
ext_low_vec->inst_c0 = 8;
ext_low_vec->inst_c1 = arg0_type;
MonoInst* sum_low_vec = emit_sum_vector (cfg, fsig->params [0], arg0_type, ext_low_vec);

MonoInst* ext_high_vec = emit_simd_ins (cfg, arg_class, OP_ARM64_EXT_IMM, shift_res_vec->dreg, zero_vec->dreg);
ext_high_vec->inst_c0 = 8;
ext_high_vec->inst_c1 = arg0_type;
MonoInst* sum_high_vec = emit_sum_vector (cfg, fsig->params [0], arg0_type, ext_high_vec);

MONO_EMIT_NEW_BIALU_IMM (cfg, OP_SHL_IMM, sum_high_vec->dreg, sum_high_vec->dreg, 8);
EMIT_NEW_BIALU (cfg, result_ins, OP_IOR, sum_high_vec->dreg, sum_high_vec->dreg, sum_low_vec->dreg);
} else {
result_ins = emit_sum_vector (cfg, fsig->params [0], arg0_type, shift_res_vec);
}
return result_ins;
#elif defined(TARGET_AMD64)
// FIXME: Support ExtractMostSignificantBits on AMD64
return NULL;
#endif
}
Expand Down

0 comments on commit 01e7f7a

Please sign in to comment.