package ru.dbotthepony.kstarbound.server import io.netty.bootstrap.ServerBootstrap import io.netty.channel.Channel import io.netty.channel.ChannelFuture import io.netty.channel.ChannelInitializer import io.netty.channel.local.LocalAddress import io.netty.channel.local.LocalServerChannel import io.netty.channel.socket.nio.NioServerSocketChannel import it.unimi.dsi.fastutil.ints.IntAVLTreeSet import org.apache.logging.log4j.LogManager import ru.dbotthepony.kstarbound.network.Connection import ru.dbotthepony.kstarbound.network.ConnectionType import ru.dbotthepony.kstarbound.network.IPacket import ru.dbotthepony.kstarbound.network.packets.clientbound.ServerInfoPacket import java.io.Closeable import java.net.SocketAddress import java.util.* import java.util.concurrent.CopyOnWriteArrayList import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.locks.ReentrantLock import kotlin.concurrent.withLock class ServerChannels(val server: StarboundServer) : Closeable { private val channels = CopyOnWriteArrayList<ChannelFuture>() val connections = CopyOnWriteArrayList<ServerConnection>() private var localChannel: Channel? = null private val lock = ReentrantLock() private var isClosed = false private var nextConnectionID = 0 private val occupiedConnectionIDs = IntAVLTreeSet() private val connectionIDLock = Any() private val playerCount = AtomicInteger() fun incrementPlayerCount() { val new = playerCount.incrementAndGet() broadcast(ServerInfoPacket(new, server.settings.maxPlayers)) } fun decrementPlayerCount() { val new = playerCount.decrementAndGet() check(new >= 0) { "Player count turned negative" } broadcast(ServerInfoPacket(new, server.settings.maxPlayers)) } fun connectionByID(id: Int): ServerConnection? { return connections.firstOrNull { it.connectionID == id } } fun connectionByUUID(id: UUID): ServerConnection? { return connections.firstOrNull { it.uuid == id } } private fun cycleConnectionID(): Int { val v = ++nextConnectionID and MAX_PLAYERS if (v == 0) { nextConnectionID++ return 1 } return v } fun nextConnectionID(): Int { synchronized(connectionIDLock) { var i = 0 while (i++ <= MAX_PLAYERS) { // 32767 is the maximum val get = cycleConnectionID() if (!occupiedConnectionIDs.contains(get)) { occupiedConnectionIDs.add(get) return get } } } throw IllegalStateException("No more free connection IDs, how did we end up here?") } fun freeConnectionID(id: Int): Boolean { return synchronized(connectionIDLock) { occupiedConnectionIDs.remove(id) } } fun broadcast(packet: IPacket, flush: Boolean = true) { connections.forEach { if (it.isReady) { it.send(packet, flush) } } } @Suppress("name_shadowing") fun createLocalChannel(): Channel { val localChannel = localChannel if (localChannel != null) { return localChannel } lock.withLock { val localChannel = this.localChannel if (localChannel != null) { return localChannel } val channel = ServerBootstrap().channel(LocalServerChannel::class.java).group(Connection.NIO_POOL).childHandler(object : ChannelInitializer<Channel>() { override fun initChannel(ch: Channel) { lock.withLock { if (isClosed) return LOGGER.info("Incoming connection from ${ch.remoteAddress()}") try { val connection = ServerConnection(server, ConnectionType.MEMORY) connections.add(connection) connection.bind(ch) } catch (err: Throwable) { LOGGER.error("Error while accepting new connection from ${ch.remoteAddress()}", err) ch.close() } } } }).bind(LocalAddress.ANY).syncUninterruptibly() channels.add(channel) this.localChannel = channel.channel() channel.channel().closeFuture().addListener { channels.remove(channel) this.localChannel = null } return channel.channel() } } fun createChannel(localAddress: SocketAddress): Channel { lock.withLock { val channel = ServerBootstrap().channel(NioServerSocketChannel::class.java).group(Connection.NIO_POOL).childHandler(object : ChannelInitializer<Channel>() { override fun initChannel(ch: Channel) { lock.withLock { if (isClosed) return LOGGER.info("Incoming connection from ${ch.remoteAddress()}") try { val connection = ServerConnection(server, ConnectionType.NETWORK) connections.add(connection) connection.bind(ch) } catch (err: Throwable) { LOGGER.error("Error while accepting new connection from ${ch.remoteAddress()}", err) ch.close() } } } }).bind(localAddress).syncUninterruptibly() channels.add(channel) this.localChannel = channel.channel() channel.channel().closeFuture().addListener { channels.remove(channel) } return channel.channel() } } override fun close() { lock.withLock { if (isClosed) return isClosed = true channels.forEach { it.channel().close() } connections.forEach { it.disconnect("Server shutting down") } // aeugh connections.forEach { if (!it.waitDisconnect(2L, TimeUnit.SECONDS)) { LOGGER.warn("Giving up waiting for $it to disconnect") } } channels.clear() connections.clear() } } companion object { private val LOGGER = LogManager.getLogger() const val MAX_PLAYERS = 32767 } }