KStarbound/src/main/kotlin/ru/dbotthepony/kstarbound/server/ServerChannels.kt

200 lines
5.3 KiB
Kotlin

package ru.dbotthepony.kstarbound.server
import io.netty.bootstrap.ServerBootstrap
import io.netty.channel.Channel
import io.netty.channel.ChannelFuture
import io.netty.channel.ChannelInitializer
import io.netty.channel.local.LocalAddress
import io.netty.channel.local.LocalServerChannel
import io.netty.channel.socket.nio.NioServerSocketChannel
import it.unimi.dsi.fastutil.ints.IntAVLTreeSet
import org.apache.logging.log4j.LogManager
import ru.dbotthepony.kstarbound.network.Connection
import ru.dbotthepony.kstarbound.network.ConnectionType
import ru.dbotthepony.kstarbound.network.IPacket
import ru.dbotthepony.kstarbound.network.packets.clientbound.ServerInfoPacket
import java.io.Closeable
import java.net.SocketAddress
import java.util.*
import java.util.concurrent.CopyOnWriteArrayList
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.withLock
class ServerChannels(val server: StarboundServer) : Closeable {
private val channels = CopyOnWriteArrayList<ChannelFuture>()
val connections = CopyOnWriteArrayList<ServerConnection>()
private var localChannel: Channel? = null
private val lock = ReentrantLock()
private var isClosed = false
private var nextConnectionID = 0
private val occupiedConnectionIDs = IntAVLTreeSet()
private val connectionIDLock = Any()
private val playerCount = AtomicInteger()
fun incrementPlayerCount() {
val new = playerCount.incrementAndGet()
broadcast(ServerInfoPacket(new, server.settings.maxPlayers))
}
fun decrementPlayerCount() {
val new = playerCount.decrementAndGet()
check(new >= 0) { "Player count turned negative" }
broadcast(ServerInfoPacket(new, server.settings.maxPlayers))
}
fun connectionByID(id: Int): ServerConnection? {
return connections.firstOrNull { it.connectionID == id }
}
fun connectionByUUID(id: UUID): ServerConnection? {
return connections.firstOrNull { it.uuid == id }
}
private fun cycleConnectionID(): Int {
val v = ++nextConnectionID and MAX_PLAYERS
if (v == 0) {
nextConnectionID++
return 1
}
return v
}
fun nextConnectionID(): Int {
synchronized(connectionIDLock) {
var i = 0
while (i++ <= MAX_PLAYERS) { // 32767 is the maximum
val get = cycleConnectionID()
if (!occupiedConnectionIDs.contains(get)) {
occupiedConnectionIDs.add(get)
return get
}
}
}
throw IllegalStateException("No more free connection IDs, how did we end up here?")
}
fun freeConnectionID(id: Int): Boolean {
return synchronized(connectionIDLock) {
occupiedConnectionIDs.remove(id)
}
}
fun broadcast(packet: IPacket, flush: Boolean = true) {
connections.forEach {
if (it.isReady) {
it.send(packet, flush)
}
}
}
@Suppress("name_shadowing")
fun createLocalChannel(): Channel {
val localChannel = localChannel
if (localChannel != null) {
return localChannel
}
lock.withLock {
val localChannel = this.localChannel
if (localChannel != null) {
return localChannel
}
val channel = ServerBootstrap().channel(LocalServerChannel::class.java).group(Connection.NIO_POOL).childHandler(object : ChannelInitializer<Channel>() {
override fun initChannel(ch: Channel) {
lock.withLock {
if (isClosed) return
LOGGER.info("Incoming connection from ${ch.remoteAddress()}")
try {
val connection = ServerConnection(server, ConnectionType.MEMORY)
connections.add(connection)
connection.bind(ch)
} catch (err: Throwable) {
LOGGER.error("Error while accepting new connection from ${ch.remoteAddress()}", err)
ch.close()
}
}
}
}).bind(LocalAddress.ANY).syncUninterruptibly()
channels.add(channel)
this.localChannel = channel.channel()
channel.channel().closeFuture().addListener {
channels.remove(channel)
this.localChannel = null
}
return channel.channel()
}
}
fun createChannel(localAddress: SocketAddress): Channel {
lock.withLock {
val channel = ServerBootstrap().channel(NioServerSocketChannel::class.java).group(Connection.NIO_POOL).childHandler(object : ChannelInitializer<Channel>() {
override fun initChannel(ch: Channel) {
lock.withLock {
if (isClosed) return
LOGGER.info("Incoming connection from ${ch.remoteAddress()}")
try {
val connection = ServerConnection(server, ConnectionType.NETWORK)
connections.add(connection)
connection.bind(ch)
} catch (err: Throwable) {
LOGGER.error("Error while accepting new connection from ${ch.remoteAddress()}", err)
ch.close()
}
}
}
}).bind(localAddress).syncUninterruptibly()
channels.add(channel)
this.localChannel = channel.channel()
channel.channel().closeFuture().addListener {
channels.remove(channel)
}
return channel.channel()
}
}
override fun close() {
lock.withLock {
if (isClosed) return
isClosed = true
channels.forEach { it.channel().close() }
connections.forEach { it.disconnect("Server shutting down") }
// aeugh
connections.forEach {
if (!it.waitDisconnect(2L, TimeUnit.SECONDS)) {
LOGGER.warn("Giving up waiting for $it to disconnect")
}
}
channels.clear()
connections.clear()
}
}
companion object {
private val LOGGER = LogManager.getLogger()
const val MAX_PLAYERS = 32767
}
}