1
0
Fork 1
mirror of https://gitlab.com/mangadex-pub/mangadex_at_home.git synced 2024-01-19 02:48:37 +00:00
mangadex_at_home/src/main/kotlin/mdnet/base/Netty.kt

127 lines
5.5 KiB
Kotlin
Raw Normal View History

2020-06-06 22:52:25 +00:00
package mdnet.base
import io.netty.bootstrap.ServerBootstrap
import io.netty.buffer.Unpooled
import io.netty.channel.ChannelFactory
import io.netty.channel.ChannelFuture
import io.netty.channel.ChannelHandler.Sharable
import io.netty.channel.ChannelHandlerContext
import io.netty.channel.ChannelInboundHandlerAdapter
import io.netty.channel.ChannelInitializer
import io.netty.channel.ChannelOption
import io.netty.channel.ServerChannel
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioServerSocketChannel
import io.netty.handler.codec.http.DefaultFullHttpResponse
import io.netty.handler.codec.http.HttpHeaderNames
import io.netty.handler.codec.http.HttpObjectAggregator
import io.netty.handler.codec.http.HttpResponseStatus
import io.netty.handler.codec.http.HttpServerCodec
import io.netty.handler.codec.http.HttpUtil
import io.netty.handler.codec.http.HttpVersion
import io.netty.handler.ssl.OptionalSslHandler
import io.netty.handler.ssl.SslContextBuilder
import io.netty.handler.ssl.SslHandler
import io.netty.handler.stream.ChunkedWriteHandler
import io.netty.handler.traffic.GlobalTrafficShapingHandler
import io.netty.handler.traffic.TrafficCounter
import org.http4k.core.HttpHandler
import org.http4k.server.Http4kChannelHandler
import org.http4k.server.Http4kServer
import org.http4k.server.ServerConfig
import java.net.InetSocketAddress
import java.nio.charset.StandardCharsets
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.atomic.AtomicReference
@Sharable
class ConnectionCounter : ChannelInboundHandlerAdapter() {
private val connections = AtomicInteger()
override fun channelActive(ctx: ChannelHandlerContext) {
val sslHandler = ctx.pipeline()[SslHandler::class.java]
if (sslHandler != null) {
sslHandler.handshakeFuture().addListener {
handleConnection(ctx)
}
} else {
handleConnection(ctx)
}
}
private fun handleConnection(ctx: ChannelHandlerContext) {
if (connections.incrementAndGet() <= Constants.MAX_CONCURRENT_CONNECTIONS) {
super.channelActive(ctx)
} else {
val response = Unpooled.copiedBuffer(Constants.OVERLOADED_MESSAGE, StandardCharsets.UTF_8)
val res =
DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.SERVICE_UNAVAILABLE, response)
res.headers().set(HttpHeaderNames.CONTENT_TYPE, "text/html; charset=UTF-8")
HttpUtil.setContentLength(res, response.readableBytes().toLong())
ctx.writeAndFlush(res)
ctx.close()
}
}
override fun channelInactive(ctx: ChannelHandlerContext) {
super.channelInactive(ctx)
connections.decrementAndGet()
}
}
class Netty(private val tls: ServerSettings.TlsCert, private val clientSettings: ClientSettings, private val stats: AtomicReference<Statistics>) : ServerConfig {
override fun toServer(httpHandler: HttpHandler): Http4kServer = object : Http4kServer {
private val masterGroup = NioEventLoopGroup()
private val workerGroup = NioEventLoopGroup()
private lateinit var closeFuture: ChannelFuture
private lateinit var address: InetSocketAddress
private val counter = ConnectionCounter()
private val burstLimiter = object : GlobalTrafficShapingHandler(
workerGroup, 1024 * clientSettings.maxBurstRateKibPerSecond, 0, 50) {
override fun doAccounting(counter: TrafficCounter) {
stats.get().bytesSent.getAndAdd(counter.cumulativeWrittenBytes())
}
}
override fun start(): Http4kServer = apply {
val sslContext = SslContextBuilder.forServer(getPrivateKey(tls.privateKey), getX509Cert(tls.certificate)).build()
val bootstrap = ServerBootstrap()
bootstrap.group(masterGroup, workerGroup)
.channelFactory(ChannelFactory<ServerChannel> { NioServerSocketChannel() })
.childHandler(object : ChannelInitializer<SocketChannel>() {
public override fun initChannel(ch: SocketChannel) {
ch.pipeline().addLast("ssl", OptionalSslHandler(sslContext))
ch.pipeline().addLast("codec", HttpServerCodec())
ch.pipeline().addLast("counter", counter)
ch.pipeline().addLast("aggregator", HttpObjectAggregator(65536))
ch.pipeline().addLast("burstLimiter", burstLimiter)
ch.pipeline().addLast("streamer", ChunkedWriteHandler())
ch.pipeline().addLast("handler", Http4kChannelHandler(httpHandler))
}
})
.option(ChannelOption.SO_BACKLOG, 1000)
.childOption(ChannelOption.SO_KEEPALIVE, true)
val channel = bootstrap.bind(clientSettings.clientPort).sync().channel()
address = channel.localAddress() as InetSocketAddress
closeFuture = channel.closeFuture()
}
override fun stop() = apply {
masterGroup.shutdownGracefully(5, 15, TimeUnit.SECONDS).sync()
workerGroup.shutdownGracefully(5, 15, TimeUnit.SECONDS).sync()
closeFuture.sync()
}
override fun port(): Int = if (clientSettings.clientPort > 0) clientSettings.clientPort else address.port
}
}