From 513ef21926eba86beb38d8fb7a25463a1b477b81 Mon Sep 17 00:00:00 2001 From: DBotThePony Date: Tue, 14 Feb 2023 22:49:57 +0700 Subject: [PATCH] Idiot proof network read methods --- .../dbotthepony/mc/otm/core/math/Decimal.kt | 8 +- .../mc/otm/core/util/DataStreams.kt | 73 ++++++++++++------- .../mc/otm/core/util/FriendlyStreams.kt | 49 ++++++++----- .../mc/otm/data/SerializedFunctionRegistry.kt | 11 ++- .../mc/otm/network/FieldSynchronizer.kt | 17 +++-- 5 files changed, 100 insertions(+), 58 deletions(-) diff --git a/src/main/kotlin/ru/dbotthepony/mc/otm/core/math/Decimal.kt b/src/main/kotlin/ru/dbotthepony/mc/otm/core/math/Decimal.kt index 40f72c805..39f471d84 100644 --- a/src/main/kotlin/ru/dbotthepony/mc/otm/core/math/Decimal.kt +++ b/src/main/kotlin/ru/dbotthepony/mc/otm/core/math/Decimal.kt @@ -2,6 +2,7 @@ package ru.dbotthepony.mc.otm.core.math import net.minecraft.nbt.ByteArrayTag import net.minecraft.nbt.CompoundTag +import net.minecraft.nbt.NbtAccounter import net.minecraft.nbt.StringTag import net.minecraft.nbt.Tag import net.minecraft.network.FriendlyByteBuf @@ -859,8 +860,11 @@ class Decimal @JvmOverloads constructor(whole: BigInteger, decimal: Double = 0.0 fun FriendlyByteBuf.readDecimal() = Decimal.read(this) fun FriendlyByteBuf.writeDecimal(value: Decimal) = value.write(this) -fun InputStream.readDecimal(): Decimal { - val bytes = ByteArray(readVarIntLE()) +fun InputStream.readDecimal(sizeLimit: NbtAccounter = NbtAccounter(512L)): Decimal { + val size = readVarIntLE(sizeLimit) + require(size >= 0) { "Negative payload size: $size" } + sizeLimit.accountBytes(size.toLong() + 8L) + val bytes = ByteArray(size) read(bytes) return Decimal(BigInteger(bytes), readDouble()) } diff --git a/src/main/kotlin/ru/dbotthepony/mc/otm/core/util/DataStreams.kt b/src/main/kotlin/ru/dbotthepony/mc/otm/core/util/DataStreams.kt index 95c4308d3..b4ea5972d 100644 --- a/src/main/kotlin/ru/dbotthepony/mc/otm/core/util/DataStreams.kt +++ b/src/main/kotlin/ru/dbotthepony/mc/otm/core/util/DataStreams.kt @@ -1,9 +1,12 @@ package ru.dbotthepony.mc.otm.core.util +import net.minecraft.nbt.NbtAccounter import net.minecraft.world.item.ItemStack import ru.dbotthepony.mc.otm.core.math.readDecimal import ru.dbotthepony.mc.otm.core.math.writeDecimal +import java.io.DataInput import java.io.DataInputStream +import java.io.DataOutput import java.io.DataOutputStream import java.io.InputStream import java.io.OutputStream @@ -18,7 +21,7 @@ import kotlin.math.absoluteValue * Also provides [copy] and [compare] methods */ interface IStreamCodec { - fun read(stream: DataInputStream): V + fun read(stream: DataInputStream, sizeLimit: NbtAccounter = NbtAccounter(1L shl 18 /* 256 KiB */)): V fun write(stream: DataOutputStream, value: V) /** @@ -37,13 +40,21 @@ interface IStreamCodec { } class StreamCodec( - private val reader: (stream: DataInputStream) -> V, + private val reader: (stream: DataInputStream, sizeLimit: NbtAccounter) -> V, private val writer: (stream: DataOutputStream, value: V) -> Unit, private val copier: ((value: V) -> V) = { it }, private val comparator: ((a: V, b: V) -> Boolean) = { a, b -> a == b } ) : IStreamCodec { - override fun read(stream: DataInputStream): V { - return reader.invoke(stream) + constructor( + reader: (stream: DataInputStream) -> V, + payloadSize: Long, + writer: (stream: DataOutputStream, value: V) -> Unit, + copier: ((value: V) -> V) = { it }, + comparator: ((a: V, b: V) -> Boolean) = { a, b -> a == b } + ) : this({ stream, sizeLimit -> sizeLimit.accountBytes(payloadSize); reader.invoke(stream) }, writer, copier, comparator) + + override fun read(stream: DataInputStream, sizeLimit: NbtAccounter): V { + return reader.invoke(stream, sizeLimit) } override fun write(stream: DataOutputStream, value: V) { @@ -59,18 +70,18 @@ class StreamCodec( } } -val NullValueCodec = StreamCodec({ null }, { _, _ -> }) -val BooleanValueCodec = StreamCodec(DataInputStream::readBoolean, DataOutputStream::writeBoolean) -val ByteValueCodec = StreamCodec(DataInputStream::readByte, { s, v -> s.writeByte(v.toInt()) }) -val ShortValueCodec = StreamCodec(DataInputStream::readShort, { s, v -> s.writeShort(v.toInt()) }) -val IntValueCodec = StreamCodec(DataInputStream::readInt, DataOutputStream::writeInt) -val LongValueCodec = StreamCodec(DataInputStream::readLong, DataOutputStream::writeLong) -val FloatValueCodec = StreamCodec(DataInputStream::readFloat, DataOutputStream::writeFloat) -val DoubleValueCodec = StreamCodec(DataInputStream::readDouble, DataOutputStream::writeDouble) +val NullValueCodec = StreamCodec({ _, _ -> null }, { _, _ -> }) +val BooleanValueCodec = StreamCodec(DataInputStream::readBoolean, 1L, DataOutputStream::writeBoolean) +val ByteValueCodec = StreamCodec(DataInputStream::readByte, 1L, { s, v -> s.writeByte(v.toInt()) }) +val ShortValueCodec = StreamCodec(DataInputStream::readShort, 2L, { s, v -> s.writeShort(v.toInt()) }) +val IntValueCodec = StreamCodec(DataInputStream::readInt, 4L, DataOutputStream::writeInt) +val LongValueCodec = StreamCodec(DataInputStream::readLong, 8L, DataOutputStream::writeLong) +val FloatValueCodec = StreamCodec(DataInputStream::readFloat, 4L, DataOutputStream::writeFloat) +val DoubleValueCodec = StreamCodec(DataInputStream::readDouble, 8L, DataOutputStream::writeDouble) val ItemStackValueCodec = StreamCodec(DataInputStream::readItem, DataOutputStream::writeItem, ItemStack::copy) { a, b -> a.equals(b, true) } val ImpreciseFractionValueCodec = StreamCodec(DataInputStream::readDecimal, DataOutputStream::writeDecimal) val BigDecimalValueCodec = StreamCodec(DataInputStream::readBigDecimal, DataOutputStream::writeBigDecimal) -val UUIDValueCodec = StreamCodec({ s -> UUID(s.readLong(), s.readLong()) }, { s, v -> s.writeLong(v.mostSignificantBits); s.writeLong(v.leastSignificantBits) }) +val UUIDValueCodec = StreamCodec({ s, a -> a.accountBytes(8L); UUID(s.readLong(), s.readLong()) }, { s, v -> s.writeLong(v.mostSignificantBits); s.writeLong(v.leastSignificantBits) }) val VarIntValueCodec = StreamCodec(DataInputStream::readVarIntLE, DataOutputStream::writeVarIntLE) val VarLongValueCodec = StreamCodec(DataInputStream::readVarLongLE, DataOutputStream::writeVarLongLE) val BinaryStringCodec = StreamCodec(DataInputStream::readBinaryString, DataOutputStream::writeBinaryString) @@ -79,13 +90,13 @@ class EnumValueCodec>(clazz: Class, val writeByIndices: Boole val clazz = searchClass(clazz) private val values = searchClass(clazz).enumConstants!! - override fun read(stream: DataInputStream): V { + override fun read(stream: DataInputStream, sizeLimit: NbtAccounter): V { if (writeByIndices) { - val id = stream.readVarIntLE() + val id = stream.readVarIntLE(sizeLimit) return values.getOrNull(id) ?: throw NoSuchElementException("No such enum with index $id") } - val id = stream.readBinaryString() + val id = stream.readBinaryString(sizeLimit) return values.firstOrNull { id == it.name } ?: throw NoSuchElementException("No such enum $id") } @@ -127,7 +138,7 @@ class EnumValueCodec>(clazz: Class, val writeByIndices: Boole } fun OutputStream.writeInt(value: Int) { - if (this is DataOutputStream) { + if (this is DataOutput) { writeInt(value) return } @@ -139,7 +150,7 @@ fun OutputStream.writeInt(value: Int) { } fun InputStream.readInt(): Int { - if (this is DataInputStream) { + if (this is DataInput) { return readInt() } @@ -147,7 +158,7 @@ fun InputStream.readInt(): Int { } fun OutputStream.writeLong(value: Long) { - if (this is DataOutputStream) { + if (this is DataOutput) { writeLong(value) return } @@ -162,7 +173,7 @@ fun OutputStream.writeLong(value: Long) { } fun InputStream.readLong(): Long { - if (this is DataInputStream) { + if (this is DataInput) { return readLong() } @@ -180,7 +191,8 @@ fun InputStream.readFloat() = Float.fromBits(readInt()) fun OutputStream.writeDouble(value: Double) = writeLong(value.toBits()) fun InputStream.readDouble() = Double.fromBits(readLong()) -fun InputStream.readVarIntLE(): Int { +fun InputStream.readVarIntLE(sizeLimit: NbtAccounter? = null): Int { + sizeLimit?.accountBytes(1L) val readFirst = read() if (readFirst < 0) { @@ -198,6 +210,7 @@ fun InputStream.readVarIntLE(): Int { while (nextBit != 0) { result = result or (read shl i) + sizeLimit?.accountBytes(1L) read = read() if (read < 0) { @@ -227,7 +240,9 @@ fun OutputStream.writeVarIntLE(value: Int) { } } -fun InputStream.readVarLongLE(): Long { +fun InputStream.readVarLongLE(sizeLimit: NbtAccounter? = null): Long { + sizeLimit?.accountBytes(1L) + val readFirst = read() if (readFirst < 0) { @@ -245,6 +260,7 @@ fun InputStream.readVarLongLE(): Long { while (nextBit != 0) { result = result or (read shl i).toLong() + sizeLimit?.accountBytes(1L) read = read() if (read < 0) { @@ -274,8 +290,10 @@ fun OutputStream.writeVarLongLE(value: Long) { } } -fun InputStream.readBinaryString(): String { +fun InputStream.readBinaryString(sizeLimit: NbtAccounter = NbtAccounter(1L shl 18 /* 256 KiB */)): String { val size = readVarIntLE() + require(size >= 0) { "Negative payload size: $size" } + sizeLimit.accountBytes(size.toLong()) val bytes = ByteArray(size) read(bytes) return bytes.decodeToString() @@ -292,8 +310,8 @@ private data class IndexedStreamCodec( val id: Int, val codec: StreamCodec ) { - fun read(stream: DataInputStream): T { - return codec.read(stream) + fun read(stream: DataInputStream, sizeLimit: NbtAccounter = NbtAccounter(1L shl 18 /* 256 KiB */)): T { + return codec.read(stream, sizeLimit) } fun write(stream: DataOutputStream, value: Any?) { @@ -346,12 +364,13 @@ fun DataOutputStream.writeType(value: Any?) { /** * Read arbitrary data from this stream, in exploit-free way */ -fun DataInputStream.readType(): Any? { +fun DataInputStream.readType(sizeLimit: NbtAccounter = NbtAccounter(1L shl 18 /* 256 KiB */)): Any? { + sizeLimit.accountBytes(1L) val id = read() if (id >= codecs.size) { throw IndexOutOfBoundsException("No codec for network type $id") } - return codecs[id].read(this) + return codecs[id].read(this, sizeLimit) } diff --git a/src/main/kotlin/ru/dbotthepony/mc/otm/core/util/FriendlyStreams.kt b/src/main/kotlin/ru/dbotthepony/mc/otm/core/util/FriendlyStreams.kt index 01a6dc749..cb67e89bd 100644 --- a/src/main/kotlin/ru/dbotthepony/mc/otm/core/util/FriendlyStreams.kt +++ b/src/main/kotlin/ru/dbotthepony/mc/otm/core/util/FriendlyStreams.kt @@ -18,6 +18,7 @@ import net.minecraft.world.item.Item import net.minecraft.world.item.ItemStack import net.minecraftforge.registries.ForgeRegistries import net.minecraftforge.registries.ForgeRegistry +import org.apache.commons.lang3.mutable.MutableInt import java.io.* import java.math.BigDecimal import java.math.BigInteger @@ -36,7 +37,7 @@ fun OutputStream.writeNbt(value: CompoundTag) { } } -fun InputStream.readNbt(accounter: NbtAccounter = NbtAccounter.UNLIMITED): CompoundTag { +fun InputStream.readNbt(accounter: NbtAccounter = NbtAccounter(1L shl 18 /* 256 KiB */)): CompoundTag { return try { NbtIo.read(if (this is DataInputStream) this else DataInputStream(this), accounter) } catch (ioexception: IOException) { @@ -68,17 +69,19 @@ fun OutputStream.writeItem(itemStack: ItemStack, limitedTag: Boolean = true) { } } -fun InputStream.readItem(): ItemStack { +fun InputStream.readItem(sizeLimit: NbtAccounter = NbtAccounter(1L shl 18 /* 256 KiB */)): ItemStack { + sizeLimit.accountBytes(1L) + if (read() == 0) { return ItemStack.EMPTY } - + sizeLimit.accountBytes(9L) val item = (ForgeRegistries.ITEMS as ForgeRegistry).getValue(readInt()) val itemStack = ItemStack(item, readInt()) if (read() != 0) { - itemStack.readShareTag(readNbt()) + itemStack.readShareTag(readNbt(sizeLimit)) } return itemStack @@ -91,9 +94,11 @@ fun OutputStream.writeBigDecimal(value: BigDecimal) { write(bytes) } -fun InputStream.readBigDecimal(): BigDecimal { +fun InputStream.readBigDecimal(sizeLimit: NbtAccounter = NbtAccounter(512L)): BigDecimal { val scale = readInt() - val size = readVarIntLE() + val size = readVarIntLE(sizeLimit) + require(size >= 0) { "Negative payload size: $size" } + sizeLimit.accountBytes(size.toLong() + 4L) val bytes = ByteArray(size) read(bytes) return BigDecimal(BigInteger(bytes), scale) @@ -180,25 +185,33 @@ fun OutputStream.writeJson(element: JsonElement) { * * just copy pasted this code from my another project because i was lazy */ -fun InputStream.readJson(): JsonElement { +fun InputStream.readJson(sizeLimit: NbtAccounter = NbtAccounter(1L shl 18 /* 256 KiB */)): JsonElement { + sizeLimit.accountBytes(1L) + return when (val id = read()) { TYPE_NULL -> JsonNull.INSTANCE - TYPE_DOUBLE -> JsonPrimitive(readDouble()) - TYPE_BOOLEAN -> JsonPrimitive(read() > 1) - TYPE_INT -> JsonPrimitive(fixSignedInt(readVarLongLE())) - TYPE_STRING -> JsonPrimitive(readBinaryString()) + TYPE_DOUBLE -> { + sizeLimit.accountBytes(8L) + JsonPrimitive(readDouble()) + } + TYPE_BOOLEAN -> { + sizeLimit.accountBytes(1L) + JsonPrimitive(read() > 1) + } + TYPE_INT -> JsonPrimitive(fixSignedInt(readVarLongLE(sizeLimit))) + TYPE_STRING -> JsonPrimitive(readBinaryString(sizeLimit)) TYPE_ARRAY -> { - val values = readVarIntLE() + val values = readVarIntLE(sizeLimit) if (values == 0) return JsonArray() if (values < 0) throw JsonSyntaxException("Tried to read json array with $values elements in it") val build = JsonArray(values) - for (i in 0 until values) build.add(readJson()) + for (i in 0 until values) build.add(readJson(sizeLimit)) return build } TYPE_OBJECT -> { - val values = readVarIntLE() + val values = readVarIntLE(sizeLimit) if (values == 0) return JsonObject() if (values < 0) throw JsonSyntaxException("Tried to read json object with $values elements in it") @@ -208,13 +221,13 @@ fun InputStream.readJson(): JsonElement { val key: String try { - key = readBinaryString() + key = readBinaryString(sizeLimit) } catch(err: Throwable) { throw JsonSyntaxException("Reading json object at $i", err) } try { - build.add(key, readJson()) + build.add(key, readJson(sizeLimit)) } catch(err: Throwable) { throw JsonSyntaxException("Reading json object at $i with name $key", err) } @@ -226,8 +239,8 @@ fun InputStream.readJson(): JsonElement { } } -fun FriendlyByteBuf.readJson(): JsonElement { - return ByteBufInputStream(this).readJson() +fun FriendlyByteBuf.readJson(sizeLimit: NbtAccounter = NbtAccounter(1L shl 18)): JsonElement { + return ByteBufInputStream(this).readJson(sizeLimit) } fun FriendlyByteBuf.writeJson(value: JsonElement) { diff --git a/src/main/kotlin/ru/dbotthepony/mc/otm/data/SerializedFunctionRegistry.kt b/src/main/kotlin/ru/dbotthepony/mc/otm/data/SerializedFunctionRegistry.kt index 548d43200..0a916f07b 100644 --- a/src/main/kotlin/ru/dbotthepony/mc/otm/data/SerializedFunctionRegistry.kt +++ b/src/main/kotlin/ru/dbotthepony/mc/otm/data/SerializedFunctionRegistry.kt @@ -17,6 +17,7 @@ import io.netty.buffer.ByteBufInputStream import io.netty.buffer.ByteBufOutputStream import it.unimi.dsi.fastutil.io.FastByteArrayInputStream import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream +import net.minecraft.nbt.NbtAccounter import net.minecraft.network.FriendlyByteBuf import net.minecraft.resources.ResourceLocation import ru.dbotthepony.mc.otm.core.util.readType @@ -175,9 +176,10 @@ class SerializedFunctionRegistry(val gson: Gson = Gson()) : JsonSerializer val stream = DataInputStream(ByteBufInputStream(buff)) val arguments = LinkedList() + val sizeLimit = NbtAccounter(1L shl 18 /* 256 KiB */) - for (i in 0 until stream.readVarIntLE()) { - arguments.add(stream.readType()) + for (i in 0 until stream.readVarIntLE(sizeLimit)) { + arguments.add(stream.readType(sizeLimit)) } return map[id]?.bind(arguments) @@ -192,9 +194,10 @@ class SerializedFunctionRegistry(val gson: Gson = Gson()) : JsonSerializer val argumentString = value["arguments"]?.asString ?: return null val stream = DataInputStream(FastByteArrayInputStream(Base64.getDecoder().decode(argumentString))) val arguments = LinkedList() + val sizeLimit = NbtAccounter(1L shl 18 /* 256 KiB */) - for (i in 0 until stream.readVarIntLE()) { - arguments.add(stream.readType()) + for (i in 0 until stream.readVarIntLE(sizeLimit)) { + arguments.add(stream.readType(sizeLimit)) } return map[ResourceLocation(id)]?.bind(arguments) diff --git a/src/main/kotlin/ru/dbotthepony/mc/otm/network/FieldSynchronizer.kt b/src/main/kotlin/ru/dbotthepony/mc/otm/network/FieldSynchronizer.kt index d3b1b5cd7..af8e37281 100644 --- a/src/main/kotlin/ru/dbotthepony/mc/otm/network/FieldSynchronizer.kt +++ b/src/main/kotlin/ru/dbotthepony/mc/otm/network/FieldSynchronizer.kt @@ -8,6 +8,7 @@ import it.unimi.dsi.fastutil.objects.ObjectArraySet import it.unimi.dsi.fastutil.objects.Reference2ObjectFunction import it.unimi.dsi.fastutil.objects.Reference2ObjectOpenHashMap import it.unimi.dsi.fastutil.objects.ReferenceArraySet +import net.minecraft.nbt.NbtAccounter import net.minecraft.world.item.ItemStack import org.apache.logging.log4j.LogManager import ru.dbotthepony.mc.otm.core.* @@ -1111,15 +1112,15 @@ class FieldSynchronizer(private val callback: Runnable, private val alwaysCallCa private val missingFields = ObjectArraySet() private val missingFieldsMap = Int2ObjectArrayMap() - fun applyNetworkPayload(stream: InputStream): Int { + fun applyNetworkPayload(stream: InputStream, sizeLimit: NbtAccounter = NbtAccounter(1L shl 21 /* 2 MiB */)): Int { if (stream.read() > 0) { idToField.clear() missingFieldsMap.clear() - var fieldId = stream.readVarIntLE() + var fieldId = stream.readVarIntLE(sizeLimit) while (fieldId != 0) { - val size = stream.readVarIntLE() + val size = stream.readVarIntLE(sizeLimit) val nameBytes = ByteArray(size) stream.read(nameBytes) val name = String(nameBytes, Charsets.UTF_8) @@ -1137,27 +1138,29 @@ class FieldSynchronizer(private val callback: Runnable, private val alwaysCallCa findField.id = fieldId } - fieldId = stream.readVarIntLE() + fieldId = stream.readVarIntLE(sizeLimit) } } - var fieldId = stream.readVarIntLE() + var fieldId = stream.readVarIntLE(sizeLimit) var i = 0 while (fieldId != 0) { val field = idToField[fieldId] - val payloadSize = stream.readVarIntLE() + val payloadSize = stream.readVarIntLE(sizeLimit) if (field == null) { LOGGER.error("Unable to read field $fieldId (${missingFieldsMap[fieldId]}) because we don't know anything about it! Skipping $payloadSize bytes", IllegalStateException("Unknown field $fieldId")) + sizeLimit.accountBytes(payloadSize.toLong()) stream.skipNBytes(payloadSize.toLong()) continue } + sizeLimit.accountBytes(payloadSize.toLong()) val bytes = ByteArray(payloadSize) stream.read(bytes) field.read(DataInputStream(FastByteArrayInputStream(bytes)), payloadSize) - fieldId = stream.readVarIntLE() + fieldId = stream.readVarIntLE(sizeLimit) i++ }