1
0
Fork 1
mirror of https://gitlab.com/mangadex-pub/mangadex_at_home.git synced 2024-01-19 02:48:37 +00:00

Add support for proxy protocol IP forwarding

Allows collecting GeoIp metrics even when md@h is behind non-terminating reverse proxy.

- Ignores requests which don't use proxy protocol
- Copies forwarded IP into `Forwarded` to be used later by `GeoIpMetricsFilter`
This commit is contained in:
Erek Speed 2021-06-03 15:12:59 +09:00
parent 52f23b0f53
commit ecab240817
No known key found for this signature in database
GPG key ID: 12933BB2EAD9F705
2 changed files with 49 additions and 0 deletions

View file

@ -45,6 +45,7 @@ dependencies {
implementation group: "org.http4k", name: "http4k-client-okhttp"
implementation group: "org.http4k", name: "http4k-metrics-micrometer"
implementation group: "org.http4k", name: "http4k-server-netty"
implementation group: "io.netty", name: "netty-codec-haproxy"
implementation group: "io.netty", name: "netty-transport-native-epoll", classifier: "linux-x86_64"
implementation group: "io.netty.incubator", name: "netty-incubator-transport-native-io_uring", version: "0.0.3.Final", classifier: "linux-x86_64"
testImplementation group: "org.http4k", name: "http4k-testing-kotest"

View file

@ -19,6 +19,7 @@ along with this MangaDex@Home. If not, see <http://www.gnu.org/licenses/>.
package mdnet.netty
import io.netty.bootstrap.ServerBootstrap
import io.netty.buffer.ByteBuf
import io.netty.channel.*
import io.netty.channel.epoll.Epoll
import io.netty.channel.epoll.EpollEventLoopGroup
@ -27,6 +28,12 @@ import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioServerSocketChannel
import io.netty.handler.codec.DecoderException
import io.netty.handler.codec.ProtocolDetectionResult
import io.netty.handler.codec.ProtocolDetectionState
import io.netty.handler.codec.haproxy.HAProxyMessage
import io.netty.handler.codec.haproxy.HAProxyMessageDecoder
import io.netty.handler.codec.haproxy.HAProxyProtocolVersion
import io.netty.handler.codec.http.FullHttpRequest
import io.netty.handler.codec.http.HttpObjectAggregator
import io.netty.handler.codec.http.HttpServerCodec
import io.netty.handler.codec.http.HttpServerKeepAliveHandler
@ -43,7 +50,10 @@ import io.netty.handler.traffic.TrafficCounter
import io.netty.incubator.channel.uring.IOUring
import io.netty.incubator.channel.uring.IOUringEventLoopGroup
import io.netty.incubator.channel.uring.IOUringServerSocketChannel
import io.netty.util.AttributeKey
import io.netty.util.AttributeMap
import io.netty.util.DomainWildcardMappingBuilder
import io.netty.util.ReferenceCountUtil
import io.netty.util.concurrent.DefaultEventExecutorGroup
import io.netty.util.internal.SystemPropertyUtil
import mdnet.Constants
@ -172,7 +182,31 @@ class Netty(
bootstrap.group(transport.bossGroup, transport.workerGroup)
.channelFactory(transport.factory)
.childHandler(object : ChannelInitializer<SocketChannel>() {
val HAPROXY_SOURCE = AttributeKey.newInstance<String>("haproxy_source")
public override fun initChannel(ch: SocketChannel) {
ch.pipeline().addLast(
"proxyProtocol",
object : ChannelInboundHandlerAdapter() {
override fun channelRead(ctx: ChannelHandlerContext, msg: Any) {
if (msg is ByteBuf) {
val result: ProtocolDetectionResult<HAProxyProtocolVersion> = HAProxyMessageDecoder.detectProtocol(msg)
if (result.state() == ProtocolDetectionState.DETECTED) {
ctx.pipeline().addAfter("proxyProtocol", null, HAProxyMessageDecoder())
ctx.pipeline().remove(this)
}
}
super.channelRead(ctx, msg)
}
}
)
ch.pipeline().addLast(
"saveOriginalIp",
object : SimpleChannelInboundHandler<HAProxyMessage>() {
override fun channelRead0(ctx: ChannelHandlerContext, msg: HAProxyMessage) {
(ctx as AttributeMap).attr(HAPROXY_SOURCE).set(msg.sourceAddress())
}
}
)
ch.pipeline().addLast(
"ssl",
SniHandler(DomainWildcardMappingBuilder(sslContext).build())
@ -205,6 +239,20 @@ class Netty(
ch.pipeline().addLast("keepAlive", HttpServerKeepAliveHandler())
ch.pipeline().addLast("aggregator", HttpObjectAggregator(65536))
ch.pipeline().addLast(
"setForwardHeader",
object : SimpleChannelInboundHandler<FullHttpRequest>(false) {
override fun channelRead0(ctx: ChannelHandlerContext, request: FullHttpRequest) {
if ((ctx as AttributeMap).hasAttr(HAPROXY_SOURCE)) {
val addr = (ctx as AttributeMap).attr(HAPROXY_SOURCE).get()
request.headers().set("Forwarded", addr)
}
ReferenceCountUtil.retain(request)
ctx.fireChannelRead(request)
}
}
)
ch.pipeline().addLast("burstLimiter", burstLimiter)
ch.pipeline().addLast(