/* Mangadex@Home Copyright (c) 2020, MangaDex Network This file is part of MangaDex@Home. MangaDex@Home is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. MangaDex@Home is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this MangaDex@Home. If not, see . */ /* ktlint-disable no-wildcard-imports */ package mdnet.netty import io.netty.bootstrap.ServerBootstrap import io.netty.channel.* import io.netty.channel.epoll.Epoll import io.netty.channel.epoll.EpollEventLoopGroup import io.netty.channel.epoll.EpollServerSocketChannel import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.handler.codec.DecoderException import io.netty.handler.codec.http.* import io.netty.handler.ssl.SslContextBuilder import io.netty.handler.stream.ChunkedWriteHandler import io.netty.handler.timeout.ReadTimeoutException import io.netty.handler.timeout.ReadTimeoutHandler import io.netty.handler.timeout.WriteTimeoutException import io.netty.handler.timeout.WriteTimeoutHandler import io.netty.handler.traffic.GlobalTrafficShapingHandler import io.netty.handler.traffic.TrafficCounter import io.netty.incubator.channel.uring.IOUring import io.netty.incubator.channel.uring.IOUringEventLoopGroup import io.netty.incubator.channel.uring.IOUringServerSocketChannel import io.netty.util.concurrent.DefaultEventExecutorGroup import mdnet.Constants import mdnet.data.Statistics import mdnet.logging.info import mdnet.logging.trace import mdnet.settings.ServerSettings import mdnet.settings.TlsCert import org.http4k.core.HttpHandler import org.http4k.server.Http4kChannelHandler import org.http4k.server.Http4kServer import org.http4k.server.ServerConfig import org.slf4j.LoggerFactory import java.io.ByteArrayInputStream import java.io.IOException import java.io.InputStream import java.net.InetSocketAddress import java.net.SocketException import java.security.PrivateKey import java.security.cert.CertificateFactory import java.security.cert.X509Certificate import java.util.concurrent.atomic.AtomicReference import javax.net.ssl.SSLException interface NettyTransport { val masterGroup: EventLoopGroup val workerGroup: EventLoopGroup val factory: ChannelFactory fun shutdownGracefully() { masterGroup.shutdownGracefully() workerGroup.shutdownGracefully() } private class NioTransport : NettyTransport { override val masterGroup = NioEventLoopGroup() override val workerGroup = NioEventLoopGroup() override val factory = ChannelFactory { NioServerSocketChannel() } } private class EpollTransport : NettyTransport { override val masterGroup = EpollEventLoopGroup() override val workerGroup = EpollEventLoopGroup() override val factory = ChannelFactory { EpollServerSocketChannel() } } private class IOUringTransport : NettyTransport { override val masterGroup = IOUringEventLoopGroup() override val workerGroup = IOUringEventLoopGroup() override val factory = ChannelFactory { IOUringServerSocketChannel() } } companion object { private val LOGGER = LoggerFactory.getLogger(NettyTransport::class.java) fun bestForPlatform(): NettyTransport { if (IOUring.isAvailable()) { LOGGER.info("Using IOUring transport") return IOUringTransport() } else { LOGGER.info(IOUring.unavailabilityCause()) { "IOUring transport not available" } } if (Epoll.isAvailable()) { LOGGER.info("Using Epoll transport") return EpollTransport() } else { LOGGER.info(Epoll.unavailabilityCause()) { "Epoll transport not available" } } LOGGER.info("Using Nio transport") return NioTransport() } } } class Netty(private val tls: TlsCert, private val serverSettings: ServerSettings, private val statistics: AtomicReference) : ServerConfig { override fun toServer(httpHandler: HttpHandler): Http4kServer = object : Http4kServer { private val transport = NettyTransport.bestForPlatform() private val executor = DefaultEventExecutorGroup(serverSettings.threads) private lateinit var closeFuture: ChannelFuture private lateinit var address: InetSocketAddress private val burstLimiter = object : GlobalTrafficShapingHandler( transport.workerGroup, serverSettings.maxKilobitsPerSecond * 1000L / 8L, 0, 50 ) { override fun doAccounting(counter: TrafficCounter) { statistics.getAndUpdate { it.copy(bytesSent = it.bytesSent + counter.cumulativeWrittenBytes()) } counter.resetCumulativeTime() } } override fun start(): Http4kServer = apply { LOGGER.info { "Starting Netty with ${serverSettings.threads} threads" } val certs = getX509Certs(tls.certificate) val sslContext = SslContextBuilder .forServer(getPrivateKey(tls.privateKey), certs) .protocols("TLSv1.3", "TLSv1.2", "TLSv1.1", "TLSv1") .build() val bootstrap = ServerBootstrap() bootstrap.group(transport.masterGroup, transport.workerGroup) .channelFactory(transport.factory) .childHandler(object : ChannelInitializer() { public override fun initChannel(ch: SocketChannel) { ch.pipeline().addLast("ssl", sslContext.newHandler(ch.alloc())) ch.pipeline().addLast("codec", HttpServerCodec()) ch.pipeline().addLast("keepAlive", HttpServerKeepAliveHandler()) ch.pipeline().addLast("aggregator", HttpObjectAggregator(65536)) ch.pipeline().addLast("burstLimiter", burstLimiter) ch.pipeline().addLast("readTimeoutHandler", ReadTimeoutHandler(Constants.MAX_READ_TIME_SECONDS)) ch.pipeline().addLast("writeTimeoutHandler", WriteTimeoutHandler(Constants.MAX_WRITE_TIME_SECONDS)) ch.pipeline().addLast("streamer", ChunkedWriteHandler()) ch.pipeline().addLast(executor, "handler", Http4kChannelHandler(httpHandler)) ch.pipeline().addLast( "exceptions", object : ChannelInboundHandlerAdapter() { override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { if (cause is SSLException || (cause is DecoderException && cause.cause is SSLException)) { LOGGER.trace { "Ignored invalid SSL connection" } LOGGER.trace(cause) { "Exception in pipeline" } } else if (cause is IOException || cause is SocketException) { LOGGER.info { "User (downloader) abruptly closed the connection" } LOGGER.trace(cause) { "Exception in pipeline" } } else if (cause !is ReadTimeoutException && cause !is WriteTimeoutException) { ctx.fireExceptionCaught(cause) } } } ) } }) .option(ChannelOption.SO_BACKLOG, 1000) .childOption(ChannelOption.SO_KEEPALIVE, true) val channel = bootstrap.bind(InetSocketAddress(serverSettings.hostname, serverSettings.port)).sync().channel() address = channel.localAddress() as InetSocketAddress closeFuture = channel.closeFuture() } override fun stop() = apply { closeFuture.cancel(false) transport.shutdownGracefully() executor.shutdownGracefully() } override fun port(): Int = if (serverSettings.port > 0) serverSettings.port else address.port } companion object { private val LOGGER = LoggerFactory.getLogger(Netty::class.java) } } fun getX509Certs(certificates: String): Collection { val targetStream: InputStream = ByteArrayInputStream(certificates.toByteArray()) @Suppress("unchecked_cast") return CertificateFactory.getInstance("X509").generateCertificates(targetStream) as Collection } fun getPrivateKey(privateKey: String): PrivateKey { return loadKey(privateKey)!! }