Reinstall limiter, add buffering, fix TLS setup

This commit is contained in:
carbotaniuman 2020-06-09 14:21:18 -05:00
parent d1629e9f5c
commit 35ea86c6ac
5 changed files with 85 additions and 17 deletions

View file

@ -7,7 +7,7 @@ plugins {
}
group = "com.mangadex"
version = "1.0.0-rc8"
version = "1.0.0-rc9"
mainClassName = "mdnet.base.MangaDexClient"
repositories {

View file

@ -6,4 +6,7 @@ public class Constants {
public static final int CLIENT_BUILD = 2;
public static final String CLIENT_VERSION = "1.0";
public static final Duration MAX_AGE_CACHE = Duration.ofDays(14);
public static final int MAX_CONCURRENT_CONNECTIONS = 2;
public static final String OVERLOADED_MESSAGE = "This server is experiencing a surge in connections. Please try again later.";
}

View file

@ -23,6 +23,8 @@ import org.http4k.routing.routes
import org.http4k.server.Http4kServer
import org.http4k.server.asServer
import org.slf4j.LoggerFactory
import java.io.BufferedInputStream
import java.io.BufferedOutputStream
import java.io.InputStream
import java.security.MessageDigest
import java.time.ZoneOffset
@ -37,10 +39,15 @@ import javax.crypto.CipherOutputStream
import javax.crypto.spec.SecretKeySpec
private val LOGGER = LoggerFactory.getLogger("Application")
private val THREADS_TO_ALLOCATE = Runtime.getRuntime().availableProcessors() * 30 / 2 ;
fun getServer(cache: DiskLruCache, serverSettings: ServerSettings, clientSettings: ClientSettings, statistics: AtomicReference<Statistics>): Http4kServer {
val executor = Executors.newCachedThreadPool()
if (LOGGER.isInfoEnabled) {
LOGGER.info("Starting ApacheClient with {} threads", THREADS_TO_ALLOCATE)
}
val client = ApacheClient(responseBodyMode = BodyMode.Stream, client = HttpClients.custom()
.setDefaultRequestConfig(RequestConfig.custom()
.setCookieSpec(CookieSpecs.IGNORE_COOKIES)
@ -48,8 +55,8 @@ fun getServer(cache: DiskLruCache, serverSettings: ServerSettings, clientSetting
.setSocketTimeout(3000)
.setConnectionRequestTimeout(3000)
.build())
.setMaxConnTotal(75)
.setMaxConnPerRoute(75)
.setMaxConnTotal(THREADS_TO_ALLOCATE)
.setMaxConnPerRoute(THREADS_TO_ALLOCATE)
.build())
val app = { dataSaver: Boolean ->
@ -122,7 +129,7 @@ fun getServer(cache: DiskLruCache, serverSettings: ServerSettings, clientSetting
}
respondWithImage(
CipherInputStream(snapshot.getInputStream(0), getRc4(rc4Bytes)),
CipherInputStream(BufferedInputStream(snapshot.getInputStream(0)), getRc4(rc4Bytes)),
snapshot.getLength(0).toString(), snapshot.getString(1), snapshot.getString(2)
)
}
@ -161,19 +168,19 @@ fun getServer(cache: DiskLruCache, serverSettings: ServerSettings, clientSetting
val tee = CachingInputStream(
mdResponse.body.stream,
executor, CipherOutputStream(editor.newOutputStream(0), getRc4(rc4Bytes))
executor, CipherOutputStream(BufferedOutputStream(editor.newOutputStream(0)), getRc4(rc4Bytes))
) {
// Note: if neither of the options get called/are in the log
// check that tee gets closed and for exceptions in this lambda
if (editor.getLength(0) == contentLength.toLong()) {
if (LOGGER.isInfoEnabled) {
LOGGER.info("Cache download $sanitizedUri committed")
LOGGER.info("Cache download for $sanitizedUri committed")
}
editor.commit()
} else {
if (LOGGER.isInfoEnabled) {
LOGGER.info("Cache download $sanitizedUri aborted")
LOGGER.info("Cache download for $sanitizedUri aborted")
}
editor.abort()

View file

@ -36,15 +36,6 @@ private const val PKCS_1_PEM_FOOTER = "-----END RSA PRIVATE KEY-----"
private const val PKCS_8_PEM_HEADER = "-----BEGIN PRIVATE KEY-----"
private const val PKCS_8_PEM_FOOTER = "-----END PRIVATE KEY-----"
fun getX509Cert(certificate: String): X509Certificate {
val targetStream: InputStream = ByteArrayInputStream(certificate.toByteArray())
return CertificateFactory.getInstance("X509").generateCertificate(targetStream) as X509Certificate
}
fun getPrivateKey(privateKey: String): PrivateKey {
return loadKey(privateKey)!!
}
fun loadKey(keyDataString: String): PrivateKey? {
if (keyDataString.contains(PKCS_1_PEM_HEADER)) {
// OpenSSL / PKCS#1 Base64 PEM encoded file

View file

@ -1,8 +1,10 @@
package mdnet.base
import io.netty.bootstrap.ServerBootstrap
import io.netty.buffer.Unpooled
import io.netty.channel.ChannelFactory
import io.netty.channel.ChannelFuture
import io.netty.channel.ChannelHandler
import io.netty.channel.ChannelHandlerContext
import io.netty.channel.ChannelInboundHandlerAdapter
import io.netty.channel.ChannelInitializer
@ -12,10 +14,16 @@ import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioServerSocketChannel
import io.netty.handler.codec.DecoderException
import io.netty.handler.codec.http.DefaultFullHttpResponse
import io.netty.handler.codec.http.HttpHeaderNames
import io.netty.handler.codec.http.HttpObjectAggregator
import io.netty.handler.codec.http.HttpResponseStatus
import io.netty.handler.codec.http.HttpServerCodec
import io.netty.handler.codec.http.HttpUtil
import io.netty.handler.codec.http.HttpVersion
import io.netty.handler.ssl.OptionalSslHandler
import io.netty.handler.ssl.SslContextBuilder
import io.netty.handler.ssl.SslHandler
import io.netty.handler.stream.ChunkedWriteHandler
import io.netty.handler.traffic.GlobalTrafficShapingHandler
import io.netty.handler.traffic.TrafficCounter
@ -24,14 +32,58 @@ import org.http4k.server.Http4kChannelHandler
import org.http4k.server.Http4kServer
import org.http4k.server.ServerConfig
import org.slf4j.LoggerFactory
import java.io.ByteArrayInputStream
import java.io.IOException
import java.io.InputStream
import java.net.InetSocketAddress
import java.nio.charset.StandardCharsets
import java.security.PrivateKey
import java.security.cert.CertificateFactory
import java.security.cert.X509Certificate
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.atomic.AtomicReference
import javax.net.ssl.SSLException
private val LOGGER = LoggerFactory.getLogger("Application")
@ChannelHandler.Sharable
class ConnectionCounter : ChannelInboundHandlerAdapter() {
private val connections = AtomicInteger()
override fun channelActive(ctx: ChannelHandlerContext) {
val sslHandler = ctx.pipeline()[SslHandler::class.java]
if (sslHandler != null) {
sslHandler.handshakeFuture().addListener {
handleConnection(ctx)
}
} else {
handleConnection(ctx)
}
}
private fun handleConnection(ctx: ChannelHandlerContext) {
if (connections.incrementAndGet() <= Constants.MAX_CONCURRENT_CONNECTIONS) {
super.channelActive(ctx)
} else {
val response = Unpooled.copiedBuffer(Constants.OVERLOADED_MESSAGE, StandardCharsets.UTF_8)
val res =
DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.SERVICE_UNAVAILABLE, response)
res.headers().set(HttpHeaderNames.CONTENT_TYPE, "text/html; charset=UTF-8")
HttpUtil.setContentLength(res, response.readableBytes().toLong())
ctx.writeAndFlush(res)
ctx.close()
}
}
override fun channelInactive(ctx: ChannelHandlerContext) {
super.channelInactive(ctx)
connections.decrementAndGet()
}
}
class Netty(private val tls: ServerSettings.TlsCert, private val clientSettings: ClientSettings, private val stats: AtomicReference<Statistics>) : ServerConfig {
override fun toServer(httpHandler: HttpHandler): Http4kServer = object : Http4kServer {
private val masterGroup = NioEventLoopGroup()
@ -46,9 +98,14 @@ class Netty(private val tls: ServerSettings.TlsCert, private val clientSettings:
counter.resetCumulativeTime()
}
}
private val limiter = ConnectionCounter();
override fun start(): Http4kServer = apply {
val sslContext = SslContextBuilder.forServer(getPrivateKey(tls.privateKey), getX509Cert(tls.certificate)).build()
val (mainCert, chainCert) = getX509Certs(tls.certificate);
val sslContext = SslContextBuilder
.forServer(getPrivateKey(tls.privateKey), mainCert, chainCert)
.protocols("TLSv1.3", "TLSv.1.2", "TLSv.1.1", "TLSv.1.0")
.build()
val bootstrap = ServerBootstrap()
bootstrap.group(masterGroup, workerGroup)
@ -57,6 +114,7 @@ class Netty(private val tls: ServerSettings.TlsCert, private val clientSettings:
public override fun initChannel(ch: SocketChannel) {
ch.pipeline().addLast("ssl", OptionalSslHandler(sslContext))
ch.pipeline().addLast("limiter", limiter)
ch.pipeline().addLast("codec", HttpServerCodec())
ch.pipeline().addLast("aggregator", HttpObjectAggregator(65536))
@ -98,3 +156,12 @@ class Netty(private val tls: ServerSettings.TlsCert, private val clientSettings:
override fun port(): Int = if (clientSettings.clientPort > 0) clientSettings.clientPort else address.port
}
}
fun getX509Certs(certificates: String): Pair<X509Certificate, X509Certificate> {
val targetStream: InputStream = ByteArrayInputStream(certificates.toByteArray())
return (CertificateFactory.getInstance("X509").generateCertificate(targetStream) as X509Certificate) to (CertificateFactory.getInstance("X509").generateCertificate(targetStream) as X509Certificate)
}
fun getPrivateKey(privateKey: String): PrivateKey {
return loadKey(privateKey)!!
}