KStarbound/src/main/kotlin/ru/dbotthepony/kstarbound/defs/MarkovTextGenerator.kt

85 lines
2.8 KiB
Kotlin

package ru.dbotthepony.kstarbound.defs
import com.google.common.collect.ImmutableList
import com.google.common.collect.ImmutableMap
import it.unimi.dsi.fastutil.objects.Object2ObjectArrayMap
import it.unimi.dsi.fastutil.objects.Object2ObjectFunction
import it.unimi.dsi.fastutil.objects.ObjectArraySet
import org.apache.logging.log4j.LogManager
import ru.dbotthepony.kstarbound.Globals
import ru.dbotthepony.kstarbound.json.builder.JsonFactory
import ru.dbotthepony.kstarbound.util.AssetPathStack
import ru.dbotthepony.kstarbound.util.random.random
import java.util.random.RandomGenerator
@JsonFactory
data class MarkovTextGenerator(
val name: String,
val prefixSize: Int = 1,
val endSize: Int = 1,
val sourceNames: ImmutableList<String>,
) {
val ends: ImmutableList<String>
val starts: ImmutableList<String>
val chains: ImmutableMap<String, ImmutableList<String>>
init {
require(prefixSize > 0) { "Invalid prefix size: $prefixSize" }
require(endSize > 0) { "Invalid suffix size: $endSize" }
val ends = ObjectArraySet<String>()
val starts = ObjectArraySet<String>()
val chains = Object2ObjectArrayMap<String, ObjectArraySet<String>>()
for (sourceName in sourceNames) {
if (sourceName.length < prefixSize || sourceName.length < endSize) {
LOGGER.warn("Name $sourceName is too short for Markov name generator with prefix size of $prefixSize and suffix size $endSize; it will be ignored (generator: ${AssetPathStack.remap(name)})")
continue
}
val sourceName = sourceName.lowercase()
ends.add(sourceName.substring(sourceName.length - endSize, sourceName.length))
for (i in 0 .. sourceName.length - prefixSize) {
val prefix = sourceName.substring(i, i + prefixSize)
if (i == 0)
starts.add(prefix)
if (i + prefixSize < sourceName.length) {
chains
.computeIfAbsent(prefix, Object2ObjectFunction { ObjectArraySet() })
.add(sourceName[i + prefixSize].toString())
}
}
}
this.ends = ImmutableList.copyOf(ends)
this.starts = ImmutableList.copyOf(starts)
this.chains = chains.entries.stream().map { it.key to ImmutableList.copyOf(it.value) }.collect(ImmutableMap.toImmutableMap({ it.first }, { it.second }))
}
fun generate(random: RandomGenerator, targetLength: Int, maxLength: Int = targetLength, maxTries: Int = 50): String {
var tries = 0
var piece: String
do {
piece = starts.random(random)
while (
piece.length < targetLength ||
piece.substring(piece.length - endSize, piece.length) !in ends
) {
val link = piece.substring(piece.length - endSize, piece.length)
piece += (chains[link] ?: break).random(random)
}
} while (tries++ < maxTries && (piece.length > maxLength || piece in Globals.profanityFilter))
return piece
}
companion object {
private val LOGGER = LogManager.getLogger()
}
}