package mdnet.base import io.netty.bootstrap.ServerBootstrap import io.netty.buffer.Unpooled import io.netty.channel.ChannelFactory import io.netty.channel.ChannelFuture import io.netty.channel.ChannelHandler.Sharable import io.netty.channel.ChannelHandlerContext import io.netty.channel.ChannelInboundHandlerAdapter import io.netty.channel.ChannelInitializer import io.netty.channel.ChannelOption import io.netty.channel.ServerChannel import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.handler.codec.http.DefaultFullHttpResponse import io.netty.handler.codec.http.HttpHeaderNames import io.netty.handler.codec.http.HttpObjectAggregator import io.netty.handler.codec.http.HttpResponseStatus import io.netty.handler.codec.http.HttpServerCodec import io.netty.handler.codec.http.HttpUtil import io.netty.handler.codec.http.HttpVersion import io.netty.handler.ssl.OptionalSslHandler import io.netty.handler.ssl.SslContextBuilder import io.netty.handler.ssl.SslHandler import io.netty.handler.stream.ChunkedWriteHandler import io.netty.handler.traffic.GlobalTrafficShapingHandler import io.netty.handler.traffic.TrafficCounter import org.http4k.core.HttpHandler import org.http4k.server.Http4kChannelHandler import org.http4k.server.Http4kServer import org.http4k.server.ServerConfig import java.net.InetSocketAddress import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicReference @Sharable class ConnectionCounter : ChannelInboundHandlerAdapter() { private val connections = AtomicInteger() override fun channelActive(ctx: ChannelHandlerContext) { val sslHandler = ctx.pipeline()[SslHandler::class.java] if (sslHandler != null) { sslHandler.handshakeFuture().addListener { handleConnection(ctx) } } else { handleConnection(ctx) } } private fun handleConnection(ctx: ChannelHandlerContext) { if (connections.incrementAndGet() <= Constants.MAX_CONCURRENT_CONNECTIONS) { super.channelActive(ctx) } else { val response = Unpooled.copiedBuffer(Constants.OVERLOADED_MESSAGE, StandardCharsets.UTF_8) val res = DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.SERVICE_UNAVAILABLE, response) res.headers().set(HttpHeaderNames.CONTENT_TYPE, "text/html; charset=UTF-8") HttpUtil.setContentLength(res, response.readableBytes().toLong()) ctx.writeAndFlush(res) ctx.close() } } override fun channelInactive(ctx: ChannelHandlerContext) { super.channelInactive(ctx) connections.decrementAndGet() } } class Netty(private val tls: ServerSettings.TlsCert, private val clientSettings: ClientSettings, private val stats: AtomicReference) : ServerConfig { override fun toServer(httpHandler: HttpHandler): Http4kServer = object : Http4kServer { private val masterGroup = NioEventLoopGroup() private val workerGroup = NioEventLoopGroup() private lateinit var closeFuture: ChannelFuture private lateinit var address: InetSocketAddress private val counter = ConnectionCounter() private val burstLimiter = object : GlobalTrafficShapingHandler( workerGroup, 1024 * clientSettings.maxBurstRateKibPerSecond, 0, 50) { override fun doAccounting(counter: TrafficCounter) { stats.get().bytesSent.getAndAdd(counter.cumulativeWrittenBytes()) } } override fun start(): Http4kServer = apply { val sslContext = SslContextBuilder.forServer(getPrivateKey(tls.privateKey), getX509Cert(tls.certificate)).build() val bootstrap = ServerBootstrap() bootstrap.group(masterGroup, workerGroup) .channelFactory(ChannelFactory { NioServerSocketChannel() }) .childHandler(object : ChannelInitializer() { public override fun initChannel(ch: SocketChannel) { ch.pipeline().addLast("ssl", OptionalSslHandler(sslContext)) ch.pipeline().addLast("codec", HttpServerCodec()) ch.pipeline().addLast("counter", counter) ch.pipeline().addLast("aggregator", HttpObjectAggregator(65536)) ch.pipeline().addLast("burstLimiter", burstLimiter) ch.pipeline().addLast("streamer", ChunkedWriteHandler()) ch.pipeline().addLast("handler", Http4kChannelHandler(httpHandler)) } }) .option(ChannelOption.SO_BACKLOG, 1000) .childOption(ChannelOption.SO_KEEPALIVE, true) val channel = bootstrap.bind(clientSettings.clientPort).sync().channel() address = channel.localAddress() as InetSocketAddress closeFuture = channel.closeFuture() } override fun stop() = apply { masterGroup.shutdownGracefully(5, 15, TimeUnit.SECONDS).sync() workerGroup.shutdownGracefully(5, 15, TimeUnit.SECONDS).sync() closeFuture.sync() } override fun port(): Int = if (clientSettings.clientPort > 0) clientSettings.clientPort else address.port } }