Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create SIMD-accelerated version of compute_gru function #191

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
8 changes: 8 additions & 0 deletions README
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
RNNoise is a noise suppression library based on a recurrent neural network.
A description of the algorithm is provided in the following paper:

J.-M. Valin, A Hybrid DSP/Deep Learning Approach to Real-Time Full-Band Speech
Enhancement, Proceedings of IEEE Multimedia Signal Processing (MMSP) Workshop,
arXiv:1709.08243, 2018.
https://arxiv.org/pdf/1709.08243.pdf

An interactive demo is available at: https://jmvalin.ca/demo/rnnoise/

To compile, just type:
% ./autogen.sh
Expand Down
2 changes: 1 addition & 1 deletion configure.ac
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ AC_SUBST(OP_LT_REVISION)
AC_SUBST(OP_LT_AGE)

CC_CHECK_CFLAGS_APPEND(
[-pedantic -Wall -Wextra -Wno-sign-compare -Wno-parentheses -Wno-long-long])
[-O3 -march=native -pedantic -Wall -Wextra -Wno-sign-compare -Wno-parentheses -Wno-long-long])

# Platform-specific tweaks
case $host in
Expand Down
2 changes: 1 addition & 1 deletion src/compile.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/bin/sh

gcc -DTRAINING=1 -Wall -W -O3 -g -I../include denoise.c kiss_fft.c pitch.c celt_lpc.c rnn.c rnn_data.c -o denoise_training -lm
gcc -DTRAINING=1 -march=native -Wall -W -O3 -g -I../include denoise.c kiss_fft.c pitch.c celt_lpc.c rnn.c rnn_data.c -o denoise_training -lm
12 changes: 5 additions & 7 deletions src/denoise.c
Original file line number Diff line number Diff line change
Expand Up @@ -408,13 +408,11 @@ static void frame_synthesis(DenoiseState *st, float *out, const kiss_fft_cpx *y)
}

static void biquad(float *y, float mem[2], const float *x, const float *b, const float *a, int N) {
int i;
for (i=0;i<N;i++) {
float xi, yi;
xi = x[i];
yi = x[i] + mem[0];
mem[0] = mem[1] + (b[0]*(double)xi - a[0]*(double)yi);
mem[1] = (b[1]*(double)xi - a[1]*(double)yi);
for (int i=0;i<N;i++) {
float xi = x[i];
float yi = xi + mem[0];
mem[0] = mem[1] + (b[0] * xi - a[0] * yi);
mem[1] = (b[1] * xi - a[1] * yi);
y[i] = yi;
}
}
Expand Down
242 changes: 206 additions & 36 deletions src/rnn.c
Original file line number Diff line number Diff line change
Expand Up @@ -84,30 +84,199 @@ void compute_dense(const DenseLayer *layer, float *output, const float *input)
M = layer->nb_inputs;
N = layer->nb_neurons;
stride = N;
for (i=0;i<N;i++)
for (i = 0;i < N;i++)
{
/* Compute update gate. */
float sum = layer->bias[i];
for (j=0;j<M;j++)
for (j = 0; j<M;j++)
sum += layer->input_weights[j*stride + i]*input[j];
output[i] = WEIGHTS_SCALE*sum;
}
if (layer->activation == ACTIVATION_SIGMOID) {
for (i=0;i<N;i++)
for (i = 0;i < N;i++)
output[i] = sigmoid_approx(output[i]);
} else if (layer->activation == ACTIVATION_TANH) {
for (i=0;i<N;i++)
for (i = 0;i < N;i++)
output[i] = tansig_approx(output[i]);
} else if (layer->activation == ACTIVATION_RELU) {
for (i=0;i<N;i++)
for (i = 0;i < N;i++)
output[i] = relu(output[i]);
} else {
*(int*)0=0;
}
}

// FMA is always available if AVX2 is available
#if !defined(__FMA__) && defined(__AVX2__)
#define __FMA__ 1
#endif

#if defined(__AVX2__) && defined(__FMA__)
#include <immintrin.h>

