Idiot proof network read methods

This commit is contained in:
DBotThePony 2023-02-14 22:49:57 +07:00
parent a7266ec01e
commit 513ef21926
Signed by: DBot
GPG Key ID: DCC23B5715498507
5 changed files with 100 additions and 58 deletions

View File

@ -2,6 +2,7 @@ package ru.dbotthepony.mc.otm.core.math
import net.minecraft.nbt.ByteArrayTag
import net.minecraft.nbt.CompoundTag
import net.minecraft.nbt.NbtAccounter
import net.minecraft.nbt.StringTag
import net.minecraft.nbt.Tag
import net.minecraft.network.FriendlyByteBuf
@ -859,8 +860,11 @@ class Decimal @JvmOverloads constructor(whole: BigInteger, decimal: Double = 0.0
fun FriendlyByteBuf.readDecimal() = Decimal.read(this)
fun FriendlyByteBuf.writeDecimal(value: Decimal) = value.write(this)
fun InputStream.readDecimal(): Decimal {
val bytes = ByteArray(readVarIntLE())
fun InputStream.readDecimal(sizeLimit: NbtAccounter = NbtAccounter(512L)): Decimal {
val size = readVarIntLE(sizeLimit)
require(size >= 0) { "Negative payload size: $size" }
sizeLimit.accountBytes(size.toLong() + 8L)
val bytes = ByteArray(size)
read(bytes)
return Decimal(BigInteger(bytes), readDouble())
}

View File

@ -1,9 +1,12 @@
package ru.dbotthepony.mc.otm.core.util
import net.minecraft.nbt.NbtAccounter
import net.minecraft.world.item.ItemStack
import ru.dbotthepony.mc.otm.core.math.readDecimal
import ru.dbotthepony.mc.otm.core.math.writeDecimal
import java.io.DataInput
import java.io.DataInputStream
import java.io.DataOutput
import java.io.DataOutputStream
import java.io.InputStream
import java.io.OutputStream
@ -18,7 +21,7 @@ import kotlin.math.absoluteValue
* Also provides [copy] and [compare] methods
*/
interface IStreamCodec<V> {
fun read(stream: DataInputStream): V
fun read(stream: DataInputStream, sizeLimit: NbtAccounter = NbtAccounter(1L shl 18 /* 256 KiB */)): V
fun write(stream: DataOutputStream, value: V)
/**
@ -37,13 +40,21 @@ interface IStreamCodec<V> {
}
class StreamCodec<V>(
private val reader: (stream: DataInputStream) -> V,
private val reader: (stream: DataInputStream, sizeLimit: NbtAccounter) -> V,
private val writer: (stream: DataOutputStream, value: V) -> Unit,
private val copier: ((value: V) -> V) = { it },
private val comparator: ((a: V, b: V) -> Boolean) = { a, b -> a == b }
) : IStreamCodec<V> {
override fun read(stream: DataInputStream): V {
return reader.invoke(stream)
constructor(
reader: (stream: DataInputStream) -> V,
payloadSize: Long,
writer: (stream: DataOutputStream, value: V) -> Unit,
copier: ((value: V) -> V) = { it },
comparator: ((a: V, b: V) -> Boolean) = { a, b -> a == b }
) : this({ stream, sizeLimit -> sizeLimit.accountBytes(payloadSize); reader.invoke(stream) }, writer, copier, comparator)
override fun read(stream: DataInputStream, sizeLimit: NbtAccounter): V {
return reader.invoke(stream, sizeLimit)
}
override fun write(stream: DataOutputStream, value: V) {
@ -59,18 +70,18 @@ class StreamCodec<V>(
}
}
val NullValueCodec = StreamCodec({ null }, { _, _ -> })
val BooleanValueCodec = StreamCodec(DataInputStream::readBoolean, DataOutputStream::writeBoolean)
val ByteValueCodec = StreamCodec(DataInputStream::readByte, { s, v -> s.writeByte(v.toInt()) })
val ShortValueCodec = StreamCodec(DataInputStream::readShort, { s, v -> s.writeShort(v.toInt()) })
val IntValueCodec = StreamCodec(DataInputStream::readInt, DataOutputStream::writeInt)
val LongValueCodec = StreamCodec(DataInputStream::readLong, DataOutputStream::writeLong)
val FloatValueCodec = StreamCodec(DataInputStream::readFloat, DataOutputStream::writeFloat)
val DoubleValueCodec = StreamCodec(DataInputStream::readDouble, DataOutputStream::writeDouble)
val NullValueCodec = StreamCodec({ _, _ -> null }, { _, _ -> })
val BooleanValueCodec = StreamCodec(DataInputStream::readBoolean, 1L, DataOutputStream::writeBoolean)
val ByteValueCodec = StreamCodec(DataInputStream::readByte, 1L, { s, v -> s.writeByte(v.toInt()) })
val ShortValueCodec = StreamCodec(DataInputStream::readShort, 2L, { s, v -> s.writeShort(v.toInt()) })
val IntValueCodec = StreamCodec(DataInputStream::readInt, 4L, DataOutputStream::writeInt)
val LongValueCodec = StreamCodec(DataInputStream::readLong, 8L, DataOutputStream::writeLong)
val FloatValueCodec = StreamCodec(DataInputStream::readFloat, 4L, DataOutputStream::writeFloat)
val DoubleValueCodec = StreamCodec(DataInputStream::readDouble, 8L, DataOutputStream::writeDouble)
val ItemStackValueCodec = StreamCodec(DataInputStream::readItem, DataOutputStream::writeItem, ItemStack::copy) { a, b -> a.equals(b, true) }
val ImpreciseFractionValueCodec = StreamCodec(DataInputStream::readDecimal, DataOutputStream::writeDecimal)
val BigDecimalValueCodec = StreamCodec(DataInputStream::readBigDecimal, DataOutputStream::writeBigDecimal)
val UUIDValueCodec = StreamCodec({ s -> UUID(s.readLong(), s.readLong()) }, { s, v -> s.writeLong(v.mostSignificantBits); s.writeLong(v.leastSignificantBits) })
val UUIDValueCodec = StreamCodec({ s, a -> a.accountBytes(8L); UUID(s.readLong(), s.readLong()) }, { s, v -> s.writeLong(v.mostSignificantBits); s.writeLong(v.leastSignificantBits) })
val VarIntValueCodec = StreamCodec(DataInputStream::readVarIntLE, DataOutputStream::writeVarIntLE)
val VarLongValueCodec = StreamCodec(DataInputStream::readVarLongLE, DataOutputStream::writeVarLongLE)
val BinaryStringCodec = StreamCodec(DataInputStream::readBinaryString, DataOutputStream::writeBinaryString)
@ -79,13 +90,13 @@ class EnumValueCodec<V : Enum<V>>(clazz: Class<out V>, val writeByIndices: Boole
val clazz = searchClass(clazz)
private val values = searchClass(clazz).enumConstants!!
override fun read(stream: DataInputStream): V {
override fun read(stream: DataInputStream, sizeLimit: NbtAccounter): V {
if (writeByIndices) {
val id = stream.readVarIntLE()
val id = stream.readVarIntLE(sizeLimit)
return values.getOrNull(id) ?: throw NoSuchElementException("No such enum with index $id")
}
val id = stream.readBinaryString()
val id = stream.readBinaryString(sizeLimit)
return values.firstOrNull { id == it.name } ?: throw NoSuchElementException("No such enum $id")
}
@ -127,7 +138,7 @@ class EnumValueCodec<V : Enum<V>>(clazz: Class<out V>, val writeByIndices: Boole
}
fun OutputStream.writeInt(value: Int) {
if (this is DataOutputStream) {
if (this is DataOutput) {
writeInt(value)
return
}
@ -139,7 +150,7 @@ fun OutputStream.writeInt(value: Int) {
}
fun InputStream.readInt(): Int {
if (this is DataInputStream) {
if (this is DataInput) {
return readInt()
}
@ -147,7 +158,7 @@ fun InputStream.readInt(): Int {
}
fun OutputStream.writeLong(value: Long) {
if (this is DataOutputStream) {
if (this is DataOutput) {
writeLong(value)
return
}
@ -162,7 +173,7 @@ fun OutputStream.writeLong(value: Long) {
}
fun InputStream.readLong(): Long {
if (this is DataInputStream) {
if (this is DataInput) {
return readLong()
}
@ -180,7 +191,8 @@ fun InputStream.readFloat() = Float.fromBits(readInt())
fun OutputStream.writeDouble(value: Double) = writeLong(value.toBits())
fun InputStream.readDouble() = Double.fromBits(readLong())
fun InputStream.readVarIntLE(): Int {
fun InputStream.readVarIntLE(sizeLimit: NbtAccounter? = null): Int {
sizeLimit?.accountBytes(1L)
val readFirst = read()
if (readFirst < 0) {
@ -198,6 +210,7 @@ fun InputStream.readVarIntLE(): Int {
while (nextBit != 0) {
result = result or (read shl i)
sizeLimit?.accountBytes(1L)
read = read()
if (read < 0) {
@ -227,7 +240,9 @@ fun OutputStream.writeVarIntLE(value: Int) {
}
}
fun InputStream.readVarLongLE(): Long {
fun InputStream.readVarLongLE(sizeLimit: NbtAccounter? = null): Long {
sizeLimit?.accountBytes(1L)
val readFirst = read()
if (readFirst < 0) {
@ -245,6 +260,7 @@ fun InputStream.readVarLongLE(): Long {
while (nextBit != 0) {
result = result or (read shl i).toLong()
sizeLimit?.accountBytes(1L)
read = read()
if (read < 0) {
@ -274,8 +290,10 @@ fun OutputStream.writeVarLongLE(value: Long) {
}
}
fun InputStream.readBinaryString(): String {
fun InputStream.readBinaryString(sizeLimit: NbtAccounter = NbtAccounter(1L shl 18 /* 256 KiB */)): String {
val size = readVarIntLE()
require(size >= 0) { "Negative payload size: $size" }
sizeLimit.accountBytes(size.toLong())
val bytes = ByteArray(size)
read(bytes)
return bytes.decodeToString()
@ -292,8 +310,8 @@ private data class IndexedStreamCodec<T>(
val id: Int,
val codec: StreamCodec<T>
) {
fun read(stream: DataInputStream): T {
return codec.read(stream)
fun read(stream: DataInputStream, sizeLimit: NbtAccounter = NbtAccounter(1L shl 18 /* 256 KiB */)): T {
return codec.read(stream, sizeLimit)
}
fun write(stream: DataOutputStream, value: Any?) {
@ -346,12 +364,13 @@ fun DataOutputStream.writeType(value: Any?) {
/**
* Read arbitrary data from this stream, in exploit-free way
*/
fun DataInputStream.readType(): Any? {
fun DataInputStream.readType(sizeLimit: NbtAccounter = NbtAccounter(1L shl 18 /* 256 KiB */)): Any? {
sizeLimit.accountBytes(1L)
val id = read()
if (id >= codecs.size) {
throw IndexOutOfBoundsException("No codec for network type $id")
}
return codecs[id].read(this)
return codecs[id].read(this, sizeLimit)
}

View File

@ -18,6 +18,7 @@ import net.minecraft.world.item.Item
import net.minecraft.world.item.ItemStack
import net.minecraftforge.registries.ForgeRegistries
import net.minecraftforge.registries.ForgeRegistry
import org.apache.commons.lang3.mutable.MutableInt
import java.io.*
import java.math.BigDecimal
import java.math.BigInteger
@ -36,7 +37,7 @@ fun OutputStream.writeNbt(value: CompoundTag) {
}
}
fun InputStream.readNbt(accounter: NbtAccounter = NbtAccounter.UNLIMITED): CompoundTag {
fun InputStream.readNbt(accounter: NbtAccounter = NbtAccounter(1L shl 18 /* 256 KiB */)): CompoundTag {
return try {
NbtIo.read(if (this is DataInputStream) this else DataInputStream(this), accounter)
} catch (ioexception: IOException) {
@ -68,17 +69,19 @@ fun OutputStream.writeItem(itemStack: ItemStack, limitedTag: Boolean = true) {
}
}
fun InputStream.readItem(): ItemStack {
fun InputStream.readItem(sizeLimit: NbtAccounter = NbtAccounter(1L shl 18 /* 256 KiB */)): ItemStack {
sizeLimit.accountBytes(1L)
if (read() == 0) {
return ItemStack.EMPTY
}
sizeLimit.accountBytes(9L)
val item = (ForgeRegistries.ITEMS as ForgeRegistry<Item>).getValue(readInt())
val itemStack = ItemStack(item, readInt())
if (read() != 0) {
itemStack.readShareTag(readNbt())
itemStack.readShareTag(readNbt(sizeLimit))
}
return itemStack
@ -91,9 +94,11 @@ fun OutputStream.writeBigDecimal(value: BigDecimal) {
write(bytes)
}
fun InputStream.readBigDecimal(): BigDecimal {
fun InputStream.readBigDecimal(sizeLimit: NbtAccounter = NbtAccounter(512L)): BigDecimal {
val scale = readInt()
val size = readVarIntLE()
val size = readVarIntLE(sizeLimit)
require(size >= 0) { "Negative payload size: $size" }
sizeLimit.accountBytes(size.toLong() + 4L)
val bytes = ByteArray(size)
read(bytes)
return BigDecimal(BigInteger(bytes), scale)
@ -180,25 +185,33 @@ fun OutputStream.writeJson(element: JsonElement) {
*
* just copy pasted this code from my another project because i was lazy
*/
fun InputStream.readJson(): JsonElement {
fun InputStream.readJson(sizeLimit: NbtAccounter = NbtAccounter(1L shl 18 /* 256 KiB */)): JsonElement {
sizeLimit.accountBytes(1L)
return when (val id = read()) {
TYPE_NULL -> JsonNull.INSTANCE
TYPE_DOUBLE -> JsonPrimitive(readDouble())
TYPE_BOOLEAN -> JsonPrimitive(read() > 1)
TYPE_INT -> JsonPrimitive(fixSignedInt(readVarLongLE()))
TYPE_STRING -> JsonPrimitive(readBinaryString())
TYPE_DOUBLE -> {
sizeLimit.accountBytes(8L)
JsonPrimitive(readDouble())
}
TYPE_BOOLEAN -> {
sizeLimit.accountBytes(1L)
JsonPrimitive(read() > 1)
}
TYPE_INT -> JsonPrimitive(fixSignedInt(readVarLongLE(sizeLimit)))
TYPE_STRING -> JsonPrimitive(readBinaryString(sizeLimit))
TYPE_ARRAY -> {
val values = readVarIntLE()
val values = readVarIntLE(sizeLimit)
if (values == 0) return JsonArray()
if (values < 0) throw JsonSyntaxException("Tried to read json array with $values elements in it")
val build = JsonArray(values)
for (i in 0 until values) build.add(readJson())
for (i in 0 until values) build.add(readJson(sizeLimit))
return build
}
TYPE_OBJECT -> {
val values = readVarIntLE()
val values = readVarIntLE(sizeLimit)
if (values == 0) return JsonObject()
if (values < 0) throw JsonSyntaxException("Tried to read json object with $values elements in it")
@ -208,13 +221,13 @@ fun InputStream.readJson(): JsonElement {
val key: String
try {
key = readBinaryString()
key = readBinaryString(sizeLimit)
} catch(err: Throwable) {
throw JsonSyntaxException("Reading json object at $i", err)
}
try {
build.add(key, readJson())
build.add(key, readJson(sizeLimit))
} catch(err: Throwable) {
throw JsonSyntaxException("Reading json object at $i with name $key", err)
}
@ -226,8 +239,8 @@ fun InputStream.readJson(): JsonElement {
}
}
fun FriendlyByteBuf.readJson(): JsonElement {
return ByteBufInputStream(this).readJson()
fun FriendlyByteBuf.readJson(sizeLimit: NbtAccounter = NbtAccounter(1L shl 18)): JsonElement {
return ByteBufInputStream(this).readJson(sizeLimit)
}
fun FriendlyByteBuf.writeJson(value: JsonElement) {

View File

@ -17,6 +17,7 @@ import io.netty.buffer.ByteBufInputStream
import io.netty.buffer.ByteBufOutputStream
import it.unimi.dsi.fastutil.io.FastByteArrayInputStream
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
import net.minecraft.nbt.NbtAccounter
import net.minecraft.network.FriendlyByteBuf
import net.minecraft.resources.ResourceLocation
import ru.dbotthepony.mc.otm.core.util.readType
@ -175,9 +176,10 @@ class SerializedFunctionRegistry<R, T>(val gson: Gson = Gson()) : JsonSerializer
val stream = DataInputStream(ByteBufInputStream(buff))
val arguments = LinkedList<Any?>()
val sizeLimit = NbtAccounter(1L shl 18 /* 256 KiB */)
for (i in 0 until stream.readVarIntLE()) {
arguments.add(stream.readType())
for (i in 0 until stream.readVarIntLE(sizeLimit)) {
arguments.add(stream.readType(sizeLimit))
}
return map[id]?.bind(arguments)
@ -192,9 +194,10 @@ class SerializedFunctionRegistry<R, T>(val gson: Gson = Gson()) : JsonSerializer
val argumentString = value["arguments"]?.asString ?: return null
val stream = DataInputStream(FastByteArrayInputStream(Base64.getDecoder().decode(argumentString)))
val arguments = LinkedList<Any?>()
val sizeLimit = NbtAccounter(1L shl 18 /* 256 KiB */)
for (i in 0 until stream.readVarIntLE()) {
arguments.add(stream.readType())
for (i in 0 until stream.readVarIntLE(sizeLimit)) {
arguments.add(stream.readType(sizeLimit))
}
return map[ResourceLocation(id)]?.bind(arguments)

View File

@ -8,6 +8,7 @@ import it.unimi.dsi.fastutil.objects.ObjectArraySet
import it.unimi.dsi.fastutil.objects.Reference2ObjectFunction
import it.unimi.dsi.fastutil.objects.Reference2ObjectOpenHashMap
import it.unimi.dsi.fastutil.objects.ReferenceArraySet
import net.minecraft.nbt.NbtAccounter
import net.minecraft.world.item.ItemStack
import org.apache.logging.log4j.LogManager
import ru.dbotthepony.mc.otm.core.*
@ -1111,15 +1112,15 @@ class FieldSynchronizer(private val callback: Runnable, private val alwaysCallCa
private val missingFields = ObjectArraySet<String>()
private val missingFieldsMap = Int2ObjectArrayMap<String>()
fun applyNetworkPayload(stream: InputStream): Int {
fun applyNetworkPayload(stream: InputStream, sizeLimit: NbtAccounter = NbtAccounter(1L shl 21 /* 2 MiB */)): Int {
if (stream.read() > 0) {
idToField.clear()
missingFieldsMap.clear()
var fieldId = stream.readVarIntLE()
var fieldId = stream.readVarIntLE(sizeLimit)
while (fieldId != 0) {
val size = stream.readVarIntLE()
val size = stream.readVarIntLE(sizeLimit)
val nameBytes = ByteArray(size)
stream.read(nameBytes)
val name = String(nameBytes, Charsets.UTF_8)
@ -1137,27 +1138,29 @@ class FieldSynchronizer(private val callback: Runnable, private val alwaysCallCa
findField.id = fieldId
}
fieldId = stream.readVarIntLE()
fieldId = stream.readVarIntLE(sizeLimit)
}
}
var fieldId = stream.readVarIntLE()
var fieldId = stream.readVarIntLE(sizeLimit)
var i = 0
while (fieldId != 0) {
val field = idToField[fieldId]
val payloadSize = stream.readVarIntLE()
val payloadSize = stream.readVarIntLE(sizeLimit)
if (field == null) {
LOGGER.error("Unable to read field $fieldId (${missingFieldsMap[fieldId]}) because we don't know anything about it! Skipping $payloadSize bytes", IllegalStateException("Unknown field $fieldId"))
sizeLimit.accountBytes(payloadSize.toLong())
stream.skipNBytes(payloadSize.toLong())
continue
}
sizeLimit.accountBytes(payloadSize.toLong())
val bytes = ByteArray(payloadSize)
stream.read(bytes)
field.read(DataInputStream(FastByteArrayInputStream(bytes)), payloadSize)
fieldId = stream.readVarIntLE()
fieldId = stream.readVarIntLE(sizeLimit)
i++
}