@@ -13,6 +13,7 @@ use crate::dist::*;
13
13
use crate :: tiling:: PlaneRegion ;
14
14
use crate :: util:: Pixel ;
15
15
use crate :: util:: PixelType ;
16
+ use std:: arch:: x86_64:: * ;
16
17
17
18
type CdefDistKernelFn = unsafe extern fn (
18
19
src : * const u8 ,
@@ -22,6 +23,13 @@ type CdefDistKernelFn = unsafe extern fn(
22
23
ret_ptr : * mut u32 ,
23
24
) ;
24
25
26
+ type CdefDistKernelHBDFn = unsafe fn (
27
+ src : * const u16 ,
28
+ src_stride : isize ,
29
+ dst : * const u16 ,
30
+ dst_stride : isize ,
31
+ ) -> ( u32 , u32 , u32 ) ;
32
+
25
33
extern {
26
34
fn rav1e_cdef_dist_kernel_4x4_sse2 (
27
35
src : * const u8 , src_stride : isize , dst : * const u8 , dst_stride : isize ,
@@ -63,12 +71,12 @@ pub fn cdef_dist_kernel<T: Pixel>(
63
71
#[ cfg( feature = "check_asm" ) ]
64
72
let ref_dist = call_rust ( ) ;
65
73
66
- let mut ret_buf = [ 0u32 ; 3 ] ;
67
- match T :: type_enum ( ) {
74
+ let ( svar, dvar, sse) = match T :: type_enum ( ) {
68
75
PixelType :: U8 => {
69
76
if let Some ( func) =
70
77
CDEF_DIST_KERNEL_FNS [ cpu. as_index ( ) ] [ kernel_fn_index ( w, h) ]
71
78
{
79
+ let mut ret_buf = [ 0u32 ; 3 ] ;
72
80
// SAFETY: Calls Assembly code.
73
81
unsafe {
74
82
func (
@@ -79,16 +87,30 @@ pub fn cdef_dist_kernel<T: Pixel>(
79
87
ret_buf. as_mut_ptr ( ) ,
80
88
)
81
89
}
90
+
91
+ ( ret_buf[ 0 ] , ret_buf[ 1 ] , ret_buf[ 2 ] )
82
92
} else {
83
93
return call_rust ( ) ;
84
94
}
85
95
}
86
- PixelType :: U16 => return call_rust ( ) ,
87
- }
88
-
89
- let svar = ret_buf[ 0 ] ;
90
- let dvar = ret_buf[ 1 ] ;
91
- let sse = ret_buf[ 2 ] ;
96
+ PixelType :: U16 => {
97
+ if let Some ( func) =
98
+ CDEF_DIST_KERNEL_HBD_FNS [ cpu. as_index ( ) ] [ kernel_fn_index ( w, h) ]
99
+ {
100
+ // SAFETY: Calls Assembly code.
101
+ unsafe {
102
+ func (
103
+ src. data_ptr ( ) as * const _ ,
104
+ T :: to_asm_stride ( src. plane_cfg . stride ) ,
105
+ dst. data_ptr ( ) as * const _ ,
106
+ T :: to_asm_stride ( dst. plane_cfg . stride ) ,
107
+ )
108
+ }
109
+ } else {
110
+ return call_rust ( ) ;
111
+ }
112
+ }
113
+ } ;
92
114
93
115
let dist = apply_ssim_boost ( sse, svar, dvar, bit_depth) ;
94
116
#[ cfg( feature = "check_asm" ) ]
@@ -128,6 +150,98 @@ cpu_function_lookup_table!(
128
150
[ SSE2 ]
129
151
) ;
130
152
153
+ #[ target_feature( enable = "avx2" ) ]
154
+ #[ inline]
155
+ unsafe fn mm256_sum_i32 ( ymm : __m256i ) -> i32 {
156
+ // We split the vector in half and then add the two halves and sum.
157
+ let m1 = _mm256_extracti128_si256 ( ymm, 1 ) ;
158
+ let m2 = _mm256_castsi256_si128 ( ymm) ;
159
+ let m2 = _mm_add_epi32 ( m2, m1) ;
160
+ let m1 = _mm_shuffle_epi32 ( m2, 0b11_10_11_10 ) ;
161
+ let m2 = _mm_add_epi32 ( m2, m1) ;
162
+ let m1 = _mm_shuffle_epi32 ( m2, 0b01_01_01_01 ) ;
163
+ let m2 = _mm_add_epi32 ( m2, m1) ;
164
+ _mm_cvtsi128_si32 ( m2)
165
+ }
166
+
167
+ #[ target_feature( enable = "avx2" ) ]
168
+ #[ inline]
169
+ unsafe fn rav1e_cdef_dist_kernel_8x8_hbd_avx2 (
170
+ src : * const u16 , src_stride : isize , dst : * const u16 , dst_stride : isize ,
171
+ ) -> ( u32 , u32 , u32 ) {
172
+ let src = src as * const u8 ;
173
+ let dst = dst as * const u8 ;
174
+
175
+ unsafe fn sum16 ( src : * const u8 , src_stride : isize ) -> u32 {
176
+ let h = 8 ;
177
+ let res = ( 0 ..h)
178
+ . map ( |row| _mm_load_si128 ( src. offset ( row * src_stride) as * const _ ) )
179
+ . reduce ( |a, b| _mm_add_epi16 ( a, b) )
180
+ . unwrap ( ) ;
181
+
182
+ let m32 = _mm256_cvtepi16_epi32 ( res) ;
183
+ mm256_sum_i32 ( m32) as u32
184
+ }
185
+ unsafe fn mpadd32 (
186
+ src : * const u8 , src_stride : isize , dst : * const u8 , dst_stride : isize ,
187
+ ) -> u32 {
188
+ let h = 8 ;
189
+ let res = ( 0 ..h / 2 )
190
+ . map ( |row| {
191
+ let s1 = _mm_load_si128 ( src. offset ( 2 * row * src_stride) as * const _ ) ;
192
+ let s2 =
193
+ _mm_load_si128 ( src. offset ( ( 2 * row + 1 ) * src_stride) as * const _ ) ;
194
+ let s = _mm256_inserti128_si256 ( _mm256_castsi128_si256 ( s1) , s2, 1 ) ;
195
+
196
+ let d1 = _mm_load_si128 ( dst. offset ( 2 * row * dst_stride) as * const _ ) ;
197
+ let d2 =
198
+ _mm_load_si128 ( dst. offset ( ( 2 * row + 1 ) * dst_stride) as * const _ ) ;
199
+ let d = _mm256_inserti128_si256 ( _mm256_castsi128_si256 ( d1) , d2, 1 ) ;
200
+
201
+ _mm256_madd_epi16 ( s, d)
202
+ } )
203
+ . reduce ( |a, b| _mm256_add_epi32 ( a, b) )
204
+ . unwrap ( ) ;
205
+ mm256_sum_i32 ( res) as u32
206
+ }
207
+
208
+ let sum_s = sum16 ( src, src_stride) ;
209
+ let sum_d = sum16 ( dst, dst_stride) ;
210
+ let sum_s2 = mpadd32 ( src, src_stride, src, src_stride) ;
211
+ let sum_d2 = mpadd32 ( dst, dst_stride, dst, dst_stride) ;
212
+ let sum_sd = mpadd32 ( src, src_stride, dst, dst_stride) ;
213
+
214
+ // To get the distortion, compute sum of squared error and apply a weight
215
+ // based on the variance of the two planes.
216
+ let sse = sum_d2 + sum_s2 - 2 * sum_sd;
217
+
218
+ // Convert to 64-bits to avoid overflow when squaring
219
+ let sum_s = sum_s as u64 ;
220
+ let sum_d = sum_d as u64 ;
221
+
222
+ let svar = ( sum_s2 as u64 - ( sum_s * sum_s) / 64 ) as u32 ;
223
+ let dvar = ( sum_d2 as u64 - ( sum_d * sum_d) / 64 ) as u32 ;
224
+
225
+ ( svar, dvar, sse)
226
+ }
227
+
228
+ static CDEF_DIST_KERNEL_HBD_FNS_AVX2 : [ Option < CdefDistKernelHBDFn > ;
229
+ CDEF_DIST_KERNEL_FNS_LENGTH ] = {
230
+ let mut out: [ Option < CdefDistKernelHBDFn > ; CDEF_DIST_KERNEL_FNS_LENGTH ] =
231
+ [ None ; CDEF_DIST_KERNEL_FNS_LENGTH ] ;
232
+
233
+ out[ kernel_fn_index ( 8 , 8 ) ] = Some ( rav1e_cdef_dist_kernel_8x8_hbd_avx2) ;
234
+
235
+ out
236
+ } ;
237
+
238
+ cpu_function_lookup_table ! (
239
+ CDEF_DIST_KERNEL_HBD_FNS :
240
+ [ [ Option <CdefDistKernelHBDFn >; CDEF_DIST_KERNEL_FNS_LENGTH ] ] ,
241
+ default : [ None ; CDEF_DIST_KERNEL_FNS_LENGTH ] ,
242
+ [ AVX2 ]
243
+ ) ;
244
+
131
245
#[ cfg( test) ]
132
246
pub mod test {
133
247
use super :: * ;
@@ -204,16 +318,34 @@ pub mod test {
204
318
cdef_diff_tester ( 8 , random_planes :: < u8 > ) ;
205
319
}
206
320
321
+ #[ test]
322
+ fn cdef_dist_simd_random_hbd ( ) {
323
+ cdef_diff_tester ( 10 , random_planes :: < u16 > ) ;
324
+ cdef_diff_tester ( 12 , random_planes :: < u16 > ) ;
325
+ }
326
+
207
327
#[ test]
208
328
fn cdef_dist_simd_large ( ) {
209
329
cdef_diff_tester ( 8 , max_planes :: < u8 > ) ;
210
330
}
211
331
332
+ #[ test]
333
+ fn cdef_dist_simd_large_hbd ( ) {
334
+ cdef_diff_tester ( 10 , max_planes :: < u16 > ) ;
335
+ cdef_diff_tester ( 12 , max_planes :: < u16 > ) ;
336
+ }
337
+
212
338
#[ test]
213
339
fn cdef_dist_simd_large_diff ( ) {
214
340
cdef_diff_tester ( 8 , max_diff_planes :: < u8 > ) ;
215
341
}
216
342
343
+ #[ test]
344
+ fn cdef_dist_simd_large_diff_hbd ( ) {
345
+ cdef_diff_tester ( 10 , max_diff_planes :: < u16 > ) ;
346
+ cdef_diff_tester ( 12 , max_diff_planes :: < u16 > ) ;
347
+ }
348
+
217
349
fn cdef_diff_tester < T : Pixel > (
218
350
bd : usize , gen_planes : fn ( bd : usize ) -> ( Plane < T > , Plane < T > ) ,
219
351
) {
0 commit comments