void compute_gru_avx2(const GRULayer *gru, float *state, const float *input)
{
int i, j;
int N, M;
int stride;
float z[MAX_NEURONS];
float r[MAX_NEURONS];
float h[MAX_NEURONS];
M = gru->nb_inputs;
N = gru->nb_neurons;
stride = 3 * N;

int chunk_size = 8;
int n_remainder = N % chunk_size;
int n_chunk_count = (N - n_remainder) / chunk_size;

for (int i_chunk = 0; i_chunk < n_chunk_count; i_chunk++) {
// Load i8s
__m128i i8_z_sum = _mm_loadu_si128((__m128i*) &gru->bias[i_chunk * chunk_size]);
__m128i i8_r_sum = _mm_loadu_si128((__m128i*) &gru->bias[N + (i_chunk * chunk_size)]);
// Sign-extend to i32s
__m256i i32_z_sum = _mm256_cvtepi8_epi32(i8_z_sum);
__m256i i32_r_sum = _mm256_cvtepi8_epi32(i8_r_sum);
// Convert to f32s
__m256 z_sum = _mm256_cvtepi32_ps(i32_z_sum);
__m256 r_sum = _mm256_cvtepi32_ps(i32_r_sum);

for (j = 0; j<M; j++) {
// Load i8s
__m128i z_input_weights_i8 = _mm_loadu_si128((__m128i*) &gru->input_weights[j*stride + (i_chunk * chunk_size)]);
__m128i r_input_weights_i8 = _mm_loadu_si128((__m128i*) &gru->input_weights[N + j*stride + (i_chunk * chunk_size)]);
// Sign-extend to i32s
__m256i z_input_weights_i32 = _mm256_cvtepi8_epi32(z_input_weights_i8);
__m256i r_input_weights_i32 = _mm256_cvtepi8_epi32(r_input_weights_i8);
// Convert to f32s
__m256 z_input_weights = _mm256_cvtepi32_ps(z_input_weights_i32);
__m256 r_input_weights = _mm256_cvtepi32_ps(r_input_weights_i32);

__m256 input_v = _mm256_broadcast_ss(&input[j]);

z_sum = _mm256_fmadd_ps(z_input_weights, input_v, z_sum);
r_sum = _mm256_fmadd_ps(r_input_weights, input_v, r_sum);
}
for (j = 0; j<N; j++) {
// Load i8s
__m128i z_recurrent_weights_i8 = _mm_loadu_si128((__m128i*) &gru->recurrent_weights[j*stride + (i_chunk * chunk_size)]);
__m128i r_recurrent_weights_i8 = _mm_loadu_si128((__m128i*) &gru->recurrent_weights[N + j*stride + (i_chunk * chunk_size)]);
// Sign-extend to i32s
__m256i z_recurrent_weights_i32 = _mm256_cvtepi8_epi32(z_recurrent_weights_i8);
__m256i r_recurrent_weights_i32 = _mm256_cvtepi8_epi32(r_recurrent_weights_i8);
// Convert to f32s
__m256 z_recurrent_weights = _mm256_cvtepi32_ps(z_recurrent_weights_i32);
__m256 r_recurrent_weights = _mm256_cvtepi32_ps(r_recurrent_weights_i32);

__m256 state_v = _mm256_broadcast_ss(&state[j]);

z_sum = _mm256_fmadd_ps(z_recurrent_weights, state_v, z_sum);
r_sum = _mm256_fmadd_ps(r_recurrent_weights, state_v, r_sum);
}

// Store sums
_mm256_storeu_ps(&z[i_chunk * chunk_size], z_sum);
_mm256_storeu_ps(&r[i_chunk * chunk_size], r_sum);
}
// Remainders
for (int i = n_chunk_count * chunk_size; i < N; i++) {
float z_sum = gru->bias[i];
float r_sum = gru->bias[N + i];

for (j = 0; j<M;j++) {
/* Compute update gate. */
z_sum += gru->input_weights[j*stride + i]*input[j];
/* Compute reset gate. */
r_sum += gru->input_weights[N + j*stride + i]*input[j];
}
for (j = 0; j<N;j++) {
/* Compute update gate. */
z_sum += gru->recurrent_weights[j*stride + i]*state[j];
/* Compute reset gate. */
r_sum += gru->recurrent_weights[N + j*stride + i]*state[j];
}

z[i] = z_sum;
r[i] = r_sum;
}
// Apply sigmoid to sums
for (i = 0; i < N; i++) {
z[i] = sigmoid_approx(WEIGHTS_SCALE * z[i]);
r[i] = sigmoid_approx(WEIGHTS_SCALE * r[i]);
}

/* Compute output. */
for (int i_chunk = 0; i_chunk < n_chunk_count; i_chunk++) {
// Load i8s
__m128i i8_sum = _mm_loadu_si128((__m128i*) &gru->bias[2*N + (i_chunk * chunk_size)]);
// Sign-extend to i32s
__m256i i32_sum = _mm256_cvtepi8_epi32(i8_sum);
// Convert to f32s
__m256 sum = _mm256_cvtepi32_ps(i32_sum);

for (j = 0; j < M; j++) {
// Load i8s
__m128i input_weights_i8 = _mm_loadu_si128((__m128i*) &gru->input_weights[2*N + j*stride + (i_chunk * chunk_size)]);
// Sign-extend to i32s
__m256i input_weights_i32 = _mm256_cvtepi8_epi32(input_weights_i8);
// Convert to f32s
__m256 input_weights = _mm256_cvtepi32_ps(input_weights_i32);

__m256 input_v = _mm256_broadcast_ss(&input[j]);

sum = _mm256_fmadd_ps(input_weights, input_v, sum) ;
}

for (j = 0; j < N; j++) {
// Load i8s
__m128i recurrent_weights_i8 = _mm_loadu_si128((__m128i*) &gru->recurrent_weights[2*N + j*stride + (i_chunk * chunk_size)]);
// Sign-extend to i32s
__m256i recurrent_weights_i32 = _mm256_cvtepi8_epi32(recurrent_weights_i8);
// Convert to f32s
__m256 recurrent_weights = _mm256_cvtepi32_ps(recurrent_weights_i32);

float state_times_r = state[j] * r[j];
__m256 state_times_r_v = _mm256_broadcast_ss(&state_times_r);

sum = _mm256_fmadd_ps(recurrent_weights, state_times_r_v, sum);
}

// Store sums
_mm256_storeu_ps(&h[i_chunk * chunk_size], sum);
}
// Remainders
for (int i = n_chunk_count * chunk_size; i < N; i++) {
float sum = gru->bias[2*N + i];
for (j = 0; j < M; j++)
sum += gru->input_weights[2*N + j*stride + i] * input[j];
for (j = 0; j < N; j++)
sum += gru->recurrent_weights[2*N + j*stride + i] * state[j] * r[j];

h[i] = sum;
}

for (i = 0; i < N; i++) {
float sum = h[i];

if (gru->activation == ACTIVATION_SIGMOID) sum = sigmoid_approx(WEIGHTS_SCALE*sum);
else if (gru->activation == ACTIVATION_TANH) sum = tansig_approx(WEIGHTS_SCALE*sum);
else if (gru->activation == ACTIVATION_RELU) sum = relu(WEIGHTS_SCALE*sum);
else *(int*)0=0;
state[i] = z[i]*state[i] + (1-z[i])*sum;
}
}
#endif

