From 8e7f6ee5c3041ec11bdcb97ef0882fbd440ff4cc Mon Sep 17 00:00:00 2001 From: DBotThePony Date: Sun, 29 Dec 2024 15:55:55 +0700 Subject: [PATCH] Universe chunks database with dictionaries --- .../kstarbound/json/BinaryJsonReader.kt | 18 +- .../kstarbound/json/BinaryJsonWriter.kt | 12 +- .../kstarbound/server/world/ServerUniverse.kt | 305 +++++++++++++++++- 3 files changed, 308 insertions(+), 27 deletions(-) diff --git a/src/main/kotlin/ru/dbotthepony/kstarbound/json/BinaryJsonReader.kt b/src/main/kotlin/ru/dbotthepony/kstarbound/json/BinaryJsonReader.kt index f9a7de15..6d3d65e7 100644 --- a/src/main/kotlin/ru/dbotthepony/kstarbound/json/BinaryJsonReader.kt +++ b/src/main/kotlin/ru/dbotthepony/kstarbound/json/BinaryJsonReader.kt @@ -1,5 +1,6 @@ package ru.dbotthepony.kstarbound.json +import com.github.luben.zstd.ZstdDictDecompress import com.github.luben.zstd.ZstdInputStreamNoFinalizer import com.google.gson.JsonArray import com.google.gson.JsonElement @@ -36,7 +37,7 @@ private enum class InflateType { NONE } -private fun ByteArray.callRead(inflate: InflateType, callable: DataInputStream.() -> T): T { +private fun ByteArray.callRead(inflate: InflateType, dictionary: ZstdDictDecompress? = null, callable: DataInputStream.() -> T): T { val stream = FastByteArrayInputStream(this) when (inflate) { @@ -53,12 +54,17 @@ private fun ByteArray.callRead(inflate: InflateType, callable: DataInputStre } InflateType.ZSTD -> { - val data = DataInputStream(BufferedInputStream(ZstdInputStreamNoFinalizer(stream), 0x10000)) + val f = ZstdInputStreamNoFinalizer(stream) + + if (dictionary != null) + f.setDict(dictionary) + + val data = DataInputStream(BufferedInputStream(f, 0x10000)) try { return callable(data) } finally { - data.close() + f.close() } } @@ -76,9 +82,9 @@ fun ByteArray.readJsonElementInflated(): JsonElement = callRead(InflateType.ZLIB fun ByteArray.readJsonObjectInflated(): JsonObject = callRead(InflateType.ZLIB) { readJsonObject() } fun ByteArray.readJsonArrayInflated(): JsonArray = callRead(InflateType.ZLIB) { readJsonArray() } -fun ByteArray.readJsonElementZstd(): JsonElement = callRead(InflateType.ZSTD) { readJsonElement() } -fun ByteArray.readJsonObjectZstd(): JsonObject = callRead(InflateType.ZSTD) { readJsonObject() } -fun ByteArray.readJsonArrayZstd(): JsonArray = callRead(InflateType.ZSTD) { readJsonArray() } +fun ByteArray.readJsonElementZstd(dictionary: ZstdDictDecompress? = null): JsonElement = callRead(InflateType.ZSTD, dictionary = dictionary) { readJsonElement() } +fun ByteArray.readJsonObjectZstd(dictionary: ZstdDictDecompress? = null): JsonObject = callRead(InflateType.ZSTD, dictionary = dictionary) { readJsonObject() } +fun ByteArray.readJsonArrayZstd(dictionary: ZstdDictDecompress? = null): JsonArray = callRead(InflateType.ZSTD, dictionary = dictionary) { readJsonArray() } /** * Позволяет читать двоичный JSON прямиком в [JsonElement] diff --git a/src/main/kotlin/ru/dbotthepony/kstarbound/json/BinaryJsonWriter.kt b/src/main/kotlin/ru/dbotthepony/kstarbound/json/BinaryJsonWriter.kt index 9cd50eb5..c3b49401 100644 --- a/src/main/kotlin/ru/dbotthepony/kstarbound/json/BinaryJsonWriter.kt +++ b/src/main/kotlin/ru/dbotthepony/kstarbound/json/BinaryJsonWriter.kt @@ -1,5 +1,6 @@ package ru.dbotthepony.kstarbound.json +import com.github.luben.zstd.ZstdDictCompress import com.github.luben.zstd.ZstdOutputStreamNoFinalizer import com.google.gson.JsonArray import com.google.gson.JsonElement @@ -21,7 +22,7 @@ private enum class DeflateType { NONE } -private fun T.callWrite(deflate: DeflateType, zstdCompressionLevel: Int = 6, callable: DataOutputStream.(T) -> Unit): ByteArray { +private fun T.callWrite(deflate: DeflateType, zstdCompressionLevel: Int = 6, zstdDictionary: ZstdDictCompress? = null, callable: DataOutputStream.(T) -> Unit): ByteArray { val stream = FastByteArrayOutputStream() when (deflate) { @@ -38,6 +39,9 @@ private fun T.callWrite(deflate: DeflateType, zstdCompressionLevel: Int = 6, val s = ZstdOutputStreamNoFinalizer(stream) s.setLevel(zstdCompressionLevel) + if (zstdDictionary != null) + s.setDict(zstdDictionary) + DataOutputStream(BufferedOutputStream(s, 0x10000)).use { callable(it, this) } @@ -57,9 +61,9 @@ fun JsonElement.writeJsonElementDeflated(): ByteArray = callWrite(DeflateType.ZL fun JsonObject.writeJsonObjectDeflated(): ByteArray = callWrite(DeflateType.ZLIB) { writeJsonObject(it) } fun JsonArray.writeJsonArrayDeflated(): ByteArray = callWrite(DeflateType.ZLIB) { writeJsonArray(it) } -fun JsonElement.writeJsonElementZstd(level: Int = 6): ByteArray = callWrite(DeflateType.ZSTD, zstdCompressionLevel = level) { writeJsonElement(it) } -fun JsonObject.writeJsonObjectZstd(level: Int = 6): ByteArray = callWrite(DeflateType.ZSTD, zstdCompressionLevel = level) { writeJsonObject(it) } -fun JsonArray.writeJsonArrayZstd(level: Int = 6): ByteArray = callWrite(DeflateType.ZSTD, zstdCompressionLevel = level) { writeJsonArray(it) } +fun JsonElement.writeJsonElementZstd(level: Int = 6, dictionary: ZstdDictCompress? = null): ByteArray = callWrite(DeflateType.ZSTD, zstdCompressionLevel = level, zstdDictionary = dictionary) { writeJsonElement(it) } +fun JsonObject.writeJsonObjectZstd(level: Int = 6, dictionary: ZstdDictCompress? = null): ByteArray = callWrite(DeflateType.ZSTD, zstdCompressionLevel = level, zstdDictionary = dictionary) { writeJsonObject(it) } +fun JsonArray.writeJsonArrayZstd(level: Int = 6, dictionary: ZstdDictCompress? = null): ByteArray = callWrite(DeflateType.ZSTD, zstdCompressionLevel = level, zstdDictionary = dictionary) { writeJsonArray(it) } fun DataOutputStream.writeJsonElement(value: JsonElement) { when (value) { diff --git a/src/main/kotlin/ru/dbotthepony/kstarbound/server/world/ServerUniverse.kt b/src/main/kotlin/ru/dbotthepony/kstarbound/server/world/ServerUniverse.kt index efa9a7eb..a2f3e6b8 100644 --- a/src/main/kotlin/ru/dbotthepony/kstarbound/server/world/ServerUniverse.kt +++ b/src/main/kotlin/ru/dbotthepony/kstarbound/server/world/ServerUniverse.kt @@ -1,9 +1,17 @@ package ru.dbotthepony.kstarbound.server.world import com.github.benmanes.caffeine.cache.Caffeine +import com.github.luben.zstd.Zstd +import com.github.luben.zstd.ZstdDictCompress +import com.github.luben.zstd.ZstdDictDecompress +import com.github.luben.zstd.ZstdException +import com.github.luben.zstd.ZstdInputStreamNoFinalizer +import com.github.luben.zstd.ZstdOutputStreamNoFinalizer import com.google.gson.JsonArray import com.google.gson.JsonObject import it.unimi.dsi.fastutil.ints.IntArraySet +import it.unimi.dsi.fastutil.io.FastByteArrayInputStream +import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream import it.unimi.dsi.fastutil.objects.ObjectArrayList import it.unimi.dsi.fastutil.objects.ObjectOpenHashSet import kotlinx.coroutines.CoroutineScope @@ -12,6 +20,7 @@ import kotlinx.coroutines.async import kotlinx.coroutines.cancel import kotlinx.coroutines.future.asCompletableFuture import kotlinx.coroutines.future.await +import org.apache.logging.log4j.LogManager import ru.dbotthepony.kommons.gson.JsonArrayCollector import ru.dbotthepony.kommons.gson.contains import ru.dbotthepony.kommons.gson.get @@ -48,10 +57,12 @@ import ru.dbotthepony.kstarbound.world.UniversePos import java.io.Closeable import java.io.File import java.lang.ref.Cleaner.Cleanable +import java.nio.ByteBuffer import java.sql.Connection import java.sql.DriverManager import java.sql.PreparedStatement import java.sql.ResultSet +import java.time.Duration import java.util.concurrent.CompletableFuture import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.TimeUnit @@ -112,9 +123,17 @@ class ServerUniverse(folder: File? = null) : Universe(), Closeable { `z` INTEGER NOT NULL, `parameters` BLOB NOT NULL, `planets` BLOB NOT NULL, + `dictionary` INTEGER NOT NULL, PRIMARY KEY(`x`, `y`, `z`) ) """.trimIndent()) + + it.execute(""" + CREATE TABLE IF NOT EXISTS `dictionary` ( + `version` INTEGER NOT NULL PRIMARY KEY, + `data` BLOB NOT NULL + ) + """.trimIndent()) } database.autoCommit = false @@ -125,10 +144,13 @@ class ServerUniverse(folder: File? = null) : Universe(), Closeable { private val scope = CoroutineScope(ScheduledCoroutineExecutor(carrier) + SupervisorJob()) private val selectChunk = database.prepareStatement("SELECT `systems`, `constellations` FROM `chunk` WHERE `x` = ? AND `y` = ?") - private val selectSystem = database.prepareStatement("SELECT `parameters`, `planets` FROM `system` WHERE `x` = ? AND `y` = ? AND `z` = ?") + private val selectSystem = database.prepareStatement("SELECT `parameters`, `planets`, `dictionary` FROM `system` WHERE `x` = ? AND `y` = ? AND `z` = ?") + private val selectDictionary = database.prepareStatement("SELECT `data` FROM `dictionary` WHERE `version` = ?") + private val selectSamples = database.prepareStatement("SELECT `x`, `y`, `z`, `parameters`, `planets` FROM `system` WHERE `dictionary` = ? ORDER BY RANDOM() LIMIT 2000") private val insertChunk = database.prepareStatement("INSERT INTO `chunk` (`x`, `y`, `systems`, `constellations`) VALUES (?, ?, ?, ?)") - private val insertSystem = database.prepareStatement("INSERT INTO `system` (`x`, `y`, `z`, `parameters`, `planets`) VALUES (?, ?, ?, ?, ?)") + private val insertSystem = database.prepareStatement("REPLACE INTO `system` (`x`, `y`, `z`, `parameters`, `planets`, `dictionary`) VALUES (?, ?, ?, ?, ?, ?)") + private val insertDictionary = database.prepareStatement("INSERT INTO `dictionary` (`version`, `data`) VALUES (?, ?)") private class SerializedChunk(val x: Int, val y: Int, val systems: ByteArray, val constellations: ByteArray) { fun write(statement: PreparedStatement) { @@ -164,19 +186,16 @@ class ServerUniverse(folder: File? = null) : Universe(), Closeable { }.collect(JsonArrayCollector).writeJsonArrayZstd(4) ) } - - fun write(statement: PreparedStatement) { - serialize().write(statement) - } } - private class SerializedSystem(val x: Int, val y: Int, val z: Int, val parameters: ByteArray, val planets: ByteArray) { + private class SerializedSystem(val x: Int, val y: Int, val z: Int, val parameters: ByteArray, val planets: ByteArray, val dictionary: Int) { fun write(statement: PreparedStatement) { statement.setInt(1, x) statement.setInt(2, y) statement.setInt(3, z) statement.setBytes(4, parameters) statement.setBytes(5, planets) + statement.setInt(6, dictionary) statement.execute() } @@ -191,19 +210,16 @@ class ServerUniverse(folder: File? = null) : Universe(), Closeable { } } - fun serialize(): SerializedSystem { + fun serialize(dict: Dicts? = null): SerializedSystem { return SerializedSystem( x, y, z, - Starbound.gson.toJsonTree(parameters).writeJsonElementZstd(8), + Starbound.gson.toJsonTree(parameters).writeJsonElementZstd(6, dictionary = dict?.compress), planets.entries.stream() .map { jsonArrayOf(it.key.first, it.key.second, it.value) } - .collect(JsonArrayCollector).writeJsonArrayZstd(8) + .collect(JsonArrayCollector).writeJsonArrayZstd(6, dictionary = dict?.compress), + dict?.version ?: 0 ) } - - fun write(statement: PreparedStatement) { - serialize().write(statement) - } } // first, chunks in process of loading/generating must not be evicted @@ -254,6 +270,239 @@ class ServerUniverse(folder: File? = null) : Universe(), Closeable { .executor(Starbound.EXECUTOR) .build>() + private class Dicts(val version: Int, bytesParameters: ByteArray) : Closeable { + private var compressL: Lazy? = lazy(LazyThreadSafetyMode.NONE) { ZstdDictCompress(bytesParameters, 6) } + private var decompressL: Lazy? = lazy(LazyThreadSafetyMode.NONE) { ZstdDictDecompress(bytesParameters) } + + val compress get() = compressL!!.value + val decompress get() = decompressL!!.value + + override fun close() { + /*if (compressL?.isInitialized() == true) { + compressL!!.value.close() + compressL = null + } + + if (decompressL?.isInitialized() == true) { + decompressL!!.value.close() + decompressL = null + }*/ + + // zstd dicts use finalizers + // while we could close the dicts explicitly, they might be still in use + // due to off-thread main thread compression being in-process + compressL = null + decompressL = null + } + } + + private val dictionaryCache = Caffeine.newBuilder() + .maximumSize(32L) + .expireAfterAccess(Duration.ofMinutes(5)) + .executor(Starbound.EXECUTOR) + .evictionListener> { _, value, _ -> value?.orNull()?.close() } + .build>() + + private fun loadDictionary(id: Int): Dicts? { + return dictionaryCache.get(id) { v -> + selectDictionary.setInt(1, v) + selectDictionary.executeQuery().use { + if (it.next()) { + val bytes = decompress(it.getBytes(1), null) + return@get KOptional(Dicts(v, bytes)) + } else { + return@get KOptional() + } + } + }.orNull() + } + + private var latestDictionary = 0 + private var latestDictionaryCapacity = 0 + private var dictionary: Dicts? = null + + init { + database.createStatement().use { + it.executeQuery(""" + SELECT MAX(`version`) FROM `dictionary` + """.trimIndent()).use { + it.next() + // if dictionary table is empty, max() returns null, which gets treated as 0 by getInt + latestDictionary = it.getInt(1) + } + + it.executeQuery(""" + SELECT COUNT(*) FROM `dictionary` WHERE `version` = $latestDictionary + """.trimIndent()).use { + it.next() + latestDictionaryCapacity = it.getInt(1) + } + } + + if (latestDictionary != 0) { + dictionary = loadDictionary(latestDictionary)!! + } + } + + private fun decompress(input: ByteArray, dictionary: ZstdDictDecompress?): ByteArray { + val parts = ArrayList() + val stream = ZstdInputStreamNoFinalizer(FastByteArrayInputStream(input)) + + if (dictionary != null) + stream.setDict(dictionary) + + while (true) { + val alloc = ByteArray(1024 * 16) + val read = stream.read(alloc) + + if (read <= 0) { + break + } else if (read == alloc.size) { + parts.add(alloc) + } else { + parts.add(alloc.copyOf(read)) + } + } + + stream.close() + + val output = ByteArray(parts.sumOf { it.size }) + var i = 0 + + for (part in parts) { + java.lang.System.arraycopy(part, 0, output, i, part.size) + i += part.size + } + + return output + } + + private fun recompress(input: ByteArray, dict: ZstdDictCompress): ByteArray { + val stream = FastByteArrayOutputStream() + + ZstdOutputStreamNoFinalizer(stream, 6).use { + it.setDict(dict) + it.write(input) + } + + return stream.array.copyOf(stream.length) + } + + private fun recompress(input: ByteArray): ByteArray { + val stream = FastByteArrayOutputStream() + + ZstdOutputStreamNoFinalizer(stream, 6).use { + it.write(input) + } + + return stream.array.copyOf(stream.length) + } + + private fun dictionaryMaintenance() { + val limit = if (latestDictionary == 0) 200 else 2000 + + if (latestDictionaryCapacity >= limit) { + LOGGER.info("Optimizing star map compression algorithm, star map might become unresponsive for a moment...") + + class Tuple(val x: Int, val y: Int, val z: Int, parameters: ByteArray, planets: ByteArray) { + val parameters = Starbound.EXECUTOR.supplyAsync { decompress(parameters, dictionary?.decompress) } + val planets = Starbound.EXECUTOR.supplyAsync { decompress(planets, dictionary?.decompress) } + } + + // current dictionary is old enough, + // create new one to adapt to possible data changes e.g. due to added or removed mods + + // collect samples + val samples = ArrayList() + + selectSamples.setInt(1, latestDictionary) + selectSamples.executeQuery().use { + while (it.next()) { + samples.add(Tuple(it.getInt(1), it.getInt(2), it.getInt(3), it.getBytes(4), it.getBytes(5))) + } + } + + // wait for samples to be decompressed + val sampleBuffer = ByteBuffer.allocateDirect(samples.sumOf { it.parameters.join().size + it.planets.join().size }) + + for (sample in samples) { + sampleBuffer.put(sample.parameters.join()) + sampleBuffer.put(sample.planets.join()) + } + + sampleBuffer.position(0) + + // create dictionary + val buffer = ByteBuffer.allocateDirect(1024 * 1024 * 4) // up to 4 MiB dictionary (before compression, since dedicated zstd dictionaries are not compressed) + // 4 MiB seems to be sweet spot, dictionary isn't that big (smaller indices to reference inside dictionary), + // takes small amount of space, and training is done moderately fast + // Too big dictionaries cause over-fitting and generally *reduce* compression ratio, + // while too small dictionaries don't contain enough data to be effective + + val status = Zstd.trainFromBufferDirect( + sampleBuffer, + IntArray(samples.size) { samples[it].parameters.join().size + samples[it].planets.join().size }, + buffer, + false, 6 + ) + + if (Zstd.isError(status)) { + throw ZstdException(status) + } + + val copyBytes = ByteArray(status.toInt()) + buffer.position(0) + buffer.get(copyBytes) + + val dicts = Dicts(++latestDictionary, copyBytes) + dictionary = dicts + dictionaryCache.put(latestDictionary, KOptional(dicts)) + + insertDictionary.setInt(1, latestDictionary) + insertDictionary.setBytes(2, recompress(copyBytes)) + insertDictionary.execute() + database.commit() + + latestDictionaryCapacity = 0 + + LOGGER.info("Star map compression optimized") + + if (latestDictionary == 1) { + LOGGER.info("Recompressing star map chunks with new dictionary...") + + // previous data wasn't compressed by any dictionary, so let's recompress it with new dictionary + val recompressed = ArrayList>>() + + for (tuple in samples) { + recompressed.add( + Starbound.EXECUTOR.supplyAsync { + Triple(tuple, recompress(tuple.parameters.join(), dicts.compress), recompress(tuple.planets.join(), dicts.compress)) + } + ) + } + + try { + for ((tuple, parameters, planets) in recompressed.map { it.join() }) { + insertSystem.setInt(1, tuple.x) + insertSystem.setInt(2, tuple.y) + insertSystem.setInt(3, tuple.z) + insertSystem.setBytes(4, parameters) + insertSystem.setBytes(5, planets) + insertSystem.setInt(6, 1) + insertSystem.execute() + } + + database.commit() + } catch (err: Throwable) { + database.rollback() + throw err + } + + LOGGER.info("Recompressed star map chunks with new dictionary") + } + } + } + private fun loadSystem(pos: Vector3i): CompletableFuture? { selectSystem.setInt(1, pos.x) selectSystem.setInt(2, pos.y) @@ -263,12 +512,21 @@ class ServerUniverse(folder: File? = null) : Universe(), Closeable { if (it.next()) { val parametersBytes = it.getBytes(1) val planetsBytes = it.getBytes(2) + val dictionaryVersion = it.getInt(3) + + val dict: ZstdDictDecompress? + + if (dictionaryVersion == 0) { + dict = null + } else { + dict = loadDictionary(dictionaryVersion)?.decompress + } // deserialize in off-thread since it involves big json structures Starbound.EXECUTOR.supplyAsync { - val parameters: CelestialParameters = Starbound.gson.fromJson(parametersBytes.readJsonElementZstd())!! + val parameters: CelestialParameters = Starbound.gson.fromJson(parametersBytes.readJsonElementZstd(dictionary = dict))!! - val planets: Map, CelestialParameters> = planetsBytes.readJsonArrayZstd().associate { + val planets: Map, CelestialParameters> = planetsBytes.readJsonArrayZstd(dictionary = dict).associate { it as JsonArray (it[0].asInt to it[1].asInt) to Starbound.gson.fromJson(it[2])!! } @@ -521,6 +779,8 @@ class ServerUniverse(folder: File? = null) : Universe(), Closeable { val random = random(staticRandom64(chunkPos.x, chunkPos.y, "ChunkIndexMix")) val region = chunkRegion(chunkPos) + val dict = dictionary + return CompletableFuture.supplyAsync(Supplier { val constellationCandidates = ArrayList() val systemPositions = ArrayList() @@ -534,7 +794,7 @@ class ServerUniverse(folder: File? = null) : Universe(), Closeable { val system = generateSystem(random, pos) ?: continue systemPositions.add(pos) - systems.add(CompletableFuture.supplyAsync(Supplier { system.serialize() }, Starbound.EXECUTOR)) + systems.add(CompletableFuture.supplyAsync(Supplier { system.serialize(dict) }, Starbound.EXECUTOR)) systemCache.put(Vector3i(system.x, system.y, system.z), CompletableFuture.completedFuture(system)) if ( @@ -554,6 +814,13 @@ class ServerUniverse(folder: File? = null) : Universe(), Closeable { serialized.write(insertChunk) systems.forEach { it.get().write(insertSystem) } database.commit() + + if (latestDictionary == (dict?.version ?: 0)) { + latestDictionaryCapacity += systems.size + // enqueue dictionary maintenance + carrier.execute(::dictionaryMaintenance) + } + chunk }, carrier) }, Starbound.EXECUTOR).thenCompose { it } @@ -738,4 +1005,8 @@ class ServerUniverse(folder: File? = null) : Universe(), Closeable { } override val region: AABBi = AABBi(Vector2i(baseInformation.xyCoordRange.x, baseInformation.xyCoordRange.x), Vector2i(baseInformation.xyCoordRange.y, baseInformation.xyCoordRange.y)) + + companion object { + private val LOGGER = LogManager.getLogger() + } }