Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use generics for bit depth throughout the encoder #3121

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions src/activity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ impl ActivityMask {
}

#[hawktracer(activity_mask_fill_scales)]
pub fn fill_scales(
&self, bit_depth: usize, activity_scales: &mut Box<[DistortionScale]>,
pub fn fill_scales<const BD: usize>(
&self, activity_scales: &mut Box<[DistortionScale]>,
) {
for (dst, &src) in activity_scales.iter_mut().zip(self.variances.iter()) {
*dst = ssim_boost(src, src, bit_depth);
*dst = ssim_boost::<BD>(src, src);
}
}
}
Expand Down Expand Up @@ -146,21 +146,20 @@ fn ssim_boost_rsqrt(x: u64) -> RsqrtOutput {
}

#[inline(always)]
pub fn ssim_boost(svar: u32, dvar: u32, bit_depth: usize) -> DistortionScale {
DistortionScale(apply_ssim_boost(
pub fn ssim_boost<const BD: usize>(svar: u32, dvar: u32) -> DistortionScale {
DistortionScale(apply_ssim_boost::<BD>(
DistortionScale::default().0,
svar,
dvar,
bit_depth,
))
}

/// Apply ssim boost to a given input
#[inline(always)]
pub fn apply_ssim_boost(
input: u32, svar: u32, dvar: u32, bit_depth: usize,
pub fn apply_ssim_boost<const BD: usize>(
input: u32, svar: u32, dvar: u32,
) -> u32 {
let coeff_shift = bit_depth - 8;
let coeff_shift = BD - 8;

// Scale dvar and svar to lbd range to prevent overflows.
let svar = (svar >> (2 * coeff_shift)) as u64;
Expand Down Expand Up @@ -199,7 +198,7 @@ mod ssim_boost_tests {
let max_pix_diff = (1 << 12) - 1;
let max_pix_sse = max_pix_diff * max_pix_diff;
let max_variance = max_pix_diff * 8 * 8 / 4;
apply_ssim_boost(max_pix_sse * 8 * 8, max_variance, max_variance, 12);
apply_ssim_boost::<12>(max_pix_sse * 8 * 8, max_variance, max_variance);
}

/// Floating point reference version of `ssim_boost`
Expand Down Expand Up @@ -234,8 +233,8 @@ mod ssim_boost_tests {
let dvar = rng.gen_range(0..(1 << scale));

let float = reference_ssim_boost(svar, dvar, 12);
let fixed =
apply_ssim_boost(1 << 23, svar, dvar, 12) as f64 / (1 << 23) as f64;
let fixed = apply_ssim_boost::<12>(1 << 23, svar, dvar) as f64
/ (1 << 23) as f64;

// Compare the two versions
max_relative_error =
Expand All @@ -259,8 +258,13 @@ mod ssim_boost_tests {
let scale = ((1 << bd) - 1) << (6 - 2 + bd - 8);
for svar in scale..(scale << 2) {
let float = ((scale << 1) as f64 / svar as f64).cbrt();
let fixed =
apply_ssim_boost(1 << 23, svar, svar, bd) as f64 / (1 << 23) as f64;
let fixed = match bd {
8 => apply_ssim_boost::<8>(1 << 23, svar, svar),
10 => apply_ssim_boost::<10>(1 << 23, svar, svar),
12 => apply_ssim_boost::<12>(1 << 23, svar, svar),
_ => unimplemented!(),
} as f64
/ (1 << 23) as f64;

// Compare the two versions
max_relative_error =
Expand Down
11 changes: 9 additions & 2 deletions src/api/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,15 @@ impl Config {
// First-pass parameters depend on whether second-pass is in effect.
// So `init_first_pass` must follow `init_second_pass`.
if self.rate_control.emit_pass_data {
let maybe_pass1_log_base_q = (self.rate_control.summary.is_none())
.then(|| inner.rc_state.select_pass1_log_base_q(&inner, 0));
let maybe_pass1_log_base_q =
(self.rate_control.summary.is_none()).then(|| {
match self.enc.bit_depth {
8 => inner.rc_state.select_pass1_log_base_q::<_, 8>(&inner, 0),
10 => inner.rc_state.select_pass1_log_base_q::<_, 10>(&inner, 0),
12 => inner.rc_state.select_pass1_log_base_q::<_, 12>(&inner, 0),
_ => unimplemented!(),
}
});
inner.rc_state.init_first_pass(maybe_pass1_log_base_q);
}

Expand Down
14 changes: 12 additions & 2 deletions src/api/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,12 @@ impl<T: Pixel> Context<T> {
}

let inner = &mut self.inner;
let run = move || inner.send_frame(frame, params);
let run = move || match inner.config.bit_depth {
8 => inner.send_frame::<8>(frame, params),
10 => inner.send_frame::<10>(frame, params),
12 => inner.send_frame::<12>(frame, params),
_ => unimplemented!(),
};

match &self.pool {
Some(pool) => pool.install(run),
Expand Down Expand Up @@ -302,7 +307,12 @@ impl<T: Pixel> Context<T> {
#[inline]
pub fn receive_packet(&mut self) -> Result<Packet<T>, EncoderStatus> {
let inner = &mut self.inner;
let mut run = move || inner.receive_packet();
let mut run = move || match inner.config.bit_depth {
8 => inner.receive_packet::<8>(),
10 => inner.receive_packet::<10>(),
12 => inner.receive_packet::<12>(),
_ => unimplemented!(),
};

match &self.pool {
Some(pool) => pool.install(run),
Expand Down
Loading