void compute_gru(const GRULayer *gru, float *state, const float *input)
{
// Check if we support AVX2 and FMA and use the SIMD-accelerated function if so
#if defined(__AVX2__) && defined(__FMA__)
if (__builtin_cpu_supports("avx2") && __builtin_cpu_supports("fma")) {
compute_gru_avx2(gru, state, input);
return;
}
#endif

int i, j;
int N, M;
int stride;
Expand All @@ -117,42 +286,43 @@ void compute_gru(const GRULayer *gru, float *state, const float *input)
M = gru->nb_inputs;
N = gru->nb_neurons;
stride = 3*N;
for (i=0;i<N;i++)
for (i = 0;i < N;i++)
{
/* Compute update gate. */
float sum = gru->bias[i];
for (j=0;j<M;j++)
sum += gru->input_weights[j*stride + i]*input[j];
for (j=0;j<N;j++)
sum += gru->recurrent_weights[j*stride + i]*state[j];
z[i] = sigmoid_approx(WEIGHTS_SCALE*sum);
}
for (i=0;i<N;i++)
{
/* Compute reset gate. */
float sum = gru->bias[N + i];
for (j=0;j<M;j++)
sum += gru->input_weights[N + j*stride + i]*input[j];
for (j=0;j<N;j++)
sum += gru->recurrent_weights[N + j*stride + i]*state[j];
r[i] = sigmoid_approx(WEIGHTS_SCALE*sum);
float z_sum = gru->bias[i];
float r_sum = gru->bias[N + i];

for (j = 0; j<M;j++) {
/* Compute update gate. */
z_sum += gru->input_weights[j*stride + i]*input[j];
/* Compute reset gate. */
r_sum += gru->input_weights[N + j*stride + i]*input[j];
}
for (j = 0; j<N;j++) {
/* Compute update gate. */
z_sum += gru->recurrent_weights[j*stride + i]*state[j];
/* Compute reset gate. */
r_sum += gru->recurrent_weights[N + j*stride + i]*state[j];
}

z[i] = sigmoid_approx(WEIGHTS_SCALE*z_sum);
r[i] = sigmoid_approx(WEIGHTS_SCALE*r_sum);
}
for (i=0;i<N;i++)
{
/* Compute output. */

/* Compute output. */
for (i = 0;i < N;i++) {
float sum = gru->bias[2*N + i];
for (j=0;j<M;j++)
for (j = 0; j<M;j++)
sum += gru->input_weights[2*N + j*stride + i]*input[j];
for (j=0;j<N;j++)
for (j = 0; j<N;j++)
sum += gru->recurrent_weights[2*N + j*stride + i]*state[j]*r[j];
if (gru->activation == ACTIVATION_SIGMOID) sum = sigmoid_approx(WEIGHTS_SCALE*sum);
else if (gru->activation == ACTIVATION_TANH) sum = tansig_approx(WEIGHTS_SCALE*sum);
else if (gru->activation == ACTIVATION_RELU) sum = relu(WEIGHTS_SCALE*sum);
else *(int*)0=0;
h[i] = z[i]*state[i] + (1-z[i])*sum;
}
for (i=0;i<N;i++)
state[i] = h[i];
for (i = 0;i < N;i++)
state[i] = h[i ];
}

#define INPUT_SIZE 42
Expand All @@ -165,14 +335,14 @@ void compute_rnn(RNNState *rnn, float *gains, float *vad, const float *input) {
compute_dense(rnn->model->input_dense, dense_out, input);
compute_gru(rnn->model->vad_gru, rnn->vad_gru_state, dense_out);
compute_dense(rnn->model->vad_output, vad, rnn->vad_gru_state);
for (i=0;i<rnn->model->input_dense_size;i++) noise_input[i] = dense_out[i];
for (i=0;i<rnn->model->vad_gru_size;i++) noise_input[i+rnn->model->input_dense_size] = rnn->vad_gru_state[i];
for (i=0;i<INPUT_SIZE;i++) noise_input[i+rnn->model->input_dense_size+rnn->model->vad_gru_size] = input[i];
for (i = 0;i<rnn->model->input_dense_size;i++) noise_input[i] = dense_out[i];
for (i = 0;i<rnn->model->vad_gru_size;i++) noise_input[i+rnn->model->input_dense_size] = rnn->vad_gru_state[i];
for (i = 0;i<INPUT_SIZE;i++) noise_input[i+rnn->model->input_dense_size+rnn->model->vad_gru_size] = input[i];
compute_gru(rnn->model->noise_gru, rnn->noise_gru_state, noise_input);

for (i=0;i<rnn->model->vad_gru_size;i++) denoise_input[i] = rnn->vad_gru_state[i];
for (i=0;i<rnn->model->noise_gru_size;i++) denoise_input[i+rnn->model->vad_gru_size] = rnn->noise_gru_state[i];
for (i=0;i<INPUT_SIZE;i++) denoise_input[i+rnn->model->vad_gru_size+rnn->model->noise_gru_size] = input[i];
for (i = 0;i<rnn->model->vad_gru_size;i++) denoise_input[i] = rnn->vad_gru_state[i];
for (i = 0;i<rnn->model->noise_gru_size;i++) denoise_input[i+rnn->model->vad_gru_size] = rnn->noise_gru_state[i];
for (i = 0;i<INPUT_SIZE;i++) denoise_input[i+rnn->model->vad_gru_size+rnn->model->noise_gru_size] = input[i];
compute_gru(rnn->model->denoise_gru, rnn->denoise_gru_state, denoise_input);
compute_dense(rnn->model->denoise_output, gains, rnn->denoise_gru_state);
}