Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package org.aaron.scala.netty.proxy
- import java.net.InetSocketAddress
- import java.util.concurrent.Executors
- import org.jboss.netty.bootstrap.ClientBootstrap
- import org.jboss.netty.bootstrap.ServerBootstrap
- import org.jboss.netty.buffer.ChannelBuffers
- import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
- import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory
- import org.jboss.netty.channel.socket.ClientSocketChannelFactory
- import org.jboss.netty.channel.Channel
- import org.jboss.netty.channel.ChannelFuture
- import org.jboss.netty.channel.ChannelFutureListener
- import org.jboss.netty.channel.ChannelHandlerContext
- import org.jboss.netty.channel.ChannelPipeline
- import org.jboss.netty.channel.ChannelPipelineFactory
- import org.jboss.netty.channel.ChannelStateEvent
- import org.jboss.netty.channel.Channels
- import org.jboss.netty.channel.ExceptionEvent
- import org.jboss.netty.channel.MessageEvent
- import org.jboss.netty.channel.SimpleChannelUpstreamHandler
- import org.jboss.netty.logging.InternalLoggerFactory
- import org.jboss.netty.logging.Slf4JLoggerFactory
- import org.jboss.netty.util.Version
- import com.weiglewilczek.slf4s.Logger
- class ScalaNettyProxy(
- val localAddressPortString: String,
- val remoteAddressPortString: String) {
- val log = Logger("ScalaNettyProxy")
- class RemoteChannelHandler(
- val clientChannel: Channel)
- extends SimpleChannelUpstreamHandler {
- override def channelOpen(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
- log.info("remote channel open " + e.getChannel)
- }
- override def exceptionCaught(ctx: ChannelHandlerContext, e: ExceptionEvent) {
- log.warn("remote channel exception caught " + e.getChannel, e.getCause)
- e.getChannel.close
- }
- override def channelClosed(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
- log.info("remote channel closed " + e.getChannel)
- closeOnFlush(clientChannel)
- }
- override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent) {
- clientChannel.write(e.getMessage)
- }
- }
- class ClientChannelHandler(
- val remoteAddress: InetSocketAddress,
- clientSocketChannelFactory: ClientSocketChannelFactory)
- extends SimpleChannelUpstreamHandler {
- @volatile
- var remoteChannel: Channel = null
- override def channelOpen(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
- val clientChannel = e.getChannel
- log.info("client channel open " + clientChannel)
- clientChannel.setReadable(false)
- val clientBootstrap = new ClientBootstrap(
- clientSocketChannelFactory)
- clientBootstrap.setOption("connectTimeoutMillis", 1000)
- clientBootstrap.getPipeline.addLast("handler",
- new RemoteChannelHandler(clientChannel))
- val connectFuture = clientBootstrap
- .connect(remoteAddress)
- remoteChannel = connectFuture.getChannel
- connectFuture.addListener(new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture) {
- if (future.isSuccess) {
- log.info("remote channel connect success "
- + remoteChannel)
- clientChannel.setReadable(true)
- } else {
- log.info("remote channel connect failure "
- + remoteChannel)
- clientChannel.close
- }
- }
- })
- }
- override def exceptionCaught(ctx: ChannelHandlerContext, e: ExceptionEvent) {
- log.info("client channel exception caught " + e.getChannel,
- e.getCause)
- e.getChannel.close
- }
- override def channelClosed(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
- log.info("client channel closed " + e.getChannel)
- val remoteChannelCopy = remoteChannel
- if (remoteChannelCopy != null) {
- closeOnFlush(remoteChannelCopy)
- }
- }
- override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent) {
- val remoteChannelCopy = remoteChannel
- if (remoteChannelCopy != null) {
- remoteChannelCopy.write(e.getMessage)
- }
- }
- }
- class ProxyPipelineFactory(
- val remoteAddress: InetSocketAddress,
- val clientSocketChannelFactory: ClientSocketChannelFactory)
- extends ChannelPipelineFactory {
- override def getPipeline: ChannelPipeline =
- Channels.pipeline(new ClientChannelHandler(remoteAddress,
- clientSocketChannelFactory))
- }
- def start {
- InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory)
- val localAddress = parseAddressPortString(localAddressPortString).get
- val remoteAddress = parseAddressPortString(remoteAddressPortString).get
- val executor = Executors.newCachedThreadPool
- val serverBootstrap = new ServerBootstrap(
- new NioServerSocketChannelFactory(
- executor, executor))
- val clientSocketChannelFactory = new NioClientSocketChannelFactory(
- executor, executor)
- serverBootstrap.setPipelineFactory(new ProxyPipelineFactory(
- remoteAddress, clientSocketChannelFactory))
- serverBootstrap.setOption("reuseAddress", true)
- val serverChannel = serverBootstrap.bind(localAddress)
- log.info("Netty version " + Version.ID)
- log.info("listening on " + serverChannel.getLocalAddress)
- log.info("remote address " + remoteAddress)
- }
- private def parseAddressPortString(
- addressPortString: String): Option[InetSocketAddress] = {
- val addressPortRE = """^(.*):(.*)$""".r
- addressPortString match {
- case addressPortRE(addressString, portString) =>
- new Some(new InetSocketAddress(addressString, portString.toInt))
- case _ => None
- }
- }
- private def closeOnFlush(channel: Channel) {
- if (channel.isConnected) {
- channel.write(ChannelBuffers.EMPTY_BUFFER).addListener(
- ChannelFutureListener.CLOSE)
- } else {
- channel.close
- }
- }
- }
- object ScalaNettyProxyMain {
- val log = Logger("ScalaNettyProxyMain")
- def main(args: Array[String]) {
- args match {
- case Array(addressString, portString) =>
- new ScalaNettyProxy(addressString, portString).start
- case _ =>
- log.info("Usage: <local address> <remote address>")
- exit(1)
- }
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement