diff --git a/build.gradle b/build.gradle index 5175b98..19ea743 100644 --- a/build.gradle +++ b/build.gradle @@ -46,6 +46,7 @@ dependencies { implementation group: "org.http4k", name: "http4k-client-okhttp" implementation group: "org.http4k", name: "http4k-metrics-micrometer" implementation group: "org.http4k", name: "http4k-server-netty" + implementation group: "io.netty", name: "netty-codec-haproxy" implementation group: "io.netty", name: "netty-transport-native-epoll", classifier: "linux-x86_64" implementation group: "io.netty.incubator", name: "netty-incubator-transport-native-io_uring", version: "0.0.3.Final", classifier: "linux-x86_64" testImplementation group: "org.http4k", name: "http4k-testing-kotest" diff --git a/settings.sample.yaml b/settings.sample.yaml index 454d877..0325ea4 100644 --- a/settings.sample.yaml +++ b/settings.sample.yaml @@ -80,6 +80,13 @@ server_settings: # 0 defaults to (2 * your available processors) # https://docs.oracle.com/en/java/javase/11/docs/api/java.base/java/lang/Runtime.html#availableProcessors() threads: 0 + # Whether to enable support for HAProxy Proxy Protocol + # If using a reverse proxy to forward requests to MD@H via + # ssl passthrough, you can use Proxy Protocol to preserve + # original IP if your reverse proxy supports it. This + # will allow geo location metrics to work correctly. + # https://www.haproxy.com/blog/haproxy/proxy-protocol/ + enable_proxy_protocol: false # Settings intended for advanced use cases or tinkering diff --git a/src/main/kotlin/mdnet/netty/ApplicationNetty.kt b/src/main/kotlin/mdnet/netty/ApplicationNetty.kt index 7a4da2d..b208f1a 100644 --- a/src/main/kotlin/mdnet/netty/ApplicationNetty.kt +++ b/src/main/kotlin/mdnet/netty/ApplicationNetty.kt @@ -19,6 +19,7 @@ along with this MangaDex@Home. If not, see . package mdnet.netty import io.netty.bootstrap.ServerBootstrap +import io.netty.buffer.ByteBuf import io.netty.channel.* import io.netty.channel.epoll.Epoll import io.netty.channel.epoll.EpollEventLoopGroup @@ -27,6 +28,12 @@ import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.handler.codec.DecoderException +import io.netty.handler.codec.ProtocolDetectionResult +import io.netty.handler.codec.ProtocolDetectionState +import io.netty.handler.codec.haproxy.HAProxyMessage +import io.netty.handler.codec.haproxy.HAProxyMessageDecoder +import io.netty.handler.codec.haproxy.HAProxyProtocolVersion +import io.netty.handler.codec.http.FullHttpRequest import io.netty.handler.codec.http.HttpObjectAggregator import io.netty.handler.codec.http.HttpServerCodec import io.netty.handler.codec.http.HttpServerKeepAliveHandler @@ -43,7 +50,10 @@ import io.netty.handler.traffic.TrafficCounter import io.netty.incubator.channel.uring.IOUring import io.netty.incubator.channel.uring.IOUringEventLoopGroup import io.netty.incubator.channel.uring.IOUringServerSocketChannel +import io.netty.util.AttributeKey +import io.netty.util.AttributeMap import io.netty.util.DomainWildcardMappingBuilder +import io.netty.util.ReferenceCountUtil import io.netty.util.concurrent.DefaultEventExecutorGroup import io.netty.util.internal.SystemPropertyUtil import mdnet.Constants @@ -173,6 +183,35 @@ class Netty( .channelFactory(transport.factory) .childHandler(object : ChannelInitializer() { public override fun initChannel(ch: SocketChannel) { + if (serverSettings.enableProxyProtocol) { + ch.pipeline().addLast( + "proxyProtocol", + object : ChannelInboundHandlerAdapter() { + override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { + if (msg is ByteBuf) { + // Since the builtin `HAProxyMessageDecoder` will break non Proxy Protocol requests + // we need to use its detection capabilities to only add it when needed. + val result: ProtocolDetectionResult = HAProxyMessageDecoder.detectProtocol(msg) + if (result.state() == ProtocolDetectionState.DETECTED) { + ctx.pipeline().addAfter("proxyProtocol", null, HAProxyMessageDecoder()) + ctx.pipeline().remove(this) + } + } + super.channelRead(ctx, msg) + } + } + ) + ch.pipeline().addLast( + "saveOriginalIp", + object : SimpleChannelInboundHandler() { + override fun channelRead0(ctx: ChannelHandlerContext, msg: HAProxyMessage) { + // Store proxy IP in an attribute for later use after HTTP request is extracted. + // Using an attribute ensures the value is scoped to this channel. + (ctx as AttributeMap).attr(HAPROXY_SOURCE).set(msg.sourceAddress()) + } + } + ) + } ch.pipeline().addLast( "ssl", SniHandler(DomainWildcardMappingBuilder(sslContext).build()) @@ -206,6 +245,26 @@ class Netty( ch.pipeline().addLast("keepAlive", HttpServerKeepAliveHandler()) ch.pipeline().addLast("aggregator", HttpObjectAggregator(65536)) + if (serverSettings.enableProxyProtocol) { + ch.pipeline().addLast( + "setForwardHeader", + object : SimpleChannelInboundHandler(false) { + override fun channelRead0(ctx: ChannelHandlerContext, request: FullHttpRequest) { + // The geo location code already supports the `Forwarded header so setting + // it is the easiest way to introduce the original IP downstream. + if ((ctx as AttributeMap).hasAttr(HAPROXY_SOURCE)) { + val addr = (ctx as AttributeMap).attr(HAPROXY_SOURCE).get() + request.headers().set("Forwarded", addr) + } + // Since we're modifying the request without handling it, we must + // call retain to ensure it will still be available downstream. + ReferenceCountUtil.retain(request) + ctx.fireChannelRead(request) + } + } + ) + } + ch.pipeline().addLast("burstLimiter", burstLimiter) ch.pipeline().addLast( @@ -256,6 +315,7 @@ class Netty( companion object { private val LOGGER = LoggerFactory.getLogger(Netty::class.java) + private val HAPROXY_SOURCE = AttributeKey.newInstance("haproxy_source") } } diff --git a/src/main/kotlin/mdnet/settings/ClientSettings.kt b/src/main/kotlin/mdnet/settings/ClientSettings.kt index d3fbbab..3c37773 100644 --- a/src/main/kotlin/mdnet/settings/ClientSettings.kt +++ b/src/main/kotlin/mdnet/settings/ClientSettings.kt @@ -43,6 +43,7 @@ data class ServerSettings( val externalIp: String? = null, val port: Int = 443, val threads: Int = 0, + val enableProxyProtocol: Boolean = false, ) @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy::class)