package ru.dbotthepony.kstarbound.server

import io.netty.bootstrap.ServerBootstrap
import io.netty.channel.Channel
import io.netty.channel.ChannelFuture
import io.netty.channel.ChannelInitializer
import io.netty.channel.local.LocalAddress
import io.netty.channel.local.LocalServerChannel
import io.netty.channel.socket.nio.NioServerSocketChannel
import it.unimi.dsi.fastutil.ints.IntAVLTreeSet
import org.apache.logging.log4j.LogManager
import ru.dbotthepony.kstarbound.network.Connection
import ru.dbotthepony.kstarbound.network.ConnectionType
import ru.dbotthepony.kstarbound.network.IPacket
import ru.dbotthepony.kstarbound.network.packets.clientbound.ServerInfoPacket
import java.io.Closeable
import java.net.SocketAddress
import java.util.*
import java.util.concurrent.CopyOnWriteArrayList
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.withLock

class ServerChannels(val server: StarboundServer) : Closeable {
	private val channels = CopyOnWriteArrayList<ChannelFuture>()
	val connections = CopyOnWriteArrayList<ServerConnection>()
	private var localChannel: Channel? = null
	private val lock = ReentrantLock()
	private var isClosed = false

	private var nextConnectionID = 0
	private val occupiedConnectionIDs = IntAVLTreeSet()
	private val connectionIDLock = Any()

	private val playerCount = AtomicInteger()

	fun incrementPlayerCount() {
		val new = playerCount.incrementAndGet()
		broadcast(ServerInfoPacket(new, server.settings.maxPlayers))
	}

	fun decrementPlayerCount() {
		val new = playerCount.decrementAndGet()
		check(new >= 0) { "Player count turned negative" }
		broadcast(ServerInfoPacket(new, server.settings.maxPlayers))
	}

	fun connectionByID(id: Int): ServerConnection? {
		return connections.firstOrNull { it.connectionID == id }
	}

	fun connectionByUUID(id: UUID): ServerConnection? {
		return connections.firstOrNull { it.uuid == id }
	}

	private fun cycleConnectionID(): Int {
		val v = ++nextConnectionID and MAX_PLAYERS

		if (v == 0) {
			nextConnectionID++
			return 1
		}

		return v
	}

	fun nextConnectionID(): Int {
		synchronized(connectionIDLock) {
			var i = 0

			while (i++ <= MAX_PLAYERS) { // 32767 is the maximum
				val get = cycleConnectionID()

				if (!occupiedConnectionIDs.contains(get)) {
					occupiedConnectionIDs.add(get)
					return get
				}
			}
		}

		throw IllegalStateException("No more free connection IDs, how did we end up here?")
	}

	fun freeConnectionID(id: Int): Boolean {
		return synchronized(connectionIDLock) {
			occupiedConnectionIDs.remove(id)
		}
	}

	fun broadcast(packet: IPacket, flush: Boolean = true) {
		connections.forEach {
			if (it.isReady) {
				it.send(packet, flush)
			}
		}
	}

	@Suppress("name_shadowing")
	fun createLocalChannel(): Channel {
		val localChannel = localChannel

		if (localChannel != null) {
			return localChannel
		}

		lock.withLock {
			val localChannel = this.localChannel

			if (localChannel != null) {
				return localChannel
			}

			val channel = ServerBootstrap().channel(LocalServerChannel::class.java).group(Connection.NIO_POOL).childHandler(object : ChannelInitializer<Channel>() {
				override fun initChannel(ch: Channel) {
					lock.withLock {
						if (isClosed) return
						LOGGER.info("Incoming connection from ${ch.remoteAddress()}")

						try {
							val connection = ServerConnection(server, ConnectionType.MEMORY)
							connections.add(connection)
							connection.bind(ch)
						} catch (err: Throwable) {
							LOGGER.error("Error while accepting new connection from ${ch.remoteAddress()}", err)
							ch.close()
						}
					}
				}
			}).bind(LocalAddress.ANY).syncUninterruptibly()

			channels.add(channel)
			this.localChannel = channel.channel()

			channel.channel().closeFuture().addListener {
				channels.remove(channel)
				this.localChannel = null
			}

			return channel.channel()
		}
	}

	fun createChannel(localAddress: SocketAddress): Channel {
		lock.withLock {
			val channel = ServerBootstrap().channel(NioServerSocketChannel::class.java).group(Connection.NIO_POOL).childHandler(object : ChannelInitializer<Channel>() {
				override fun initChannel(ch: Channel) {
					lock.withLock {
						if (isClosed) return
						LOGGER.info("Incoming connection from ${ch.remoteAddress()}")

						try {
							val connection = ServerConnection(server, ConnectionType.NETWORK)
							connections.add(connection)
							connection.bind(ch)
						} catch (err: Throwable) {
							LOGGER.error("Error while accepting new connection from ${ch.remoteAddress()}", err)
							ch.close()
						}
					}
				}
			}).bind(localAddress).syncUninterruptibly()

			channels.add(channel)
			this.localChannel = channel.channel()

			channel.channel().closeFuture().addListener {
				channels.remove(channel)
			}

			return channel.channel()
		}
	}

	override fun close() {
		lock.withLock {
			if (isClosed) return
			isClosed = true

			channels.forEach { it.channel().close() }
			connections.forEach { it.disconnect("Server shutting down") }

			// aeugh
			connections.forEach {
				if (!it.waitDisconnect(2L, TimeUnit.SECONDS)) {
					LOGGER.warn("Giving up waiting for $it to disconnect")
				}
			}

			channels.clear()
			connections.clear()
		}
	}

	companion object {
		private val LOGGER = LogManager.getLogger()
		const val MAX_PLAYERS = 32767
	}
}