Advertisement
Guest User

Goroutines in Kotlin

a guest
Sep 23rd, 2012
346
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 10.16 KB | None | 0 0
  1. import java.net.*
  2. import java.util.*
  3. import java.util.concurrent.*
  4. import java.util.concurrent.atomic.AtomicLong
  5. import kotlin.test.*
  6. import kotlin.util.measureTimeMillis
  7.  
  8. /**
  9.  * Google Go's goroutines implemented in Kotlin.
  10.  * @author Ioannis Tsakpinis
  11.  */
  12.  
  13. public trait Channel<T>
  14.  
  15. public trait ChannelIn<T>: Channel<T>, Iterable<T> {
  16.     public val take: T
  17.     public val poll: T?
  18.     public fun poll(timeout: Long, unit: TimeUnit): T?
  19.  
  20.     public val peek: T?
  21. }
  22.  
  23. public trait ChannelOut<T>: Channel<T> {
  24.     public fun offer(value: T, timeout: Long, unit: TimeUnit): Boolean
  25.     public fun offer(value: T): Boolean
  26.     public fun send(value: T): Unit
  27. }
  28.  
  29. /*
  30.  * Unbuffered when bufferSize == 0, buffered when bufferSize > 0.
  31.  */
  32. public class ChannelInOut<T> (bufferSize: Int, fair: Boolean): ChannelIn<T>, ChannelOut<T> {
  33.  
  34.     private val buffer: BlockingQueue<T> = if ( 0 < bufferSize )
  35.         ArrayBlockingQueue<T>(bufferSize, fair)
  36.     else
  37.         SynchronousQueue<T>(fair)
  38.  
  39.     public val input: ChannelIn<T> get() = object : ChannelIn<T> by this {
  40.     }
  41.     public val output: ChannelOut<T> get() = object : ChannelOut<T> by this {
  42.     }
  43.  
  44.     // Blocking Queue interface
  45.  
  46.     public override val take: T get() = buffer.take()
  47.     public override val poll: T? get() = buffer.poll()
  48.     public override fun poll(timeout: Long, unit: TimeUnit): T? = buffer.poll(timeout, unit)
  49.  
  50.     public override val peek: T? get() = buffer.peek()
  51.  
  52.     public override fun offer(value: T, timeout: Long, unit: TimeUnit): Boolean = buffer.offer(value, timeout, unit)
  53.     public override fun offer(value: T): Boolean = buffer.offer(value)
  54.     public override fun send(value: T): Unit = buffer.put(value)
  55.  
  56.     // Iterable interface
  57.  
  58.     // always empty for synchronous queue
  59.     public override fun iterator(): Iterator<T> = object : Iterator<T> {
  60.         public override fun next(): T = take
  61.         public override fun hasNext(): Boolean = 0 < buffer.size()
  62.     }
  63.  
  64. }
  65.  
  66. public fun <T> channel(bufferSize: Int = 0, fair: Boolean = false): ChannelInOut<T> = ChannelInOut<T>(bufferSize, fair)
  67.  
  68. val GOROUTINE_THREADS = Runtime.getRuntime().availableProcessors() * 4
  69. val THREAD_POOL = Executors.newFixedThreadPool(GOROUTINE_THREADS)
  70.  
  71. public fun go(exec: () -> Unit): Future<*> {
  72.     return THREAD_POOL.submit(object: Runnable {
  73.         public override fun run(): Unit = exec()
  74.     })
  75. }
  76.  
  77. public fun select(vararg cases: Case<*>): Unit = select(default = null, cases = *cases)
  78. public fun select(default: (()->Unit)?, vararg cases: Case<*>) {
  79.     if ( cases.size == 0 ) {
  80.         if ( default == null )
  81.             throw IllegalArgumentException("No case or default specified.")
  82.  
  83.         default()
  84.         return
  85.     }
  86.  
  87.     val pollOrder = IntArray(cases.size)
  88.     for ( i in cases.indices ) pollOrder[i] = i
  89.  
  90.     // Randomize poll order
  91.     val rand = ThreadLocalRandom.current()
  92.     for ( i in 1..cases.lastIndex ) {
  93.         val o = pollOrder[i]
  94.         val j = rand.nextInt(i + 1)
  95.         pollOrder[i] = pollOrder[j]
  96.         pollOrder[j] = o
  97.     }
  98.  
  99.     while ( true ) {
  100.         for ( i in pollOrder ) {
  101.             if ( cases[i].tryAction() )
  102.                 return
  103.         }
  104.  
  105.         if ( default != null ) {
  106.             default()
  107.             return
  108.         }
  109.  
  110.         Thread.yield()
  111.     }
  112. }
  113.  
  114. public trait Case<T> {
  115.     fun tryAction(): Boolean
  116. }
  117.  
  118. public class CaseOut<T>(val channel: ChannelIn<T>, val action: (T)->Unit): Case<T> {
  119.     override fun tryAction(): Boolean {
  120.         val value = channel.poll
  121.         if ( value == null )
  122.             return false
  123.  
  124.         action(value)
  125.         return true
  126.     }
  127. }
  128.  
  129. public class CaseIn<T>(val channel: ChannelOut<T>, val value: T, val action: (()->Unit)?): Case<T> {
  130.     override fun tryAction(): Boolean {
  131.         val result = channel.offer(value)
  132.         if ( result && action != null )
  133.             action()
  134.  
  135.         return result
  136.     }
  137. }
  138.  
  139. public fun <T> ChannelIn<T>.case(action: (T)->Unit): CaseOut<T> = CaseOut<T>(this, action)
  140. public fun <T> ChannelOut<T>.case(value: T, action: (()->Unit)? = null): CaseIn<T> = CaseIn<T>(this, value, action)
  141.  
  142. /*
  143.  * ----------------------------------------------------------------------
  144.  * ------------------------------[ EXAMPLES ]----------------------------
  145.  * ----------------------------------------------------------------------
  146.  */
  147.  
  148. fun example1Sort() {
  149.     val SIZE = 1024 * 1024
  150.     val list = ArrayList<Int>(SIZE)
  151.  
  152.     // generate random list
  153.     val rand = ThreadLocalRandom.current()
  154.     for ( i in SIZE.indices )
  155.         list add rand.nextInt()
  156.  
  157.     val total = measureTimeMillis {
  158.         val c = channel<Int>()
  159.  
  160.         // Spawn goroutine to sort
  161.         go {
  162.             val time = measureTimeMillis {
  163.                 Collections.sort(list)
  164.             }
  165.             println(" sort took: $time ms")
  166.             c send 1
  167.         }
  168.  
  169.         // Do something else while sorting
  170.         Thread.sleep(300)
  171.  
  172.         // Wait for sort to complete
  173.         c.take
  174.     }
  175.  
  176.     println("total took: $total ms")
  177. }
  178.  
  179. fun example2LimitThroughput() {
  180.     class Request(val id: Int, val done: ChannelOut<Int>) {
  181.         fun done(): Unit = done send id
  182.     }
  183.  
  184.     val sem = channel<Int>(5) // Throttle request handling to 5 at the same time
  185.  
  186.     fun handle(val r: Request) {
  187.         println("Pre-handle: ${r.id}")
  188.         sem send 1
  189.         println("\tHandling: ${r.id}")
  190.         Thread.sleep(200)
  191.         r.done()
  192.         sem.take
  193.         println("Post-handle: ${r.id}")
  194.     }
  195.  
  196.     val quitServer = channel<Int>() // A msg will be send here when it's time to stop the server
  197.  
  198.     fun serve(queue: ChannelIn<Request>) {
  199.         var run = true
  200.         while ( run ) {
  201.             // Select between queue and quit channels
  202.             select (
  203.                 queue case {
  204.                     // Spawn handle goroutine
  205.                     go { handle(it) }
  206.                 },
  207.                 quitServer case { run = false }
  208.             )
  209.         }
  210.     }
  211.  
  212.     val queue = channel<Request>()
  213.  
  214.     // Spawn server goroutine
  215.     go {
  216.         serve(queue)
  217.     }
  218.  
  219.     val done = channel<Int>()
  220.  
  221.     val time = measureTimeMillis {
  222.         // Send 100 simultaneous requests
  223.         var id = 0
  224.         while ( id < 100 ) {
  225.             queue send Request(id++, done)
  226.         }
  227.  
  228.         // Wait for request completion
  229.         for ( i in 1..100 )
  230.             done.take
  231.  
  232.         // Stop the server
  233.         quitServer send 1
  234.     }
  235.  
  236.     // 100 requests / 5 = 20 * 200 ms = 4000ms
  237.     assertTrue(4000 <= time)
  238.     println("Took: $time ms")
  239. }
  240.  
  241. fun example3Compute() {
  242.     // Serial
  243.     fun vectorOp(vector: FloatArray, from: Int, to: Int, op: (Float)->Float) {
  244.         for ( i in from..to - 1 )
  245.             vector[i] = op(vector[i])
  246.     }
  247.  
  248.     // Channel-aware
  249.     fun vectorOp(vector: FloatArray, from: Int, to: Int, c: ChannelOut<Int>, op: (Float)->Float) {
  250.         vectorOp(vector, from, to, op)
  251.         c send 1
  252.     }
  253.  
  254.     val vector = FloatArray(1024 * 1024 * 2)
  255.     for ( i in vector.indices )
  256.         vector[i] = i.toFloat()
  257.  
  258.     // Create a copy to avoid speedup due to caching
  259.     val vector2 = vector.copyOf()
  260.  
  261.     val op = {(it: Float)-> Math.sin(it.toDouble()).toFloat() }
  262.  
  263.     val single = measureTimeMillis {
  264.         vectorOp(vector, 0, vector.size, op)
  265.     }
  266.  
  267.     val multi = measureTimeMillis {
  268.         val BATCHES = GOROUTINE_THREADS * 2
  269.         val c = channel<Int>(BATCHES)
  270.         val batchSize = vector2.size / BATCHES
  271.         for ( i in 1..BATCHES ) {
  272.             // Spawn a goroutine for each batch
  273.             go {
  274.                 vectorOp(vector2, (i - 1) * batchSize, i * batchSize, c, op)
  275.             }
  276.         }
  277.  
  278.         // Wait for all batches to complete
  279.         for ( i in 1..BATCHES )
  280.             c.take
  281.     }
  282.  
  283.     println("Single: $single ms")
  284.     println(" Multi: $multi ms")
  285.     println("  Practical speedup: ${single.toFloat() / multi.toFloat()}")
  286.     println("Theoritical speedup: ${Runtime.getRuntime().availableProcessors().toFloat()}")
  287. }
  288.  
  289. fun example4LeakyBuffer() {
  290.     val FREE_LIST_SIZE = 80
  291.     val freeList = channel<Array<Int>>(FREE_LIST_SIZE)
  292.     for ( i in 1..FREE_LIST_SIZE )
  293.         freeList send Array<Int>(128) { i }
  294.  
  295.     val serverChan = channel<Array<Int>>(100)
  296.     val quitServer = channel<Int>() // A msg will be send here when it's time to stop the server
  297.  
  298.     val processed = AtomicLong()
  299.  
  300.     // server
  301.     go {
  302.         var run = true
  303.         while ( run ) {
  304.             // Select between queue and quit channels
  305.             select (
  306.                 serverChan case { buffer ->
  307.                     // Process
  308.                     Thread.sleep(1)
  309.  
  310.                     //freeList offer it
  311.                     select ({
  312.                                 // Do nothing, buffer is released to GC
  313.                             },
  314.                             freeList case buffer
  315.                     )
  316.  
  317.                     processed.incrementAndGet()
  318.                 },
  319.                 quitServer case { run = false }
  320.             )
  321.         }
  322.     }
  323.  
  324.     // client
  325.     var newCount = 0
  326.     var reuseCount = 0
  327.     for ( i in 1..1000 ) {
  328.         var buffer: Array<Int>? = null
  329.  
  330.         select (
  331.             {
  332.                 // Allocate new buffer
  333.                 newCount++
  334.                 buffer = Array<Int>(32) { i }
  335.             },
  336.             freeList case {
  337.                 // Reuse buffer
  338.                 reuseCount++
  339.                 buffer = it
  340.             }
  341.         )
  342.  
  343.         // ...
  344.  
  345.         serverChan send buffer!!
  346.     }
  347.  
  348.     // Wait for processing to complete
  349.     while ( processed.intValue() < 1000 ) {
  350.     }
  351.  
  352.     quitServer send 1
  353.  
  354.     println("REUSED: $reuseCount")
  355.     println("   NEW: $newCount")
  356.     println(" RATIO: ${reuseCount.toFloat() / newCount.toFloat()}")
  357. }
  358.  
  359. fun example5URLs() {
  360.     fun count(name: String, url: URL): Pair<String, Long> {
  361.         var contentSize = 0
  362.         val time = measureTimeMillis {
  363.             val connection = url.openConnection()
  364.             connection.connect()
  365.  
  366.             val reader = connection.getInputStream().reader()
  367.             val html = connection.getInputStream().reader().readText()
  368.             contentSize = html.size
  369.             reader.close()
  370.         }
  371.  
  372.         return Pair("$name: $contentSize characters - $time ms", time)
  373.     }
  374.  
  375.     fun doSerial(languages: Map<String, String>) {
  376.         var serialTime = 0.toLong()
  377.         for ( (name, url) in languages ) {
  378.             val (msg, time) = count(name, URL(url))
  379.             println(msg)
  380.             serialTime += time
  381.         }
  382.  
  383.         println("TOTAL SERIAL: $serialTime")
  384.  
  385.     }
  386.     fun doParallel(languages: Map<String, String>) {
  387.         val c = channel<Pair<String, Long>>()
  388.         for ( (name, url) in languages )
  389.             go {
  390.                 c send count(name, URL(url))
  391.             }
  392.  
  393.         var serialTime = AtomicLong()
  394.  
  395.         languages.size.indices.forEach {
  396.             val (msg, time) = c.take
  397.             println(msg)
  398.             serialTime += time
  399.         }
  400.  
  401.         println("TOTAL SERIAL: $serialTime")
  402.     }
  403.  
  404.     val total = measureTimeMillis {
  405.         inline fun <K, V> K._(value: V): Pair<K, V> = Pair<K, V>(this, value)
  406.  
  407.         val languages = hashMap(
  408.             "Google Go" _ "http://golang.org",
  409.             "Java" _ "http://www.java.com/en/download/index.jsp",
  410.             "Scala" _ "http://www.scala-lang.org/",
  411.             "C++" _ "http://en.wikipedia.org/wiki/C%2B%2B"
  412.         )
  413.  
  414.         //doSerial(languages)
  415.         doParallel(languages)
  416.     }
  417.  
  418.     println(" TOTAL TIME: $total")
  419. }
  420.  
  421. fun main(args: Array<String>) {
  422.     //example1Sort()
  423.     //example2LimitThroughput()
  424.     //example3Compute()
  425.     //example4LeakyBuffer()
  426.     example5URLs()
  427.  
  428.     THREAD_POOL.shutdown()
  429.     THREAD_POOL.awaitTermination(5, TimeUnit.SECONDS)
  430. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement