Skip to content

Commit 90fdd3d

Browse files
maj160shssoichiro
authored andcommitted
Add 10-bit cdef_dist ASM
1 parent c2dfb35 commit 90fdd3d

File tree

1 file changed

+140
-8
lines changed

1 file changed

+140
-8
lines changed

src/asm/x86/dist/cdef_dist.rs

+140-8
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use crate::dist::*;
1313
use crate::tiling::PlaneRegion;
1414
use crate::util::Pixel;
1515
use crate::util::PixelType;
16+
use std::arch::x86_64::*;
1617

1718
type CdefDistKernelFn = unsafe extern fn(
1819
src: *const u8,
@@ -22,6 +23,13 @@ type CdefDistKernelFn = unsafe extern fn(
2223
ret_ptr: *mut u32,
2324
);
2425

26+
type CdefDistKernelHBDFn = unsafe fn(
27+
src: *const u16,
28+
src_stride: isize,
29+
dst: *const u16,
30+
dst_stride: isize,
31+
) -> (u32, u32, u32);
32+
2533
extern {
2634
fn rav1e_cdef_dist_kernel_4x4_sse2(
2735
src: *const u8, src_stride: isize, dst: *const u8, dst_stride: isize,
@@ -63,12 +71,12 @@ pub fn cdef_dist_kernel<T: Pixel>(
6371
#[cfg(feature = "check_asm")]
6472
let ref_dist = call_rust();
6573

66-
let mut ret_buf = [0u32; 3];
67-
match T::type_enum() {
74+
let (svar, dvar, sse) = match T::type_enum() {
6875
PixelType::U8 => {
6976
if let Some(func) =
7077
CDEF_DIST_KERNEL_FNS[cpu.as_index()][kernel_fn_index(w, h)]
7178
{
79+
let mut ret_buf = [0u32; 3];
7280
// SAFETY: Calls Assembly code.
7381
unsafe {
7482
func(
@@ -79,16 +87,30 @@ pub fn cdef_dist_kernel<T: Pixel>(
7987
ret_buf.as_mut_ptr(),
8088
)
8189
}
90+
91+
(ret_buf[0], ret_buf[1], ret_buf[2])
8292
} else {
8393
return call_rust();
8494
}
8595
}
86-
PixelType::U16 => return call_rust(),
87-
}
88-
89-
let svar = ret_buf[0];
90-
let dvar = ret_buf[1];
91-
let sse = ret_buf[2];
96+
PixelType::U16 => {
97+
if let Some(func) =
98+
CDEF_DIST_KERNEL_HBD_FNS[cpu.as_index()][kernel_fn_index(w, h)]
99+
{
100+
// SAFETY: Calls Assembly code.
101+
unsafe {
102+
func(
103+
src.data_ptr() as *const _,
104+
T::to_asm_stride(src.plane_cfg.stride),
105+
dst.data_ptr() as *const _,
106+
T::to_asm_stride(dst.plane_cfg.stride),
107+
)
108+
}
109+
} else {
110+
return call_rust();
111+
}
112+
}
113+
};
92114

93115
let dist = apply_ssim_boost(sse, svar, dvar, bit_depth);
94116
#[cfg(feature = "check_asm")]
@@ -128,6 +150,98 @@ cpu_function_lookup_table!(
128150
[SSE2]
129151
);
130152

153+
#[target_feature(enable = "avx2")]
154+
#[inline]
155+
unsafe fn mm256_sum_i32(ymm: __m256i) -> i32 {
156+
// We split the vector in half and then add the two halves and sum.
157+
let m1 = _mm256_extracti128_si256(ymm, 1);
158+
let m2 = _mm256_castsi256_si128(ymm);
159+
let m2 = _mm_add_epi32(m2, m1);
160+
let m1 = _mm_shuffle_epi32(m2, 0b11_10_11_10);
161+
let m2 = _mm_add_epi32(m2, m1);
162+
let m1 = _mm_shuffle_epi32(m2, 0b01_01_01_01);
163+
let m2 = _mm_add_epi32(m2, m1);
164+
_mm_cvtsi128_si32(m2)
165+
}
166+
167+
#[target_feature(enable = "avx2")]
168+
#[inline]
169+
unsafe fn rav1e_cdef_dist_kernel_8x8_hbd_avx2(
170+
src: *const u16, src_stride: isize, dst: *const u16, dst_stride: isize,
171+
) -> (u32, u32, u32) {
172+
let src = src as *const u8;
173+
let dst = dst as *const u8;
174+
175+
unsafe fn sum16(src: *const u8, src_stride: isize) -> u32 {
176+
let h = 8;
177+
let res = (0..h)
178+
.map(|row| _mm_load_si128(src.offset(row * src_stride) as *const _))
179+
.reduce(|a, b| _mm_add_epi16(a, b))
180+
.unwrap();
181+
182+
let m32 = _mm256_cvtepi16_epi32(res);
183+
mm256_sum_i32(m32) as u32
184+
}
185+
unsafe fn mpadd32(
186+
src: *const u8, src_stride: isize, dst: *const u8, dst_stride: isize,
187+
) -> u32 {
188+
let h = 8;
189+
let res = (0..h / 2)
190+
.map(|row| {
191+
let s1 = _mm_load_si128(src.offset(2 * row * src_stride) as *const _);
192+
let s2 =
193+
_mm_load_si128(src.offset((2 * row + 1) * src_stride) as *const _);
194+
let s = _mm256_inserti128_si256(_mm256_castsi128_si256(s1), s2, 1);
195+
196+
let d1 = _mm_load_si128(dst.offset(2 * row * dst_stride) as *const _);
197+
let d2 =
198+
_mm_load_si128(dst.offset((2 * row + 1) * dst_stride) as *const _);
199+
let d = _mm256_inserti128_si256(_mm256_castsi128_si256(d1), d2, 1);
200+
201+
_mm256_madd_epi16(s, d)
202+
})
203+
.reduce(|a, b| _mm256_add_epi32(a, b))
204+
.unwrap();
205+
mm256_sum_i32(res) as u32
206+
}
207+
208+
let sum_s = sum16(src, src_stride);
209+
let sum_d = sum16(dst, dst_stride);
210+
let sum_s2 = mpadd32(src, src_stride, src, src_stride);
211+
let sum_d2 = mpadd32(dst, dst_stride, dst, dst_stride);
212+
let sum_sd = mpadd32(src, src_stride, dst, dst_stride);
213+
214+
// To get the distortion, compute sum of squared error and apply a weight
215+
// based on the variance of the two planes.
216+
let sse = sum_d2 + sum_s2 - 2 * sum_sd;
217+
218+
// Convert to 64-bits to avoid overflow when squaring
219+
let sum_s = sum_s as u64;
220+
let sum_d = sum_d as u64;
221+
222+
let svar = (sum_s2 as u64 - (sum_s * sum_s) / 64) as u32;
223+
let dvar = (sum_d2 as u64 - (sum_d * sum_d) / 64) as u32;
224+
225+
(svar, dvar, sse)
226+
}
227+
228+
static CDEF_DIST_KERNEL_HBD_FNS_AVX2: [Option<CdefDistKernelHBDFn>;
229+
CDEF_DIST_KERNEL_FNS_LENGTH] = {
230+
let mut out: [Option<CdefDistKernelHBDFn>; CDEF_DIST_KERNEL_FNS_LENGTH] =
231+
[None; CDEF_DIST_KERNEL_FNS_LENGTH];
232+
233+
out[kernel_fn_index(8, 8)] = Some(rav1e_cdef_dist_kernel_8x8_hbd_avx2);
234+
235+
out
236+
};
237+
238+
cpu_function_lookup_table!(
239+
CDEF_DIST_KERNEL_HBD_FNS:
240+
[[Option<CdefDistKernelHBDFn>; CDEF_DIST_KERNEL_FNS_LENGTH]],
241+
default: [None; CDEF_DIST_KERNEL_FNS_LENGTH],
242+
[AVX2]
243+
);
244+
131245
#[cfg(test)]
132246
pub mod test {
133247
use super::*;
@@ -204,16 +318,34 @@ pub mod test {
204318
cdef_diff_tester(8, random_planes::<u8>);
205319
}
206320

321+
#[test]
322+
fn cdef_dist_simd_random_hbd() {
323+
cdef_diff_tester(10, random_planes::<u16>);
324+
cdef_diff_tester(12, random_planes::<u16>);
325+
}
326+
207327
#[test]
208328
fn cdef_dist_simd_large() {
209329
cdef_diff_tester(8, max_planes::<u8>);
210330
}
211331

332+
#[test]
333+
fn cdef_dist_simd_large_hbd() {
334+
cdef_diff_tester(10, max_planes::<u16>);
335+
cdef_diff_tester(12, max_planes::<u16>);
336+
}
337+
212338
#[test]
213339
fn cdef_dist_simd_large_diff() {
214340
cdef_diff_tester(8, max_diff_planes::<u8>);
215341
}
216342

343+
#[test]
344+
fn cdef_dist_simd_large_diff_hbd() {
345+
cdef_diff_tester(10, max_diff_planes::<u16>);
346+
cdef_diff_tester(12, max_diff_planes::<u16>);
347+
}
348+
217349
fn cdef_diff_tester<T: Pixel>(
218350
bd: usize, gen_planes: fn(bd: usize) -> (Plane<T>, Plane<T>),
219351
) {

0 commit comments

Comments
 (0)