Advertisement
Guest User

Untitled

a guest
Oct 23rd, 2011
132
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Scala 6.21 KB | None | 0 0
  1. package org.aaron.scala.netty.proxy
  2.  
  3. import java.net.InetSocketAddress
  4. import java.util.concurrent.Executors
  5.  
  6. import org.jboss.netty.bootstrap.ClientBootstrap
  7. import org.jboss.netty.bootstrap.ServerBootstrap
  8. import org.jboss.netty.buffer.ChannelBuffers
  9. import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
  10. import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory
  11. import org.jboss.netty.channel.socket.ClientSocketChannelFactory
  12. import org.jboss.netty.channel.Channel
  13. import org.jboss.netty.channel.ChannelFuture
  14. import org.jboss.netty.channel.ChannelFutureListener
  15. import org.jboss.netty.channel.ChannelHandlerContext
  16. import org.jboss.netty.channel.ChannelPipeline
  17. import org.jboss.netty.channel.ChannelPipelineFactory
  18. import org.jboss.netty.channel.ChannelStateEvent
  19. import org.jboss.netty.channel.Channels
  20. import org.jboss.netty.channel.ExceptionEvent
  21. import org.jboss.netty.channel.MessageEvent
  22. import org.jboss.netty.channel.SimpleChannelUpstreamHandler
  23. import org.jboss.netty.logging.InternalLoggerFactory
  24. import org.jboss.netty.logging.Slf4JLoggerFactory
  25. import org.jboss.netty.util.Version
  26.  
  27. import com.weiglewilczek.slf4s.Logger
  28.  
  29. class ScalaNettyProxy(
  30.   val localAddressPortString: String,
  31.   val remoteAddressPortString: String) {
  32.  
  33.   val log = Logger("ScalaNettyProxy")
  34.  
  35.   class RemoteChannelHandler(
  36.     val clientChannel: Channel)
  37.     extends SimpleChannelUpstreamHandler {
  38.  
  39.     override def channelOpen(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
  40.       log.info("remote channel open " + e.getChannel)
  41.     }
  42.  
  43.     override def exceptionCaught(ctx: ChannelHandlerContext, e: ExceptionEvent) {
  44.       log.warn("remote channel exception caught " + e.getChannel, e.getCause)
  45.       e.getChannel.close
  46.     }
  47.  
  48.     override def channelClosed(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
  49.       log.info("remote channel closed " + e.getChannel)
  50.       closeOnFlush(clientChannel)
  51.     }
  52.  
  53.     override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent) {
  54.       clientChannel.write(e.getMessage)
  55.     }
  56.  
  57.   }
  58.  
  59.   class ClientChannelHandler(
  60.     val remoteAddress: InetSocketAddress,
  61.     clientSocketChannelFactory: ClientSocketChannelFactory)
  62.     extends SimpleChannelUpstreamHandler {
  63.  
  64.     @volatile
  65.     var remoteChannel: Channel = null
  66.  
  67.     override def channelOpen(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
  68.       val clientChannel = e.getChannel
  69.       log.info("client channel open " + clientChannel)
  70.       clientChannel.setReadable(false)
  71.  
  72.       val clientBootstrap = new ClientBootstrap(
  73.         clientSocketChannelFactory)
  74.       clientBootstrap.setOption("connectTimeoutMillis", 1000)
  75.       clientBootstrap.getPipeline.addLast("handler",
  76.         new RemoteChannelHandler(clientChannel))
  77.  
  78.       val connectFuture = clientBootstrap
  79.         .connect(remoteAddress)
  80.       remoteChannel = connectFuture.getChannel
  81.       connectFuture.addListener(new ChannelFutureListener {
  82.         override def operationComplete(future: ChannelFuture) {
  83.           if (future.isSuccess) {
  84.             log.info("remote channel connect success "
  85.               + remoteChannel)
  86.             clientChannel.setReadable(true)
  87.           } else {
  88.             log.info("remote channel connect failure "
  89.               + remoteChannel)
  90.             clientChannel.close
  91.           }
  92.         }
  93.       })
  94.     }
  95.  
  96.     override def exceptionCaught(ctx: ChannelHandlerContext, e: ExceptionEvent) {
  97.       log.info("client channel exception caught " + e.getChannel,
  98.         e.getCause)
  99.       e.getChannel.close
  100.     }
  101.  
  102.     override def channelClosed(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
  103.       log.info("client channel closed " + e.getChannel)
  104.       val remoteChannelCopy = remoteChannel
  105.       if (remoteChannelCopy != null) {
  106.         closeOnFlush(remoteChannelCopy)
  107.       }
  108.     }
  109.  
  110.     override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent) {
  111.       val remoteChannelCopy = remoteChannel
  112.       if (remoteChannelCopy != null) {
  113.         remoteChannelCopy.write(e.getMessage)
  114.       }
  115.     }
  116.  
  117.   }
  118.  
  119.   class ProxyPipelineFactory(
  120.     val remoteAddress: InetSocketAddress,
  121.     val clientSocketChannelFactory: ClientSocketChannelFactory)
  122.     extends ChannelPipelineFactory {
  123.  
  124.     override def getPipeline: ChannelPipeline =
  125.       Channels.pipeline(new ClientChannelHandler(remoteAddress,
  126.         clientSocketChannelFactory))
  127.  
  128.   }
  129.  
  130.   def start {
  131.     InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory)
  132.  
  133.     val localAddress = parseAddressPortString(localAddressPortString).get
  134.     val remoteAddress = parseAddressPortString(remoteAddressPortString).get
  135.  
  136.     val executor = Executors.newCachedThreadPool
  137.  
  138.     val serverBootstrap = new ServerBootstrap(
  139.       new NioServerSocketChannelFactory(
  140.         executor, executor))
  141.  
  142.     val clientSocketChannelFactory = new NioClientSocketChannelFactory(
  143.       executor, executor)
  144.  
  145.     serverBootstrap.setPipelineFactory(new ProxyPipelineFactory(
  146.       remoteAddress, clientSocketChannelFactory))
  147.  
  148.     serverBootstrap.setOption("reuseAddress", true)
  149.  
  150.     val serverChannel = serverBootstrap.bind(localAddress)
  151.  
  152.     log.info("Netty version " + Version.ID)
  153.     log.info("listening on " + serverChannel.getLocalAddress)
  154.     log.info("remote address " + remoteAddress)
  155.   }
  156.  
  157.   private def parseAddressPortString(
  158.     addressPortString: String): Option[InetSocketAddress] = {
  159.     val addressPortRE = """^(.*):(.*)$""".r
  160.     addressPortString match {
  161.       case addressPortRE(addressString, portString) =>
  162.         new Some(new InetSocketAddress(addressString, portString.toInt))
  163.  
  164.       case _ => None
  165.     }
  166.   }
  167.  
  168.   private def closeOnFlush(channel: Channel) {
  169.     if (channel.isConnected) {
  170.       channel.write(ChannelBuffers.EMPTY_BUFFER).addListener(
  171.         ChannelFutureListener.CLOSE)
  172.     } else {
  173.       channel.close
  174.     }
  175.   }
  176.  
  177. }
  178.  
  179. object ScalaNettyProxyMain {
  180.  
  181.   val log = Logger("ScalaNettyProxyMain")
  182.  
  183.   def main(args: Array[String]) {
  184.     args match {
  185.       case Array(addressString, portString) =>
  186.         new ScalaNettyProxy(addressString, portString).start
  187.  
  188.       case _ =>
  189.         log.info("Usage: <local address> <remote address>")
  190.         exit(1)
  191.     }
  192.   }
  193.  
  194. }
  195.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement