diff --git a/src/main/kotlin/ru/dbotthepony/mc/otm/core/math/Clustering.kt b/src/main/kotlin/ru/dbotthepony/mc/otm/core/math/Clustering.kt index 22b244067..a4faa615f 100644 --- a/src/main/kotlin/ru/dbotthepony/mc/otm/core/math/Clustering.kt +++ b/src/main/kotlin/ru/dbotthepony/mc/otm/core/math/Clustering.kt @@ -13,11 +13,11 @@ interface Cluster { } private class ClusterValue>(val value: V, var cluster: MutableCluster, var error: V) { - inline fun updateError(abs: (V) -> V, minus: (V, V) -> V) { + fun updateError(abs: (V) -> V, minus: (V, V) -> V) { error = abs(minus(cluster.center, value)) } - inline fun maybeSwitchCluster(cluster: MutableCluster, abs: (V) -> V, minus: (V, V) -> V): Boolean { + fun maybeSwitchCluster(cluster: MutableCluster, abs: (V) -> V, minus: (V, V) -> V): Boolean { if (cluster == this.cluster) return false val newError = abs(minus(cluster.center, value)) @@ -26,6 +26,8 @@ private class ClusterValue>(val value: V, var cluster: Mutable error = newError this.cluster.values.remove(this) cluster.values.add(this) + cluster.generation++ + this.cluster.generation++ this.cluster = cluster return true } else { @@ -34,12 +36,16 @@ private class ClusterValue>(val value: V, var cluster: Mutable } } -private class MutableCluster>(var center: V, expectedSize: Int) { - val values = ObjectArrayList>(min(100, expectedSize)) +private class MutableCluster>(var center: V) { + val values = ObjectArrayList>() + var generation = 0 + var centerCalculatedAtGeneration = -1 - inline fun calculateCenter(identity: V, plus: (V, V) -> V, minus: (V, V) -> V, divInt: (V, Int) -> V, abs: (V) -> V): Boolean { - if (values.isEmpty()) return false + fun calculateCenter(identity: V, plus: (V, V) -> V, minus: (V, V) -> V, divInt: (V, Int) -> V, abs: (V) -> V): Boolean { + if (values.isEmpty) return false + if (centerCalculatedAtGeneration == generation) return true + centerCalculatedAtGeneration = generation var value = identity values.forEach { value = plus(value, it.value) } val old = center @@ -51,9 +57,37 @@ private class MutableCluster>(var center: V, expectedSize: Int return true } + + fun pullFrom(clusters: MutableList>, identity: V): Boolean { + var candidate: ClusterValue? = null + + for (eCluser in clusters) { + if (eCluser.values.size > 1) { + for (value in eCluser.values) { + if (candidate == null || candidate.error < value.error) { + candidate = value + } + } + } + } + + if (candidate != null) { + values.add(candidate) + candidate.cluster.values.remove(candidate) + this.generation++ + candidate.cluster.generation++ + candidate.cluster = this + center = candidate.value + candidate.error = identity + clusters.add(this) + return true + } else { + return false + } + } } -private inline fun > Iterable.clusterize( +private fun > Iterable.clusterize( random: RandomGenerator, initialClusters: Int = 1, identity: V, @@ -63,49 +97,43 @@ private inline fun > Iterable.clusterize( abs: (V) -> V, heuristics: (min: V, max: V, error: V) -> Boolean, ): List> { + require(initialClusters > 0) { "Invalid amount of initial clusters: $initialClusters" } val itr = iterator() if (!itr.hasNext()) return listOf() - var expectedSize = 1 + val clusters = ObjectArrayList>(initialClusters) + + for (i in 0 until initialClusters) { + clusters.add(MutableCluster(identity)) + } + + val values = ObjectArrayList>() var min = itr.next() var max = min - while (itr.hasNext()) { - val value = itr.next() + values.add(ClusterValue(min, clusters.random(random), min)) + values[0].cluster.values.add(values[0]) + for (value in itr) { min = minOf(min, value) max = maxOf(max, value) - expectedSize++ + + val cluster = clusters.random(random) + val wrapped = ClusterValue(value, cluster, value) + values.add(wrapped) + cluster.values.add(wrapped) } if (min == max) { return listOf(Cluster.Impl(listOf(min), min)) } - var targetClusters = initialClusters - val clusters = ObjectArrayList>(initialClusters) - val values = ObjectArrayList>(expectedSize) - while (true) { - clusters.clear() - values.clear() - var converged = false var oversaturation = false - for (i in 0 until targetClusters) { - clusters.add(MutableCluster(identity, expectedSize)) - } - - for (value in this) { - val cluster = clusters.random(random) - val wrapped = ClusterValue(value, cluster, value) - values.add(wrapped) - cluster.values.add(wrapped) - } - clusters.forEach { it.calculateCenter(identity, plus, minus, divInt, abs) } while (!converged) { @@ -127,26 +155,7 @@ private inline fun > Iterable.clusterize( } for (cluster in emptyClusters) { - var candidate: ClusterValue? = null - - for (eCluser in clusters) { - if (eCluser.values.size > 1) { - for (value in eCluser.values) { - if (candidate == null || candidate.error < value.error) { - candidate = value - } - } - } - } - - if (candidate != null) { - cluster.values.add(candidate) - candidate.cluster.values.remove(candidate) - candidate.cluster = cluster - cluster.center = candidate.value - candidate.error = identity - clusters.add(cluster) - } else { + if (!cluster.pullFrom(clusters, identity)) { oversaturation = true break } @@ -156,8 +165,10 @@ private inline fun > Iterable.clusterize( val maxError = values.maxOf { it.error } - if (!oversaturation && targetClusters < values.size && heuristics(min, max, maxError)) { - targetClusters++ + if (!oversaturation && clusters.size < values.size && heuristics(min, max, maxError)) { + val cluster = MutableCluster(identity) + check(cluster.pullFrom(clusters, identity)) { "Newly created cluster couldn't pull first value from other clusters (expected at least one cluster with 2 values in it)" } + clusters.add(cluster) } else { return clusters.stream() .filter { it.values.isNotEmpty() }