#[cfg(test)] const MAX_COMPARE_SIZE: usize = 256; pub fn compare256_slice(src0: &[u8], src1: &[u8]) -> usize { let src0 = first_chunk::<_, 256>(src0).unwrap(); let src1 = first_chunk::<_, 256>(src1).unwrap(); compare256(src0, src1) } fn compare256(src0: &[u8; 256], src1: &[u8; 256]) -> usize { #[cfg(feature = "avx512")] #[cfg(target_arch = "x86_64")] if cfg!(target_feature = "avx512vl") && cfg!(target_feature = "avx512bw") { return unsafe { avx512::compare256(src0, src1) }; } #[cfg(target_arch = "x86_64")] if crate::cpu_features::is_enabled_avx2_and_bmi2() { return unsafe { avx2::compare256(src0, src1) }; } #[cfg(target_arch = "aarch64")] if crate::cpu_features::is_enabled_neon() { return unsafe { neon::compare256(src0, src1) }; } #[cfg(target_arch = "wasm32")] if crate::cpu_features::is_enabled_simd128() { return wasm32::compare256(src0, src1); } rust::compare256(src0, src1) } pub fn compare256_rle_slice(byte: u8, src: &[u8]) -> usize { rust::compare256_rle(byte, src) } #[inline] pub const fn first_chunk(slice: &[T]) -> Option<&[T; N]> { if slice.len() < N { None } else { // SAFETY: We explicitly check for the correct number of elements, // and do not let the reference outlive the slice. Some(unsafe { &*(slice.as_ptr() as *const [T; N]) }) } } mod rust { pub fn compare256(src0: &[u8; 256], src1: &[u8; 256]) -> usize { // only unrolls 4 iterations; zlib-ng unrolls 8 src0.iter().zip(src1).take_while(|(x, y)| x == y).count() } // run-length encoding pub fn compare256_rle(byte: u8, src: &[u8]) -> usize { assert!(src.len() >= 256, "too short {}", src.len()); let sv = u64::from_ne_bytes([byte; 8]); let mut len = 0; // this optimizes well because we statically limit the slice to 256 bytes. // the loop gets unrolled 4 times automatically. for chunk in src[..256].chunks_exact(8) { let mv = u64::from_le_bytes(chunk.try_into().unwrap()); let diff = sv ^ mv; if diff > 0 { let match_byte = diff.trailing_zeros() / 8; return len + match_byte as usize; } len += 8 } 256 } #[test] fn test_compare256() { let str1 = [b'a'; super::MAX_COMPARE_SIZE]; let mut str2 = [b'a'; super::MAX_COMPARE_SIZE]; for i in 0..str1.len() { str2[i] = 0; let match_len = compare256(&str1, &str2); assert_eq!(match_len, i); str2[i] = b'a'; } } #[test] fn test_compare256_rle() { let mut string = [b'a'; super::MAX_COMPARE_SIZE]; for i in 0..string.len() { string[i] = 0; let match_len = compare256_rle(b'a', &string); assert_eq!(match_len, i); string[i] = b'a'; } } } #[cfg(target_arch = "aarch64")] mod neon { use core::arch::aarch64::{ uint8x16x4_t, vceqq_u8, vget_lane_u64, vld4q_u8, vreinterpret_u64_u8, vreinterpretq_u16_u8, vshrn_n_u16, vsriq_n_u8, }; /// # Safety /// /// Behavior is undefined if the `neon` target feature is not enabled #[target_feature(enable = "neon")] pub unsafe fn compare256(src0: &[u8; 256], src1: &[u8; 256]) -> usize { type Chunk = uint8x16x4_t; let src0 = src0.chunks_exact(core::mem::size_of::()); let src1 = src1.chunks_exact(core::mem::size_of::()); let mut len = 0; for (a, b) in src0.zip(src1) { unsafe { // Load 4 vectors *deinterleaved* from the two slices // e.g. the first vector contains the 0, 4, 8, ... bytes of the input, the // second vector contains the 1, 5, 9, ... bytes of the input, etc. let a: Chunk = vld4q_u8(a.as_ptr()); let b: Chunk = vld4q_u8(b.as_ptr()); // Compare each vector element-wise, each resulting vector will contain // 0xFF for equal bytes, and 0x00 for unequal bytes. let cmp0 = vceqq_u8(a.0, b.0); let cmp1 = vceqq_u8(a.1, b.1); let cmp2 = vceqq_u8(a.2, b.2); let cmp3 = vceqq_u8(a.3, b.3); // Pack bits from the 4 vectors into a single vector to convert to a 64-bit integer. // shift the second vector right by one, insert the top bit from the first vector // The top two bits each element of the result are from the first and second vector let first_two_bits = vsriq_n_u8::<1>(cmp1, cmp0); // shift the fourth vector right by one, insert the top bit from the third vector // The top two bits each element of the result are from the third and fourth vector let last_two_bits = vsriq_n_u8::<1>(cmp3, cmp2); // shift last_two_bits (the top two bits of which are from the third and fourth // vector) right by 2, insert the top two bits from first_two_bits (the top two // bits of which are from the first and second vector). // The top four bits of each element of the result are from the // first, second, third, and fourth vector let first_four_bits = vsriq_n_u8::<2>(last_two_bits, first_two_bits); // duplicate the top 4 bits into the bottom 4 bits of each element. let bitmask_vector = vsriq_n_u8::<4>(first_four_bits, first_four_bits); // Reinterpret as 16-bit integers, and shift right by 4 bits narrowing: // shifting right by 4 bits means the top 4 bits of each 16 bit element contains the // low 4 bits of the 0th 8-bit element and the high 4 bits of the 1nth 8-bit // element. Narrowing takes the top 8 bits of each (16-bit) element. let result_vector = vshrn_n_u16::<4>(vreinterpretq_u16_u8(bitmask_vector)); // Convert the vector to a 64-bit integer, where each bit represents whether // the corresponding byte in the original vectors was equal. let bitmask = vget_lane_u64::<0>(vreinterpret_u64_u8(result_vector)); // We reinterpreted the vector as a 64-bit integer, so endianness matters. // We want things to be in little-endian (where the least significant bit is in the // first byte), but in big-endian, the first vector element will be the most // significant byte, so we need to convert to little-endian. let bitmask = bitmask.to_le(); if bitmask != u64::MAX { // Find the first byte that is not equal, which is the first bit that is not set let match_byte = bitmask.trailing_ones(); return len + match_byte as usize; } len += core::mem::size_of::(); } } 256 } #[test] fn test_compare256() { if crate::cpu_features::is_enabled_neon() { let str1 = [b'a'; super::MAX_COMPARE_SIZE]; let mut str2 = [b'a'; super::MAX_COMPARE_SIZE]; for i in 0..str1.len() { str2[i] = 0; let match_len = unsafe { compare256(&str1, &str2) }; assert_eq!(match_len, i); str2[i] = b'a'; } } } } #[cfg(target_arch = "x86_64")] mod avx2 { use core::arch::x86_64::{ __m256i, _mm256_cmpeq_epi8, _mm256_loadu_si256, _mm256_movemask_epi8, }; /// # Safety /// /// Behavior is undefined if the `avx` target feature is not enabled #[target_feature(enable = "avx2")] #[target_feature(enable = "bmi2")] #[target_feature(enable = "bmi1")] pub unsafe fn compare256(src0: &[u8; 256], src1: &[u8; 256]) -> usize { let src0 = src0.chunks_exact(32); let src1 = src1.chunks_exact(32); let mut len = 0; unsafe { for (chunk0, chunk1) in src0.zip(src1) { let ymm_src0 = _mm256_loadu_si256(chunk0.as_ptr() as *const __m256i); let ymm_src1 = _mm256_loadu_si256(chunk1.as_ptr() as *const __m256i); // element-wise compare of the 8-bit elements let ymm_cmp = _mm256_cmpeq_epi8(ymm_src0, ymm_src1); // turn an 32 * 8-bit vector into a 32-bit integer. // a bit in the output is one if the corresponding element is non-zero. let mask = _mm256_movemask_epi8(ymm_cmp) as u32; if mask != 0xFFFFFFFF { let match_byte = mask.trailing_ones(); return len + match_byte as usize; } len += 32; } } 256 } #[test] fn test_compare256() { if crate::cpu_features::is_enabled_avx2_and_bmi2() { let str1 = [b'a'; super::MAX_COMPARE_SIZE]; let mut str2 = [b'a'; super::MAX_COMPARE_SIZE]; for i in 0..str1.len() { str2[i] = 0; let match_len = unsafe { compare256(&str1, &str2) }; assert_eq!(match_len, i); str2[i] = b'a'; } } } } #[cfg(feature = "avx512")] #[cfg(target_arch = "x86_64")] mod avx512 { use core::arch::x86_64::{ _mm512_cmpeq_epu8_mask, _mm512_loadu_si512, _mm_cmpeq_epu8_mask, _mm_loadu_si128, }; /// # Safety /// /// Behavior is undefined if the `avx` target feature is not enabled #[target_feature(enable = "avx512vl")] #[target_feature(enable = "avx512bw")] pub unsafe fn compare256(src0: &[u8; 256], src1: &[u8; 256]) -> usize { // First do a 16byte round before increasing to 64bytes, this reduces the // penalty for the short matches, and those are usually the most common ones. // This requires us to overlap on the last round, giving a small penalty // on matches of 192+ bytes (Still faster than AVX2 though). unsafe { // 16 bytes let xmm_src0_0 = _mm_loadu_si128(src0.as_ptr().cast()); let xmm_src1_0 = _mm_loadu_si128(src1.as_ptr().cast()); let mask_0 = u32::from(_mm_cmpeq_epu8_mask(xmm_src0_0, xmm_src1_0)); // zero-extended to use __builtin_ctz if mask_0 != 0x0000FFFF { // There is potential for using __builtin_ctzg/__builtin_ctzs/_tzcnt_u16/__tzcnt_u16 here let match_byte = mask_0.trailing_ones(); return match_byte as usize; } // 64 bytes let zmm_src0_1 = _mm512_loadu_si512(src0[16..].as_ptr().cast()); let zmm_src1_1 = _mm512_loadu_si512(src1[16..].as_ptr().cast()); let mask_1 = _mm512_cmpeq_epu8_mask(zmm_src0_1, zmm_src1_1); if mask_1 != 0xFFFFFFFFFFFFFFFF { let match_byte = mask_1.trailing_ones(); return 16 + match_byte as usize; } // 64 bytes let zmm_src0_2 = _mm512_loadu_si512(src0[80..].as_ptr().cast()); let zmm_src1_2 = _mm512_loadu_si512(src1[80..].as_ptr().cast()); let mask_2 = _mm512_cmpeq_epu8_mask(zmm_src0_2, zmm_src1_2); if mask_2 != 0xFFFFFFFFFFFFFFFF { let match_byte = mask_2.trailing_ones(); return 80 + match_byte as usize; } // 64 bytes let zmm_src0_3 = _mm512_loadu_si512(src0[144..].as_ptr().cast()); let zmm_src1_3 = _mm512_loadu_si512(src1[144..].as_ptr().cast()); let mask_3 = _mm512_cmpeq_epu8_mask(zmm_src0_3, zmm_src1_3); if mask_3 != 0xFFFFFFFFFFFFFFFF { let match_byte = mask_3.trailing_ones(); return 144 + match_byte as usize; } // 64 bytes (overlaps the previous 16 bytes for fast tail processing) let zmm_src0_4 = _mm512_loadu_si512(src0[192..].as_ptr().cast()); let zmm_src1_4 = _mm512_loadu_si512(src1[192..].as_ptr().cast()); let mask_4 = _mm512_cmpeq_epu8_mask(zmm_src0_4, zmm_src1_4); if mask_4 != 0xFFFFFFFFFFFFFFFF { let match_byte = mask_4.trailing_ones(); return 192 + match_byte as usize; } } 256 } #[test] fn test_compare256() { if true { let str1 = [b'a'; super::MAX_COMPARE_SIZE]; let mut str2 = [b'a'; super::MAX_COMPARE_SIZE]; for i in 0..str1.len() { str2[i] = 0; let match_len = unsafe { compare256(&str1, &str2) }; assert_eq!(match_len, i); str2[i] = b'a'; } } } } #[cfg(target_arch = "wasm32")] mod wasm32 { use core::arch::wasm32::{u8x16_bitmask, u8x16_eq, v128, v128_load}; #[target_feature(enable = "simd128")] pub fn compare256(src0: &[u8; 256], src1: &[u8; 256]) -> usize { let src0 = src0.chunks_exact(16); let src1 = src1.chunks_exact(16); let mut len = 0; for (chunk0, chunk1) in src0.zip(src1) { // SAFETY: these are valid pointers to slice data. let v128_src0 = unsafe { v128_load(chunk0.as_ptr() as *const v128) }; let v128_src1 = unsafe { v128_load(chunk1.as_ptr() as *const v128) }; let v128_cmp = u8x16_eq(v128_src0, v128_src1); let mask = u8x16_bitmask(v128_cmp); if mask != 0xFFFF { let match_byte = mask.trailing_ones(); return len + match_byte as usize; } len += 16; } 256 } #[test] fn test_compare256() { if crate::cpu_features::is_enabled_simd128() { let str1 = [b'a'; super::MAX_COMPARE_SIZE]; let mut str2 = [b'a'; super::MAX_COMPARE_SIZE]; for i in 0..str1.len() { str2[i] = 0; let match_len = unsafe { compare256(&str1, &str2) }; assert_eq!(match_len, i); str2[i] = b'a'; } } } }