/* Mangadex@Home Copyright (c) 2020, MangaDex Network This file is part of MangaDex@Home. MangaDex@Home is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. MangaDex@Home is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this MangaDex@Home. If not, see . */ package mdnet.netty import io.netty.bootstrap.ServerBootstrap import io.netty.buffer.ByteBuf import io.netty.channel.* import io.netty.channel.epoll.Epoll import io.netty.channel.epoll.EpollEventLoopGroup import io.netty.channel.epoll.EpollServerSocketChannel import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.handler.codec.DecoderException import io.netty.handler.codec.ProtocolDetectionResult import io.netty.handler.codec.ProtocolDetectionState import io.netty.handler.codec.haproxy.HAProxyMessage import io.netty.handler.codec.haproxy.HAProxyMessageDecoder import io.netty.handler.codec.haproxy.HAProxyProtocolVersion import io.netty.handler.codec.http.FullHttpRequest import io.netty.handler.codec.http.HttpObjectAggregator import io.netty.handler.codec.http.HttpServerCodec import io.netty.handler.codec.http.HttpServerKeepAliveHandler import io.netty.handler.ssl.SniCompletionEvent import io.netty.handler.ssl.SniHandler import io.netty.handler.ssl.SslContextBuilder import io.netty.handler.stream.ChunkedWriteHandler import io.netty.handler.timeout.ReadTimeoutException import io.netty.handler.timeout.ReadTimeoutHandler import io.netty.handler.timeout.WriteTimeoutException import io.netty.handler.timeout.WriteTimeoutHandler import io.netty.handler.traffic.GlobalTrafficShapingHandler import io.netty.handler.traffic.TrafficCounter import io.netty.incubator.channel.uring.IOUring import io.netty.incubator.channel.uring.IOUringEventLoopGroup import io.netty.incubator.channel.uring.IOUringServerSocketChannel import io.netty.util.AttributeKey import io.netty.util.AttributeMap import io.netty.util.DomainWildcardMappingBuilder import io.netty.util.ReferenceCountUtil import io.netty.util.concurrent.DefaultEventExecutorGroup import io.netty.util.internal.SystemPropertyUtil import mdnet.Constants import mdnet.data.Statistics import mdnet.logging.info import mdnet.logging.trace import mdnet.logging.warn import mdnet.settings.DevSettings import mdnet.settings.RemoteSettings import mdnet.settings.ServerSettings import org.http4k.core.HttpHandler import org.http4k.server.Http4kChannelHandler import org.http4k.server.Http4kServer import org.http4k.server.ServerConfig import org.slf4j.LoggerFactory import java.io.ByteArrayInputStream import java.io.IOException import java.io.InputStream import java.net.InetSocketAddress import java.net.SocketException import java.security.PrivateKey import java.security.cert.CertificateFactory import java.security.cert.X509Certificate import java.util.Locale import java.util.concurrent.TimeUnit import javax.net.ssl.SSLException sealed class NettyTransport(threads: Int) { abstract val bossGroup: EventLoopGroup abstract val workerGroup: EventLoopGroup abstract val factory: ChannelFactory val executor = DefaultEventExecutorGroup( threads.also { require(threads > 0) { "Threads must be greater than zero" } } ) private class NioTransport(threads: Int) : NettyTransport(threads) { override val bossGroup = NioEventLoopGroup(1) override val workerGroup = NioEventLoopGroup(8) override val factory = ChannelFactory { NioServerSocketChannel() } } private class EpollTransport(threads: Int) : NettyTransport(threads) { override val bossGroup = EpollEventLoopGroup(1) override val workerGroup = EpollEventLoopGroup(8) override val factory = ChannelFactory { EpollServerSocketChannel() } } private class IOUringTransport(threads: Int) : NettyTransport(threads) { override val bossGroup = IOUringEventLoopGroup(1) override val workerGroup = IOUringEventLoopGroup(8) override val factory = ChannelFactory { IOUringServerSocketChannel() } } companion object { private val LOGGER = LoggerFactory.getLogger(NettyTransport::class.java) private fun defaultNumThreads() = Runtime.getRuntime().availableProcessors() * 2 fun bestForPlatform(threads: Int): NettyTransport { val name = SystemPropertyUtil.get("os.name").lowercase(Locale.US).trim { it <= ' ' } val threadsToUse = if (threads == 0) defaultNumThreads() else threads LOGGER.info { "Choosing a transport with $threadsToUse threads" } if (name.startsWith("linux")) { if (!SystemPropertyUtil.get("no-iouring").toBoolean()) { if (IOUring.isAvailable()) { LOGGER.info { "Using IOUring transport" } return IOUringTransport(threadsToUse) } else { LOGGER.info { "IOUring transport not available (this may be normal)" } } } if (!SystemPropertyUtil.get("no-epoll").toBoolean()) { if (Epoll.isAvailable()) { LOGGER.info { "Using Epoll transport" } return EpollTransport(threadsToUse) } else { LOGGER.info { "Epoll transport not available (this may be normal)" } } } } LOGGER.info { "Using Nio transport" } return NioTransport(threadsToUse) } } } class Netty( private val remoteSettings: RemoteSettings, private val serverSettings: ServerSettings, private val devSettings: DevSettings, private val statistics: Statistics ) : ServerConfig { override fun toServer(http: HttpHandler): Http4kServer = object : Http4kServer { private val transport = NettyTransport.bestForPlatform(serverSettings.threads) private lateinit var channel: Channel private val burstLimiter = object : GlobalTrafficShapingHandler( transport.workerGroup, serverSettings.maxKilobitsPerSecond * 1000L / 8L, 0, 100 ) { override fun doAccounting(counter: TrafficCounter) { statistics.bytesSent.getAndAccumulate(counter.cumulativeWrittenBytes()) { a, b -> a + b } counter.resetCumulativeTime() } } override fun start(): Http4kServer = apply { LOGGER.info { "Starting Netty!" } val tls = remoteSettings.tls!! val certs = getX509Certs(tls.certificate) val sslContext = SslContextBuilder .forServer(getPrivateKey(tls.privateKey), certs) .protocols("TLSv1.3", "TLSv1.2") .build() val bootstrap = ServerBootstrap() bootstrap.group(transport.bossGroup, transport.workerGroup) .channelFactory(transport.factory) .childHandler(object : ChannelInitializer() { public override fun initChannel(ch: SocketChannel) { if (serverSettings.enableProxyProtocol) { ch.pipeline().addLast( "proxyProtocol", object : ChannelInboundHandlerAdapter() { override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { if (msg is ByteBuf) { // Since the builtin `HAProxyMessageDecoder` will break non Proxy-Protocol requests // we need to use its detection capabilities to only add it when needed. val result: ProtocolDetectionResult = HAProxyMessageDecoder.detectProtocol(msg) if (result.state() == ProtocolDetectionState.DETECTED) { ctx.pipeline().addAfter("proxyProtocol", null, HAProxyMessageDecoder()) ctx.pipeline().remove(this) } } super.channelRead(ctx, msg) } } ) ch.pipeline().addLast( "saveOriginalIp", object : SimpleChannelInboundHandler() { override fun channelRead0(ctx: ChannelHandlerContext, msg: HAProxyMessage) { // Store proxy IP in an attribute for later use after HTTP request is extracted. // Using an attribute ensures the value is scoped to this channel. (ctx as AttributeMap).attr(HAPROXY_SOURCE).set(msg.sourceAddress()) } } ) } ch.pipeline().addLast( "ssl", SniHandler(DomainWildcardMappingBuilder(sslContext).build()) ) ch.pipeline().addLast( "dropHostname", object : ChannelInboundHandlerAdapter() { private val hostToTest = remoteSettings.url.authority.let { it.substring(0, it.lastIndexOf(":")) } override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any) { if (evt is SniCompletionEvent) { if (!devSettings.disableSniCheck) { if (evt.hostname() != null && !evt.hostname().endsWith(hostToTest) && !evt.hostname().endsWith("localhost") ) { ctx.close() } } } else { ctx.fireUserEventTriggered(evt) } } } ) ch.pipeline().addLast("codec", HttpServerCodec()) ch.pipeline().addLast("keepAlive", HttpServerKeepAliveHandler()) ch.pipeline().addLast("aggregator", HttpObjectAggregator(65536)) if (serverSettings.enableProxyProtocol) { ch.pipeline().addLast( "setForwardHeader", object : SimpleChannelInboundHandler(false) { override fun channelRead0(ctx: ChannelHandlerContext, request: FullHttpRequest) { // The geo-location code already supports the `Forwarded` header so setting // it is the easiest way to introduce the original IP downstream. if ((ctx as AttributeMap).hasAttr(HAPROXY_SOURCE)) { val addr = (ctx as AttributeMap).attr(HAPROXY_SOURCE).get() request.headers().set("Forwarded", addr) } // Since we're modifying the request without handling it, we must // call retain to ensure it will still be available downstream. ReferenceCountUtil.retain(request) ctx.fireChannelRead(request) } } ) } ch.pipeline().addLast("burstLimiter", burstLimiter) ch.pipeline().addLast( "readTimeoutHandler", ReadTimeoutHandler(Constants.MAX_READ_TIME_SECONDS) ) ch.pipeline().addLast( "writeTimeoutHandler", WriteTimeoutHandler(Constants.MAX_WRITE_TIME_SECONDS) ) ch.pipeline().addLast("streamer", ChunkedWriteHandler()) ch.pipeline().addLast(transport.executor, "handler", Http4kChannelHandler(http)) ch.pipeline().addLast( "exceptions", object : ChannelInboundHandlerAdapter() { override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { if (cause is SSLException || (cause is DecoderException && cause.cause is SSLException)) { LOGGER.trace(cause) { "Ignored invalid SSL connection" } } else if (cause is IOException || cause is SocketException) { LOGGER.trace(cause) { "User (downloader) abruptly closed the connection" } } else if (cause !is ReadTimeoutException && cause !is WriteTimeoutException) { LOGGER.warn(cause) { "Exception in pipeline" } } } } ) } }) .option(ChannelOption.SO_BACKLOG, 1000) .childOption(ChannelOption.SO_KEEPALIVE, true) channel = bootstrap.bind(InetSocketAddress(serverSettings.hostname, serverSettings.port)).sync().channel() } override fun stop() = apply { channel.close().sync() transport.run { bossGroup.shutdownGracefully(0, 500, TimeUnit.MILLISECONDS).sync() workerGroup.shutdownGracefully(0, 500, TimeUnit.MILLISECONDS).sync() executor.shutdownGracefully(0, 500, TimeUnit.MILLISECONDS).sync() } } override fun port(): Int = (channel.localAddress() as InetSocketAddress).port } companion object { private val LOGGER = LoggerFactory.getLogger(Netty::class.java) private val HAPROXY_SOURCE = AttributeKey.newInstance("haproxy_source") } } fun getX509Certs(certificates: String): Collection { val targetStream: InputStream = ByteArrayInputStream(certificates.toByteArray()) @Suppress("unchecked_cast") return CertificateFactory.getInstance("X509").generateCertificates(targetStream) as Collection } fun getPrivateKey(privateKey: String): PrivateKey { return loadKey(privateKey)!! }