diff --git a/src/main/kotlin/ru/dbotthepony/mc/otm/network/FieldSynchronizer.kt b/src/main/kotlin/ru/dbotthepony/mc/otm/network/FieldSynchronizer.kt index 0c122db40..d4dcc6cf9 100644 --- a/src/main/kotlin/ru/dbotthepony/mc/otm/network/FieldSynchronizer.kt +++ b/src/main/kotlin/ru/dbotthepony/mc/otm/network/FieldSynchronizer.kt @@ -1,13 +1,17 @@ package ru.dbotthepony.mc.otm.network import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream +import it.unimi.dsi.fastutil.objects.Reference2ObjectFunction +import it.unimi.dsi.fastutil.objects.Reference2ObjectOpenHashMap import net.minecraft.world.item.ItemStack import ru.dbotthepony.mc.otm.core.* import java.io.DataInputStream import java.io.DataOutputStream import java.io.InputStream +import java.lang.ref.WeakReference import java.math.BigDecimal import java.util.* +import java.util.function.Consumer import kotlin.ConcurrentModificationException import kotlin.collections.ArrayList import kotlin.collections.HashMap @@ -32,9 +36,10 @@ fun interface FieldSetter { sealed interface IField : ReadOnlyProperty { fun observe() fun markDirty() + fun markDirty(endpoint: FieldSynchronizer.Endpoint) val value: V - fun write(stream: DataOutputStream) + fun write(stream: DataOutputStream, endpoint: FieldSynchronizer.Endpoint) fun read(stream: DataInputStream) override fun getValue(thisRef: Any, property: KProperty<*>): V { @@ -66,8 +71,7 @@ enum class MapAction { class FieldSynchronizer { private val fields = ArrayList>() - private val observers = ArrayList>() - private val dirtyFields = ArrayList>() + private val observers = LinkedList>() fun byte( value: Byte = 0, @@ -212,6 +216,104 @@ class FieldSynchronizer { ) } + private val endpoints = LinkedList>() + val defaultEndpoint = Endpoint() + + private var lastEndpointsCleanup = System.nanoTime() + + private fun notifyEndpoints(dirtyField: IField<*>) { + forEachEndpoint { + it.addDirtyField(dirtyField) + } + } + + private inline fun forEachEndpoint(execute: (Endpoint) -> Unit) { + lastEndpointsCleanup = System.nanoTime() + + synchronized(endpoints) { + val iterator = endpoints.listIterator() + + for (value in iterator) { + val endpoint = value.get() + + if (endpoint == null) { + iterator.remove() + } else { + execute.invoke(endpoint) + } + } + } + } + + inner class Endpoint { + init { + endpoints.addLast(WeakReference(this)) + + if (System.nanoTime() - lastEndpointsCleanup >= 60_000_000_000) { + lastEndpointsCleanup = System.nanoTime() + + synchronized(endpoints) { + val iterator = endpoints.listIterator() + + for (value in iterator) { + if (value.get() == null) { + iterator.remove() + } + } + } + } + } + + private val dirtyFields = LinkedList>() + private val mapBacklogs = Reference2ObjectOpenHashMap, LinkedList Unit>>>() + + init { + for (field in fields) { + field.markDirty(this) + } + } + + internal fun addDirtyField(field: IField<*>) { + if (field !in dirtyFields) { + dirtyFields.addLast(field) + } + } + + internal fun getMapBacklog(map: Map): LinkedList Unit>> { + return mapBacklogs.computeIfAbsent(map, Reference2ObjectFunction { + LinkedList() + }) + } + + fun collectNetworkPayload(): FastByteArrayOutputStream? { + if (dirtyFields.isEmpty()) { + return null + } + + val stream = FastByteArrayOutputStream() + val dataStream = DataOutputStream(stream) + + for (field in dirtyFields) { + field.write(dataStream, this) + } + + dirtyFields.clear() + dataStream.write(0) + + return stream + } + } + + private val boundEndpoints = WeakHashMap() + + fun computeEndpointFor(obj: Any): Endpoint { + return boundEndpoints.computeIfAbsent(obj) { Endpoint() } + } + + fun endpointFor(obj: Any): Endpoint? { + return boundEndpoints[obj] + } + inner class Field( private var field: V, private val codec: IStreamCodec, @@ -240,7 +342,7 @@ class FieldSynchronizer { override fun write(value: V) { if (!isDirty && !codec.compare(remote, value)) { - dirtyFields.add(this@Field) + notifyEndpoints(this@Field) isDirty = true } @@ -250,7 +352,7 @@ class FieldSynchronizer { override fun observe() { if (!isDirty && !codec.compare(remote, field)) { - dirtyFields.add(this) + notifyEndpoints(this@Field) isDirty = true } } @@ -278,7 +380,7 @@ class FieldSynchronizer { } if (!isDirty && !codec.compare(remote, value)) { - dirtyFields.add(this) + notifyEndpoints(this@Field) isDirty = true } @@ -287,12 +389,16 @@ class FieldSynchronizer { override fun markDirty() { if (!isDirty) { - dirtyFields.add(this) + notifyEndpoints(this@Field) isDirty = true } } - override fun write(stream: DataOutputStream) { + override fun markDirty(endpoint: Endpoint) { + endpoint.addDirtyField(this) + } + + override fun write(stream: DataOutputStream, endpoint: Endpoint) { stream.write(id) codec.write(stream, field) isDirty = false @@ -348,20 +454,23 @@ class FieldSynchronizer { override fun observe() { if (!isDirty && !codec.compare(remote, value)) { - dirtyFields.add(this) + notifyEndpoints(this) isDirty = true } } override fun markDirty() { if (!isDirty) { - dirtyFields.add(this) + notifyEndpoints(this) isDirty = true } } + override fun markDirty(endpoint: Endpoint) { + endpoint.addDirtyField(this) + } - override fun write(stream: DataOutputStream) { + override fun write(stream: DataOutputStream, endpoint: Endpoint) { stream.write(id) val value = value codec.write(stream, value) @@ -400,7 +509,26 @@ class FieldSynchronizer { observers.add(this) } - private val backlog = LinkedList<(DataOutputStream) -> Unit>() + private fun pushBacklog(key: Any?, value: (DataOutputStream) -> Unit) { + forEachEndpoint { + val list = it.getMapBacklog(this) + val iterator = list.listIterator() + + for (pair in iterator) { + if (pair.first == key) { + iterator.remove() + } + } + + list.addLast(key to value) + } + } + + private fun clearBacklog() { + forEachEndpoint { + it.getMapBacklog(this).clear() + } + } override fun observe() { if (isRemote) { @@ -416,7 +544,7 @@ class FieldSynchronizer { if (!valueCodec.compare(value, remoteValue)) { val valueCopy = valueCodec.copy(value) - backlog.add { + pushBacklog(key) { it.write(MapAction.ADD.ordinal + 1) keyCodec.write(it, key) valueCodec.write(it, valueCopy) @@ -425,7 +553,7 @@ class FieldSynchronizer { observingBackingMap[key] = valueCopy if (!isDirty) { - dirtyFields.add(this) + notifyEndpoints(this) isDirty = true } } @@ -438,25 +566,52 @@ class FieldSynchronizer { return } - if (!isDirty) { - dirtyFields.add(this) - isDirty = true + isDirty = true + val backlogs = LinkedList Unit>>>() + + forEachEndpoint { + it.addDirtyField(this) + val value = it.getMapBacklog(this) + backlogs.add(value) + value.clear() + value.add(null to ClearBacklogEntry) } - if (!sentAllValues) { - for ((key, value) in backingMap) { - val valueCopy = valueCodec.copy(value) + for ((key, value) in backingMap) { + val valueCopy = valueCodec.copy(value) - backlog.add { - it.write(MapAction.ADD.ordinal + 1) - keyCodec.write(it, key) - valueCodec.write(it, valueCopy) - } - - observingBackingMap?.put(key, valueCopy) + val action = { it: DataOutputStream -> + it.write(MapAction.ADD.ordinal + 1) + keyCodec.write(it, key) + valueCodec.write(it, valueCopy) } - sentAllValues = true + for (backlog in backlogs) { + backlog.add(key to action) + } + + observingBackingMap?.put(key, valueCopy) + } + } + + override fun markDirty(endpoint: Endpoint) { + if (isRemote) { + return + } + + val backlog = endpoint.getMapBacklog(this) + + backlog.clear() + backlog.add(null to ClearBacklogEntry) + + for ((key, value) in backingMap) { + val valueCopy = valueCodec.copy(value) + + backlog.add(key to { + it.write(MapAction.ADD.ordinal + 1) + keyCodec.write(it, key) + valueCodec.write(it, valueCopy) + }) } } @@ -466,12 +621,17 @@ class FieldSynchronizer { return } - backlog.clear() observingBackingMap?.clear() - backlog.add(ClearBacklogEntry) + + forEachEndpoint { + it.getMapBacklog(this@Map).also { + it.clear() + it.add(null to ClearBacklogEntry) + } + } if (!isDirty) { - dirtyFields.add(this@Map) + notifyEndpoints(this@Map) isDirty = true } } @@ -483,7 +643,7 @@ class FieldSynchronizer { val valueCopy = valueCodec.copy(value) - backlog.add { + pushBacklog(key) { it.write(MapAction.ADD.ordinal + 1) keyCodec.write(it, key) valueCodec.write(it, valueCopy) @@ -492,7 +652,7 @@ class FieldSynchronizer { observingBackingMap?.put(key, valueCopy) if (!isDirty) { - dirtyFields.add(this@Map) + notifyEndpoints(this@Map) isDirty = true } } @@ -504,7 +664,7 @@ class FieldSynchronizer { val keyCopy = keyCodec.copy(key) - backlog.add { + pushBacklog(key) { it.write(MapAction.REMOVE.ordinal + 1) keyCodec.write(it, keyCopy) } @@ -512,23 +672,24 @@ class FieldSynchronizer { observingBackingMap?.remove(key) if (!isDirty) { - dirtyFields.add(this@Map) + notifyEndpoints(this@Map) isDirty = true } } } - override fun write(stream: DataOutputStream) { + override fun write(stream: DataOutputStream, endpoint: Endpoint) { stream.write(id) sentAllValues = false isDirty = false - for (entry in backlog) { - entry.invoke(stream) - } + val iterator = endpoint.getMapBacklog(this).listIterator() - backlog.clear() + for (entry in iterator) { + entry.second.invoke(stream) + iterator.remove() + } stream.write(0) } @@ -536,7 +697,7 @@ class FieldSynchronizer { override fun read(stream: DataInputStream) { if (!isRemote) { isRemote = true - backlog.clear() + clearBacklog() observingBackingMap?.clear() } @@ -585,28 +746,20 @@ class FieldSynchronizer { } } - fun collectNetworkPayload(): FastByteArrayOutputStream? { + fun observe() { if (observers.isNotEmpty()) { for (field in observers) { field.observe() } } + } - if (dirtyFields.isEmpty()) { - return null - } - - val stream = FastByteArrayOutputStream() - val dataStream = DataOutputStream(stream) - - for (field in dirtyFields) { - field.write(dataStream) - } - - dirtyFields.clear() - dataStream.write(0) - - return stream + /** + * [defaultEndpoint]#collectNetworkPayload + */ + fun collectNetworkPayload(): FastByteArrayOutputStream? { + observe() + return defaultEndpoint.collectNetworkPayload() } fun applyNetworkPayload(stream: DataInputStream): Int { diff --git a/src/test/kotlin/ru/dbotthepony/mc/otm/tests/FieldSynchronizerTests.kt b/src/test/kotlin/ru/dbotthepony/mc/otm/tests/FieldSynchronizerTests.kt index 592b23956..0ed4494ab 100644 --- a/src/test/kotlin/ru/dbotthepony/mc/otm/tests/FieldSynchronizerTests.kt +++ b/src/test/kotlin/ru/dbotthepony/mc/otm/tests/FieldSynchronizerTests.kt @@ -68,4 +68,41 @@ object FieldSynchronizerTests { assertEquals(intA2.value, intB2.value) assertEquals(intA3.value, intB3.value) } + + @Test + @DisplayName("Field Synchronizer multiple endpoints") + fun multipleEndpoints() { + val a = FieldSynchronizer() + val b = FieldSynchronizer() + val c = FieldSynchronizer() + + val f1 = a.bool() + val f2 = a.bool() + val f3 = a.int() + val f4 = a.long() + + val aFields = listOf(f1, f2, f3, f4) + val bFields = listOf(b.bool(), b.bool(), b.int(), b.long()) + val cFields = listOf(c.bool(), c.bool(), c.int(), c.long()) + + val bEndpoint = a.Endpoint() + + f2.value = true + f3.value = -15 + + val cEndpoint = a.Endpoint() + + f4.value = 80L + + b.applyNetworkPayload(bEndpoint.collectNetworkPayload()!!.let { ByteArrayInputStream(it.array, 0, it.length) }) + c.applyNetworkPayload(cEndpoint.collectNetworkPayload()!!.let { ByteArrayInputStream(it.array, 0, it.length) }) + + for ((i, field) in bFields.withIndex()) { + assertEquals(aFields[i].value, field.value) + } + + for ((i, field) in cFields.withIndex()) { + assertEquals(aFields[i].value, field.value) + } + } }