diff --git a/src/main/kotlin/ru/dbotthepony/mc/otm/FriendlyStreams.kt b/src/main/kotlin/ru/dbotthepony/mc/otm/FriendlyStreams.kt index 605cb976c..6954f0534 100644 --- a/src/main/kotlin/ru/dbotthepony/mc/otm/FriendlyStreams.kt +++ b/src/main/kotlin/ru/dbotthepony/mc/otm/FriendlyStreams.kt @@ -1,32 +1,83 @@ package ru.dbotthepony.mc.otm -import io.netty.buffer.ByteBufInputStream import io.netty.handler.codec.EncoderException import net.minecraft.nbt.CompoundTag import net.minecraft.nbt.NbtAccounter import net.minecraft.nbt.NbtIo -import java.io.DataInputStream -import java.io.DataOutputStream -import java.io.IOException -import java.io.InputStream -import java.io.OutputStream +import net.minecraft.world.item.Item +import net.minecraft.world.item.ItemStack +import net.minecraftforge.registries.ForgeRegistries +import net.minecraftforge.registries.ForgeRegistry +import java.io.* // But seriously, Mojang, why would you need to derive from ByteBuf directly, when you can implement // your own InputStream and OutputStream, since ByteBuf is meant to be operated on most time like a stream anyway? // netty ByteBuf -> netty ByteBufInputStream -> Minecraft FriendlyInputStream -fun T.writeNbt(value: CompoundTag): T { +fun OutputStream.writeInt(value: Int) { + if (this is DataOutputStream) { + writeInt(value) + return + } + + write(value ushr 24) + write(value ushr 16) + write(value ushr 8) + write(value) +} + +fun InputStream.readInt(): Int { + if (this is DataInputStream) { + return readInt() + } + + return (read() shl 24) or (read() shl 16) or (read() shl 8) or read() +} + +fun OutputStream.writeLong(value: Long) { + if (this is DataOutputStream) { + writeLong(value) + return + } + + write((value ushr 48).toInt()) + write((value ushr 40).toInt()) + write((value ushr 32).toInt()) + write((value ushr 24).toInt()) + write((value ushr 16).toInt()) + write((value ushr 8).toInt()) + write(value.toInt()) +} + +fun InputStream.readLong(): Long { + if (this is DataInputStream) { + return readLong() + } + + return (read().toLong() shl 48) or + (read().toLong() shl 40) or + (read().toLong() shl 32) or + (read().toLong() shl 24) or + (read().toLong() shl 16) or + (read().toLong() shl 8) or + read().toLong() +} + +fun OutputStream.writeFloat(value: Float) = writeInt(value.toBits()) +fun InputStream.readFloat() = Float.fromBits(readInt()) +fun OutputStream.writeDouble(value: Double) = writeLong(value.toBits()) +fun InputStream.readDouble() = Double.fromBits(readLong()) + +fun OutputStream.writeNbt(value: CompoundTag) { try { NbtIo.write(value, if (this is DataOutputStream) this else DataOutputStream(this)) } catch (ioexception: IOException) { throw EncoderException(ioexception) } - - return this } -fun T.readNbt(accounter: NbtAccounter = NbtAccounter.UNLIMITED): CompoundTag { +fun InputStream.readNbt(accounter: NbtAccounter = NbtAccounter.UNLIMITED): CompoundTag { return try { NbtIo.read(if (this is DataInputStream) this else DataInputStream(this), accounter) } catch (ioexception: IOException) { @@ -34,4 +85,82 @@ fun T.readNbt(accounter: NbtAccounter = NbtAccounter.UNLIMITED } } +fun OutputStream.writeItem(itemStack: ItemStack, limitedTag: Boolean = true) { + if (itemStack.isEmpty) { + write(0) + } else { + write(1) + val id = (ForgeRegistries.ITEMS as ForgeRegistry).getID(itemStack.item) + writeInt(id) + writeInt(itemStack.count) + + var compoundtag: CompoundTag? = null + + if (itemStack.item.isDamageable(itemStack) || itemStack.item.shouldOverrideMultiplayerNbt()) { + compoundtag = if (limitedTag) itemStack.shareTag else itemStack.tag + } + + write(if (compoundtag != null) 1 else 0) + + if (compoundtag != null) { + writeNbt(compoundtag) + } + } +} + +fun InputStream.readItem(): ItemStack { + if (read() == 0) { + return ItemStack.EMPTY + } + + + val item = (ForgeRegistries.ITEMS as ForgeRegistry).getValue(readInt()) + val itemStack = ItemStack(item, readInt()) + + if (read() != 0) { + itemStack.readShareTag(readNbt()) + } + + return itemStack +} + +fun InputStream.readVarIntBE(): Int { + var result = 0 + var read = read() + + do { + result = (result shl 7) or (read and 127) + read = read() + } while (read and 128 != 0) + + return result +} + +fun InputStream.readVarIntLE(): Int { + var result = 0 + var read = read() + var i = 0 + + while (read and 128 != 0) { + result = result or ((read and 127) shl i) + read = read() + i += 7 + } + + result = result or ((read and 127) shl i) + + return result +} + +fun OutputStream.writeVarIntLE(value: Int) { + require(value >= 0) { "Negative number provided: $value" } + var written = value + + while (written >= 128) { + write((written and 127) or 128) + written = written ushr 7 + } + + write(written) +} diff --git a/src/main/kotlin/ru/dbotthepony/mc/otm/core/ImpreciseFraction.kt b/src/main/kotlin/ru/dbotthepony/mc/otm/core/ImpreciseFraction.kt index 48633d0a9..fe3c159c2 100644 --- a/src/main/kotlin/ru/dbotthepony/mc/otm/core/ImpreciseFraction.kt +++ b/src/main/kotlin/ru/dbotthepony/mc/otm/core/ImpreciseFraction.kt @@ -4,6 +4,12 @@ import net.minecraft.nbt.ByteArrayTag import net.minecraft.nbt.StringTag import net.minecraft.nbt.Tag import net.minecraft.network.FriendlyByteBuf +import ru.dbotthepony.mc.otm.readDouble +import ru.dbotthepony.mc.otm.readVarIntLE +import ru.dbotthepony.mc.otm.writeDouble +import ru.dbotthepony.mc.otm.writeVarIntLE +import java.io.InputStream +import java.io.OutputStream import java.math.BigDecimal import java.math.BigInteger import java.math.MathContext @@ -641,3 +647,16 @@ class ImpreciseFraction @JvmOverloads constructor(whole: BigInteger, decimal: Do fun FriendlyByteBuf.readImpreciseFraction() = ImpreciseFraction.read(this) fun FriendlyByteBuf.writeImpreciseFraction(value: ImpreciseFraction) = value.write(this) + +fun InputStream.readImpreciseFraction(): ImpreciseFraction { + val bytes = ByteArray(readVarIntLE()) + read(bytes) + return ImpreciseFraction(BigInteger(bytes), readDouble()) +} + +fun OutputStream.writeImpreciseFraction(value: ImpreciseFraction) { + val bytes = value.whole.toByteArray() + writeVarIntLE(bytes.size) + write(bytes) + writeDouble(value.decimal) +} diff --git a/src/main/kotlin/ru/dbotthepony/mc/otm/network/FieldSynchronizer.kt b/src/main/kotlin/ru/dbotthepony/mc/otm/network/FieldSynchronizer.kt new file mode 100644 index 000000000..49d848438 --- /dev/null +++ b/src/main/kotlin/ru/dbotthepony/mc/otm/network/FieldSynchronizer.kt @@ -0,0 +1,234 @@ +package ru.dbotthepony.mc.otm.network + +import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream +import net.minecraft.world.item.ItemStack +import ru.dbotthepony.mc.otm.core.ImpreciseFraction +import ru.dbotthepony.mc.otm.core.readImpreciseFraction +import ru.dbotthepony.mc.otm.core.writeImpreciseFraction +import ru.dbotthepony.mc.otm.readItem +import ru.dbotthepony.mc.otm.writeItem +import java.io.DataInputStream +import java.io.DataOutputStream +import java.io.InputStream +import kotlin.properties.ReadWriteProperty +import kotlin.reflect.KProperty + +fun interface FieldReader { + fun invoke(): V +} + +fun interface FieldGetter { + fun invoke(read: FieldReader): V +} + +fun interface FieldWriter { + fun invoke(value: V) +} + +fun interface FieldSetter { + fun invoke(value: V, write: FieldWriter, setByRemote: Boolean): V +} + +class NetworkValueCodec( + val reader: (stream: DataInputStream) -> V, + val writer: (stream: DataOutputStream, value: V) -> Unit, + val copier: ((value: V) -> V) = { it }, + val comparator: ((a: V, b: V) -> Boolean) = { a, b -> a == b } +) + +val BooleanValueCodec = NetworkValueCodec(DataInputStream::readBoolean, DataOutputStream::writeBoolean) +val IntValueCodec = NetworkValueCodec(DataInputStream::readInt, DataOutputStream::writeInt) +val ItemStackValueCodec = NetworkValueCodec(DataInputStream::readItem, DataOutputStream::writeItem, ItemStack::copy, ItemStack::isSameItemSameTags) +val ImpreciseFractionValueCodec = NetworkValueCodec(DataInputStream::readImpreciseFraction, DataOutputStream::writeImpreciseFraction) + +class FieldSynchronizer { + private val fields = ArrayList>() + private val observers = ArrayList>() + private val dirtyFields = ArrayList>() + + fun bool( + value: Boolean = false, + getter: FieldGetter? = null, + setter: FieldSetter? = null, + ): Field { + return Field(value, BooleanValueCodec, getter, setter) + } + + fun int( + value: Int = 0, + getter: FieldGetter? = null, + setter: FieldSetter? = null, + ): Field { + return Field(value, IntValueCodec, getter, setter) + } + + fun fraction( + value: ImpreciseFraction = ImpreciseFraction.ZERO, + getter: FieldGetter? = null, + setter: FieldSetter? = null, + ): Field { + return Field(value, ImpreciseFractionValueCodec, getter, setter) + } + + fun item( + value: ItemStack = ItemStack.EMPTY, + getter: FieldGetter? = null, + setter: FieldSetter? = null, + ): Field { + return Field(value, ItemStackValueCodec, getter, setter, isObserver = true) + } + + inner class Field( + private var value: V, + private val dispatcher: NetworkValueCodec, + private val getter: FieldGetter? = null, + private val setter: FieldSetter? = null, + isObserver: Boolean = false, + ) : ReadWriteProperty { + private var remote: V = dispatcher.copier.invoke(value) + + val id = fields.size + 1 + + init { + fields.add(this) + + if (isObserver) { + observers.add(this) + } + } + + private var isDirty = false + + private val write: FieldWriter = FieldWriter { + if (!isDirty && !dispatcher.comparator.invoke(remote, value)) { + dirtyFields.add(this) + isDirty = true + } + + this.value = it + } + + private val read: FieldReader = FieldReader { + return@FieldReader this.value + } + + fun observe() { + if (!isDirty && !dispatcher.comparator.invoke(remote, value)) { + dirtyFields.add(this) + isDirty = true + } + } + + fun getValue(): V { + val getter = this.getter + + if (getter != null) { + return getter.invoke(read) + } + + return value + } + + override fun getValue(thisRef: Any, property: KProperty<*>): V { + return getValue() + } + + fun setValue(value: V) { + val setter = this.setter + + if (setter != null) { + setter.invoke(value, write, false) + return + } + + if (this.value == value) { + return + } + + if (!isDirty && !dispatcher.comparator.invoke(remote, value)) { + dirtyFields.add(this) + isDirty = true + } + + this.value = value + } + + override fun setValue(thisRef: Any, property: KProperty<*>, value: V) { + setValue(value) + } + + fun markDirty() { + if (!isDirty) { + dirtyFields.add(this) + isDirty = true + } + } + + fun write(stream: DataOutputStream) { + stream.write(id) + dispatcher.writer.invoke(stream, value) + isDirty = false + remote = dispatcher.copier.invoke(value) + } + + fun read(stream: DataInputStream) { + val value = dispatcher.reader.invoke(stream) + val setter = this.setter + + if (setter != null) { + setter.invoke(value, write, true) + return + } + + this.value = value + } + } + + fun invalidate() { + for (field in fields) { + field.markDirty() + } + } + + fun collectNetworkPayload(): FastByteArrayOutputStream? { + if (observers.isNotEmpty()) { + for (field in observers) { + field.observe() + } + } + + if (dirtyFields.isEmpty()) { + return null + } + + val stream = FastByteArrayOutputStream() + val dataStream = DataOutputStream(stream) + + for (field in dirtyFields) { + field.write(dataStream) + } + + dirtyFields.clear() + dataStream.write(0) + + return stream + } + + fun applyNetworkPayload(stream: DataInputStream): Int { + var fieldId = stream.read() + var i = 0 + + while (fieldId != 0) { + val field = fields.getOrNull(fieldId - 1) ?: throw IndexOutOfBoundsException("Invalid field id $fieldId") + field.read(stream) + fieldId = stream.read() + i++ + } + + return i + } + + fun applyNetworkPayload(stream: InputStream): Int { + return applyNetworkPayload(DataInputStream(stream)) + } +} diff --git a/src/test/kotlin/ru/dbotthepony/mc/otm/tests/FieldSynchronizerTests.kt b/src/test/kotlin/ru/dbotthepony/mc/otm/tests/FieldSynchronizerTests.kt new file mode 100644 index 000000000..91f47a8a3 --- /dev/null +++ b/src/test/kotlin/ru/dbotthepony/mc/otm/tests/FieldSynchronizerTests.kt @@ -0,0 +1,71 @@ +package ru.dbotthepony.mc.otm.tests + +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.DisplayName +import org.junit.jupiter.api.Test +import ru.dbotthepony.mc.otm.network.FieldSynchronizer +import java.io.ByteArrayInputStream + +object FieldSynchronizerTests { + @Test + @DisplayName("Field Synchronizer full read/write test") + fun test() { + val a = FieldSynchronizer() + val b = FieldSynchronizer() + + val boolA = a.bool() + val boolB = b.bool() + + val intA = a.int() + val intB = b.int() + + val intA2 = a.int() + val intB2 = b.int() + + val intA3 = a.int() + val intB3 = b.int() + + boolA.setValue(true) + intA.setValue(8384) + intA2.setValue(348488) + intA3.setValue(-4) + + b.applyNetworkPayload(ByteArrayInputStream(a.collectNetworkPayload()!!.array)) + + assertEquals(boolA.getValue(), boolB.getValue()) + assertEquals(intA.getValue(), intB.getValue()) + assertEquals(intA2.getValue(), intB2.getValue()) + assertEquals(intA3.getValue(), intB3.getValue()) + } + + @Test + @DisplayName("Field Synchronizer partial read/write test") + fun testPartial() { + val a = FieldSynchronizer() + val b = FieldSynchronizer() + + val boolA = a.bool() + val boolB = b.bool() + + val intA = a.int() + val intB = b.int() + + val intA2 = a.int() + val intB2 = b.int() + + val intA3 = a.int() + val intB3 = b.int() + + boolA.setValue(true) + //intA.setValue(8384) + //intA2.setValue(348488) + intA3.setValue(-4) + + b.applyNetworkPayload(ByteArrayInputStream(a.collectNetworkPayload()!!.array)) + + assertEquals(boolA.getValue(), boolB.getValue()) + assertEquals(intA.getValue(), intB.getValue()) + assertEquals(intA2.getValue(), intB2.getValue()) + assertEquals(intA3.getValue(), intB3.getValue()) + } +} diff --git a/src/test/kotlin/ru/dbotthepony/mc/otm/tests/FriendlyStreams.kt b/src/test/kotlin/ru/dbotthepony/mc/otm/tests/FriendlyStreams.kt new file mode 100644 index 000000000..1d5fcc692 --- /dev/null +++ b/src/test/kotlin/ru/dbotthepony/mc/otm/tests/FriendlyStreams.kt @@ -0,0 +1,53 @@ +package ru.dbotthepony.mc.otm.tests + +import it.unimi.dsi.fastutil.io.FastByteArrayInputStream +import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.DisplayName +import org.junit.jupiter.api.Test +import ru.dbotthepony.mc.otm.readInt +import ru.dbotthepony.mc.otm.readVarIntLE +import ru.dbotthepony.mc.otm.writeInt +import ru.dbotthepony.mc.otm.writeVarIntLE + +object FriendlyStreams { + @Test + @DisplayName("Stream extension functions") + fun test() { + val output = FastByteArrayOutputStream() + + output.writeInt(4) + output.writeInt(16) + output.writeInt(-1) + output.writeInt(1000000) + + output.writeVarIntLE(0) + output.writeVarIntLE(1) + output.writeVarIntLE(4) + output.writeVarIntLE(15) + output.writeVarIntLE(16) + output.writeVarIntLE(127) + output.writeVarIntLE(128) + output.writeVarIntLE(129) + output.writeVarIntLE(10023) + output.writeVarIntLE(100000) + + val input = FastByteArrayInputStream(output.array, 0, output.length) + + assertEquals(input.readInt(), 4) + assertEquals(input.readInt(), 16) + assertEquals(input.readInt(), -1) + assertEquals(input.readInt(), 1000000) + + assertEquals(input.readVarIntLE(), 0) + assertEquals(input.readVarIntLE(), 1) + assertEquals(input.readVarIntLE(), 4) + assertEquals(input.readVarIntLE(), 15) + assertEquals(input.readVarIntLE(), 16) + assertEquals(input.readVarIntLE(), 127) + assertEquals(input.readVarIntLE(), 128) + assertEquals(input.readVarIntLE(), 129) + assertEquals(input.readVarIntLE(), 10023) + assertEquals(input.readVarIntLE(), 100000) + } +} \ No newline at end of file