From 3b11fcf792e4ba84b4c0880e1c05736b70381758 Mon Sep 17 00:00:00 2001 From: DBotThePony Date: Fri, 19 Apr 2024 12:34:56 +0700 Subject: [PATCH] Sacrifice some memory efficiency of VectorizedBitSet for wild performance boost --- .../defs/dungeon/VectorizedBitSet.java | 264 +++--------------- .../dbotthepony/kstarbound/test/MiscTests.kt | 28 ++ 2 files changed, 72 insertions(+), 220 deletions(-) diff --git a/src/main/java/ru/dbotthepony/kstarbound/defs/dungeon/VectorizedBitSet.java b/src/main/java/ru/dbotthepony/kstarbound/defs/dungeon/VectorizedBitSet.java index b2000d5b..dc4d8c03 100644 --- a/src/main/java/ru/dbotthepony/kstarbound/defs/dungeon/VectorizedBitSet.java +++ b/src/main/java/ru/dbotthepony/kstarbound/defs/dungeon/VectorizedBitSet.java @@ -1,245 +1,69 @@ package ru.dbotthepony.kstarbound.defs.dungeon; -import java.util.BitSet; - -// TODO: actually make it vectorized, some unsafe API maybe? -// currently it is quite slow, but achieves desired memory efficiency. public final class VectorizedBitSet { public final int bits; public final int width; public final int height; - private final BitSet data; + + private final int[] data; + private final int mask; + private final int bitMask; + private final int bitShift; + private final int bitShiftInv; public VectorizedBitSet(int bits, int width, int height) { - if (bits <= 0 || bits >= 13) - throw new IllegalArgumentException("Too many or no bits: " + bits + ". Maximum 12 supported"); + if (bits <= 0) + throw new IllegalArgumentException("Zero bits: " + bits); - this.bits = bits; this.width = width; this.height = height; - this.data = new BitSet(bits * width * height); - } + if (bits <= 4) { + this.bits = 4; - private static int bit(boolean value, int order) { - return (value ? 1 : 0) << order; - } + this.bitShift = 3; + this.bitShiftInv = 2; - private void bit(int index, int value, int order) { - if (((value >>> order) & 1) == 0) { - data.clear(index); + this.bitMask = 7; + this.mask = 0xF; + this.data = new int[(width * height + 7) / 8]; + } else if (bits <= 8) { + this.bits = 8; + this.bitMask = 3; + + this.bitShift = 2; + this.bitShiftInv = 3; + + this.mask = 0xFF; + this.data = new int[(width * height + 3) / 4]; + } else if (bits <= 16) { + this.bits = 16; + this.bitMask = 1; + + this.bitShift = 1; + this.bitShiftInv = 4; + + this.mask = 0xFFFF; + this.data = new int[(width * height + 1) / 2]; } else { - data.set(index); + throw new IllegalArgumentException("Too many bits: " + bits); } } public void set(int x, int y, int value) { - int bitIndex = (x + y * width) * bits; - - switch (this.bits) { - case 0: - - case 1: - bit(bitIndex, value, 0); - break; - - case 2: - bit(bitIndex, value, 0); - bit(bitIndex + 1, value, 1); - break; - - case 3: - bit(bitIndex, value, 0); - bit(bitIndex + 1, value, 1); - bit(bitIndex + 2, value, 2); - break; - - case 4: - bit(bitIndex, value, 0); - bit(bitIndex + 1, value, 1); - bit(bitIndex + 2, value, 2); - bit(bitIndex + 3, value, 3); - break; - - case 5: - bit(bitIndex, value, 0); - bit(bitIndex + 1, value, 1); - bit(bitIndex + 2, value, 2); - bit(bitIndex + 3, value, 3); - bit(bitIndex + 4, value, 4); - break; - - case 6: - bit(bitIndex, value, 0); - bit(bitIndex + 1, value, 1); - bit(bitIndex + 2, value, 2); - bit(bitIndex + 3, value, 3); - bit(bitIndex + 4, value, 4); - bit(bitIndex + 5, value, 5); - break; - - case 7: - bit(bitIndex, value, 0); - bit(bitIndex + 1, value, 1); - bit(bitIndex + 2, value, 2); - bit(bitIndex + 3, value, 3); - bit(bitIndex + 4, value, 4); - bit(bitIndex + 5, value, 5); - bit(bitIndex + 6, value, 6); - break; - - case 8: - bit(bitIndex, value, 0); - bit(bitIndex + 1, value, 1); - bit(bitIndex + 2, value, 2); - bit(bitIndex + 3, value, 3); - bit(bitIndex + 4, value, 4); - bit(bitIndex + 5, value, 5); - bit(bitIndex + 6, value, 6); - bit(bitIndex + 7, value, 7); - break; - - case 9: - bit(bitIndex, value, 0); - bit(bitIndex + 1, value, 1); - bit(bitIndex + 2, value, 2); - bit(bitIndex + 3, value, 3); - bit(bitIndex + 4, value, 4); - bit(bitIndex + 5, value, 5); - bit(bitIndex + 6, value, 6); - bit(bitIndex + 7, value, 7); - bit(bitIndex + 8, value, 8); - break; - - case 10: - bit(bitIndex, value, 0); - bit(bitIndex + 1, value, 1); - bit(bitIndex + 2, value, 2); - bit(bitIndex + 3, value, 3); - bit(bitIndex + 4, value, 4); - bit(bitIndex + 5, value, 5); - bit(bitIndex + 6, value, 6); - bit(bitIndex + 7, value, 7); - bit(bitIndex + 8, value, 8); - bit(bitIndex + 9, value, 9); - break; - - case 11: - bit(bitIndex, value, 0); - bit(bitIndex + 1, value, 1); - bit(bitIndex + 2, value, 2); - bit(bitIndex + 3, value, 3); - bit(bitIndex + 4, value, 4); - bit(bitIndex + 5, value, 5); - bit(bitIndex + 6, value, 6); - bit(bitIndex + 7, value, 7); - bit(bitIndex + 8, value, 8); - bit(bitIndex + 9, value, 9); - bit(bitIndex + 10, value, 10); - break; - - case 12: - bit(bitIndex, value, 0); - bit(bitIndex + 1, value, 1); - bit(bitIndex + 2, value, 2); - bit(bitIndex + 3, value, 3); - bit(bitIndex + 4, value, 4); - bit(bitIndex + 5, value, 5); - bit(bitIndex + 6, value, 6); - bit(bitIndex + 7, value, 7); - bit(bitIndex + 8, value, 8); - bit(bitIndex + 9, value, 9); - bit(bitIndex + 10, value, 10); - bit(bitIndex + 11, value, 11); - break; - - } + value = value & this.mask; + int index = x + y * this.width; + int arrayIndex = index >>> this.bitShift; + int innerIndex = (index & this.bitMask) << this.bitShiftInv; + this.data[arrayIndex] = (this.data[arrayIndex] &~ (this.mask << innerIndex)) | (value << innerIndex); } public int get(int x, int y) { - int bitIndex = (x + y * width) * bits; - BitSet data = this.data; + int index = x + y * this.width; + int arrayIndex = index >>> this.bitShift; + int innerIndex = (index & this.bitMask) << this.bitShiftInv; - return switch (this.bits) { - case 0, 1 -> bit(data.get(bitIndex), 0); - case 2 -> bit(data.get(bitIndex), 0) | - bit(data.get(bitIndex + 1), 1); - case 3 -> bit(data.get(bitIndex), 0) | - bit(data.get(bitIndex + 1), 1) | - bit(data.get(bitIndex + 2), 2); - case 4 -> bit(data.get(bitIndex), 0) | - bit(data.get(bitIndex + 1), 1) | - bit(data.get(bitIndex + 2), 2) | - bit(data.get(bitIndex + 3), 3); - case 5 -> bit(data.get(bitIndex), 0) | - bit(data.get(bitIndex + 1), 1) | - bit(data.get(bitIndex + 2), 2) | - bit(data.get(bitIndex + 3), 3) | - bit(data.get(bitIndex + 4), 4); - case 6 -> bit(data.get(bitIndex), 0) | - bit(data.get(bitIndex + 1), 1) | - bit(data.get(bitIndex + 2), 2) | - bit(data.get(bitIndex + 3), 3) | - bit(data.get(bitIndex + 4), 4) | - bit(data.get(bitIndex + 5), 5); - case 7 -> bit(data.get(bitIndex), 0) | - bit(data.get(bitIndex + 1), 1) | - bit(data.get(bitIndex + 2), 2) | - bit(data.get(bitIndex + 3), 3) | - bit(data.get(bitIndex + 4), 4) | - bit(data.get(bitIndex + 5), 5) | - bit(data.get(bitIndex + 6), 6); - case 8 -> bit(data.get(bitIndex), 0) | - bit(data.get(bitIndex + 1), 1) | - bit(data.get(bitIndex + 2), 2) | - bit(data.get(bitIndex + 3), 3) | - bit(data.get(bitIndex + 4), 4) | - bit(data.get(bitIndex + 5), 5) | - bit(data.get(bitIndex + 6), 6) | - bit(data.get(bitIndex + 7), 7); - case 9 -> bit(data.get(bitIndex), 0) | - bit(data.get(bitIndex + 1), 1) | - bit(data.get(bitIndex + 2), 2) | - bit(data.get(bitIndex + 3), 3) | - bit(data.get(bitIndex + 4), 4) | - bit(data.get(bitIndex + 5), 5) | - bit(data.get(bitIndex + 6), 6) | - bit(data.get(bitIndex + 7), 7) | - bit(data.get(bitIndex + 8), 8); - case 10 -> bit(data.get(bitIndex), 0) | - bit(data.get(bitIndex + 1), 1) | - bit(data.get(bitIndex + 2), 2) | - bit(data.get(bitIndex + 3), 3) | - bit(data.get(bitIndex + 4), 4) | - bit(data.get(bitIndex + 5), 5) | - bit(data.get(bitIndex + 6), 6) | - bit(data.get(bitIndex + 7), 7) | - bit(data.get(bitIndex + 8), 8) | - bit(data.get(bitIndex + 9), 9); - case 11 -> bit(data.get(bitIndex), 0) | - bit(data.get(bitIndex + 1), 1) | - bit(data.get(bitIndex + 2), 2) | - bit(data.get(bitIndex + 3), 3) | - bit(data.get(bitIndex + 4), 4) | - bit(data.get(bitIndex + 5), 5) | - bit(data.get(bitIndex + 6), 6) | - bit(data.get(bitIndex + 7), 7) | - bit(data.get(bitIndex + 8), 8) | - bit(data.get(bitIndex + 9), 9) | - bit(data.get(bitIndex + 10), 10); - case 12 -> bit(data.get(bitIndex), 0) | - bit(data.get(bitIndex + 1), 1) | - bit(data.get(bitIndex + 2), 2) | - bit(data.get(bitIndex + 3), 3) | - bit(data.get(bitIndex + 4), 4) | - bit(data.get(bitIndex + 5), 5) | - bit(data.get(bitIndex + 6), 6) | - bit(data.get(bitIndex + 7), 7) | - bit(data.get(bitIndex + 8), 8) | - bit(data.get(bitIndex + 9), 9) | - bit(data.get(bitIndex + 10), 10) | - bit(data.get(bitIndex + 11), 11); - default -> 0; - }; + int values = this.data[arrayIndex]; + return (values >>> innerIndex) & this.mask; } } diff --git a/src/test/kotlin/ru/dbotthepony/kstarbound/test/MiscTests.kt b/src/test/kotlin/ru/dbotthepony/kstarbound/test/MiscTests.kt index ccfb5b73..e01c0ed2 100644 --- a/src/test/kotlin/ru/dbotthepony/kstarbound/test/MiscTests.kt +++ b/src/test/kotlin/ru/dbotthepony/kstarbound/test/MiscTests.kt @@ -3,8 +3,36 @@ package ru.dbotthepony.kstarbound.test import org.junit.jupiter.api.Assertions.* import org.junit.jupiter.api.DisplayName import org.junit.jupiter.api.Test +import ru.dbotthepony.kommons.arrays.Int2DArray +import ru.dbotthepony.kstarbound.defs.dungeon.VectorizedBitSet +import ru.dbotthepony.kstarbound.util.random.random import java.rmi.UnexpectedException object MiscTests { + @Test + @DisplayName("Vectorized bit set test") + fun vectorizedBitSet() { + val random = random() + val values = Int2DArray.allocate(16, 16) + for (x in 0 until 16) { + for (y in 0 until 16) { + values[x, y] = random.nextInt(0, 4) + } + } + + val bits = VectorizedBitSet(4, 16, 16) + + for (x in 0 until 16) { + for (y in 0 until 16) { + bits[x, y] = values[x, y] + } + } + + for (x in 0 until 16) { + for (y in 0 until 16) { + assertEquals(values[x, y], bits[x, y]) + } + } + } }