Skip to content

Commit

Permalink
CodedOutputStream: Avoid updating position to go beyond end of array.
Browse files Browse the repository at this point in the history
When writing varints.

This has twofold goals:
1. Correctness: if position overruns the array, checking space left may return a negative number. I'm not sure how bad that is, but let's avoid it.
2. Performance. This generates more optimal assembly code which can combine bounds checks, particularly on Android (I haven't looked at the generated assembly on the server JVM; it's possible the server JVM can already performance this hoist).

The `position` field is stored on the object, so Android ART generates assembly codes for `this.position++` like "load, add, store":

```
       ldr w3, [x1, #12]
       add w4, w3, #0x1 (1)
       str w4, [x1, #12]
```

There can be a lot of these loads/stores executed each step of a loop (e.g. writeFixed64NoTag updates position 8 times, and varint encoding could do it even more). It's faster if we can hoist these so we load once at the start of the function, and store once at the end of the function. This also has the nice benefit that it won't store if we've thrown an exception.

See before/after in Compiler Explorer: https://godbolt.org/z/bWWYqsxK4. I'm not an assembly expert, but it seems clear that the increment instructions like `add w4, w0, #0x1 (1)` are no longer always surrounded by loads and stores in the new version.

PiperOrigin-RevId: 681644516
  • Loading branch information
mhansen authored and copybara-github committed Oct 3, 2024
1 parent d88a3d0 commit 24160d5
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 11 deletions.
35 changes: 24 additions & 11 deletions java/core/src/main/java/com/google/protobuf/CodedOutputStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -1347,11 +1347,12 @@ public final void writeInt32NoTag(int value) throws IOException {

@Override
public final void writeUInt32NoTag(int value) throws IOException {
int position = this.position; // Perf: hoist field to register to avoid load/stores.
try {
while (true) {
if ((value & ~0x7F) == 0) {
buffer[position++] = (byte) value;
return;
break;
} else {
buffer[position++] = (byte) ((value | 0x80) & 0xFF);
value >>>= 7;
Expand All @@ -1360,6 +1361,7 @@ public final void writeUInt32NoTag(int value) throws IOException {
} catch (IndexOutOfBoundsException e) {
throw new OutOfSpaceException(position, limit, 1, e);
}
this.position = position; // Only update position if we stayed within the array bounds.
}

@Override
Expand All @@ -1379,11 +1381,12 @@ public final void writeFixed32NoTag(int value) throws IOException {

@Override
public final void writeUInt64NoTag(long value) throws IOException {
int position = this.position; // Perf: hoist field to register to avoid load/stores.
if (HAS_UNSAFE_ARRAY_OPERATIONS && spaceLeft() >= MAX_VARINT_SIZE) {
while (true) {
if ((value & ~0x7FL) == 0) {
UnsafeUtil.putByte(buffer, position++, (byte) value);
return;
break;
} else {
UnsafeUtil.putByte(buffer, position++, (byte) (((int) value | 0x80) & 0xFF));
value >>>= 7;
Expand All @@ -1394,7 +1397,7 @@ public final void writeUInt64NoTag(long value) throws IOException {
while (true) {
if ((value & ~0x7FL) == 0) {
buffer[position++] = (byte) value;
return;
break;
} else {
buffer[position++] = (byte) (((int) value | 0x80) & 0xFF);
value >>>= 7;
Expand All @@ -1404,6 +1407,7 @@ public final void writeUInt64NoTag(long value) throws IOException {
throw new OutOfSpaceException(position, limit, 1, e);
}
}
this.position = position; // Only update position if we stayed within the array bounds.
}

@Override
Expand Down Expand Up @@ -2034,29 +2038,34 @@ public void writeInt32NoTag(int value) throws IOException {

@Override
public void writeUInt32NoTag(int value) throws IOException {
long position = this.position; // Perf: hoist field to register to avoid load/stores.
if (position <= oneVarintLimit) {
// Optimization to avoid bounds checks on each iteration.
while (true) {
if ((value & ~0x7F) == 0) {
UnsafeUtil.putByte(position++, (byte) value);
return;
break;
} else {
UnsafeUtil.putByte(position++, (byte) ((value | 0x80) & 0xFF));
value >>>= 7;
}
}
} else {
while (position < limit) {
while (true) {
if (position >= limit) {
throw new OutOfSpaceException(
String.format("Pos: %d, limit: %d, len: %d", position, limit, 1));
}
if ((value & ~0x7F) == 0) {
UnsafeUtil.putByte(position++, (byte) value);
return;
break;
} else {
UnsafeUtil.putByte(position++, (byte) ((value | 0x80) & 0xFF));
value >>>= 7;
}
}
throw new OutOfSpaceException(position, limit, 1);
}
this.position = position; // Only update position if we stayed within the array bounds.
}

@Override
Expand All @@ -2071,29 +2080,33 @@ public void writeFixed32NoTag(int value) throws IOException {

@Override
public void writeUInt64NoTag(long value) throws IOException {
long position = this.position; // Perf: hoist field to register to avoid load/stores.
if (position <= oneVarintLimit) {
// Optimization to avoid bounds checks on each iteration.
while (true) {
if ((value & ~0x7FL) == 0) {
UnsafeUtil.putByte(position++, (byte) value);
return;
break;
} else {
UnsafeUtil.putByte(position++, (byte) (((int) value | 0x80) & 0xFF));
value >>>= 7;
}
}
} else {
while (position < limit) {
while (true) {
if (position >= limit) {
throw new OutOfSpaceException(position, limit, 1);
}
if ((value & ~0x7FL) == 0) {
UnsafeUtil.putByte(position++, (byte) value);
return;
break;
} else {
UnsafeUtil.putByte(position++, (byte) (((int) value | 0x80) & 0xFF));
value >>>= 7;
}
}
throw new OutOfSpaceException(position, limit, 1);
}
this.position = position; // Only update position if we stayed within the array bounds.
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package com.google.protobuf;

import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth.assertWithMessage;
import static com.google.common.truth.TruthJUnit.assume;
import static org.junit.Assert.assertThrows;

Expand Down Expand Up @@ -313,6 +314,36 @@ public void testWriteFixed64NoTag_outOfBounds_throws() throws Exception {
}
}

@Test
public void testWriteUInt32NoTag_outOfBounds_throws() throws Exception {
// Streaming's buffering masks out of bounds writes.
assume().that(outputType).isNotEqualTo(OutputType.STREAM);

for (int i = 0; i < 5; i++) {
Coder coder = outputType.newCoder(i);
assertThrows(
OutOfSpaceException.class, () -> coder.stream().writeUInt32NoTag(Integer.MAX_VALUE));

// Space left should not go negative.
assertWithMessage("i=%s", i).that(coder.stream().spaceLeft()).isAtLeast(0);
}
}

@Test
public void testWriteUInt64NoTag_outOfBounds_throws() throws Exception {
// Streaming's buffering masks out of bounds writes.
assume().that(outputType).isNotEqualTo(OutputType.STREAM);

for (int i = 0; i < 9; i++) {
Coder coder = outputType.newCoder(i);
assertThrows(
OutOfSpaceException.class, () -> coder.stream().writeUInt64NoTag(Long.MAX_VALUE));

// Space left should not go negative.
assertWithMessage("i=%s", i).that(coder.stream().spaceLeft()).isAtLeast(0);
}
}

/** Test encodeZigZag32() and encodeZigZag64(). */
@Test
public void testEncodeZigZag() throws Exception {
Expand Down

0 comments on commit 24160d5

Please sign in to comment.