Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import java.net.*
- import java.util.*
- import java.util.concurrent.*
- import java.util.concurrent.atomic.AtomicLong
- import kotlin.test.*
- import kotlin.util.measureTimeMillis
- /**
- * Google Go's goroutines implemented in Kotlin.
- * @author Ioannis Tsakpinis
- */
- public trait Channel<T>
- public trait ChannelIn<T>: Channel<T>, Iterable<T> {
- public val take: T
- public val poll: T?
- public fun poll(timeout: Long, unit: TimeUnit): T?
- public val peek: T?
- }
- public trait ChannelOut<T>: Channel<T> {
- public fun offer(value: T, timeout: Long, unit: TimeUnit): Boolean
- public fun offer(value: T): Boolean
- public fun send(value: T): Unit
- }
- /*
- * Unbuffered when bufferSize == 0, buffered when bufferSize > 0.
- */
- public class ChannelInOut<T> (bufferSize: Int, fair: Boolean): ChannelIn<T>, ChannelOut<T> {
- private val buffer: BlockingQueue<T> = if ( 0 < bufferSize )
- ArrayBlockingQueue<T>(bufferSize, fair)
- else
- SynchronousQueue<T>(fair)
- public val input: ChannelIn<T> get() = object : ChannelIn<T> by this {
- }
- public val output: ChannelOut<T> get() = object : ChannelOut<T> by this {
- }
- // Blocking Queue interface
- public override val take: T get() = buffer.take()
- public override val poll: T? get() = buffer.poll()
- public override fun poll(timeout: Long, unit: TimeUnit): T? = buffer.poll(timeout, unit)
- public override val peek: T? get() = buffer.peek()
- public override fun offer(value: T, timeout: Long, unit: TimeUnit): Boolean = buffer.offer(value, timeout, unit)
- public override fun offer(value: T): Boolean = buffer.offer(value)
- public override fun send(value: T): Unit = buffer.put(value)
- // Iterable interface
- // always empty for synchronous queue
- public override fun iterator(): Iterator<T> = object : Iterator<T> {
- public override fun next(): T = take
- public override fun hasNext(): Boolean = 0 < buffer.size()
- }
- }
- public fun <T> channel(bufferSize: Int = 0, fair: Boolean = false): ChannelInOut<T> = ChannelInOut<T>(bufferSize, fair)
- val GOROUTINE_THREADS = Runtime.getRuntime().availableProcessors() * 4
- val THREAD_POOL = Executors.newFixedThreadPool(GOROUTINE_THREADS)
- public fun go(exec: () -> Unit): Future<*> {
- return THREAD_POOL.submit(object: Runnable {
- public override fun run(): Unit = exec()
- })
- }
- public fun select(vararg cases: Case<*>): Unit = select(default = null, cases = *cases)
- public fun select(default: (()->Unit)?, vararg cases: Case<*>) {
- if ( cases.size == 0 ) {
- if ( default == null )
- throw IllegalArgumentException("No case or default specified.")
- default()
- return
- }
- val pollOrder = IntArray(cases.size)
- for ( i in cases.indices ) pollOrder[i] = i
- // Randomize poll order
- val rand = ThreadLocalRandom.current()
- for ( i in 1..cases.lastIndex ) {
- val o = pollOrder[i]
- val j = rand.nextInt(i + 1)
- pollOrder[i] = pollOrder[j]
- pollOrder[j] = o
- }
- while ( true ) {
- for ( i in pollOrder ) {
- if ( cases[i].tryAction() )
- return
- }
- if ( default != null ) {
- default()
- return
- }
- Thread.yield()
- }
- }
- public trait Case<T> {
- fun tryAction(): Boolean
- }
- public class CaseOut<T>(val channel: ChannelIn<T>, val action: (T)->Unit): Case<T> {
- override fun tryAction(): Boolean {
- val value = channel.poll
- if ( value == null )
- return false
- action(value)
- return true
- }
- }
- public class CaseIn<T>(val channel: ChannelOut<T>, val value: T, val action: (()->Unit)?): Case<T> {
- override fun tryAction(): Boolean {
- val result = channel.offer(value)
- if ( result && action != null )
- action()
- return result
- }
- }
- public fun <T> ChannelIn<T>.case(action: (T)->Unit): CaseOut<T> = CaseOut<T>(this, action)
- public fun <T> ChannelOut<T>.case(value: T, action: (()->Unit)? = null): CaseIn<T> = CaseIn<T>(this, value, action)
- /*
- * ----------------------------------------------------------------------
- * ------------------------------[ EXAMPLES ]----------------------------
- * ----------------------------------------------------------------------
- */
- fun example1Sort() {
- val SIZE = 1024 * 1024
- val list = ArrayList<Int>(SIZE)
- // generate random list
- val rand = ThreadLocalRandom.current()
- for ( i in SIZE.indices )
- list add rand.nextInt()
- val total = measureTimeMillis {
- val c = channel<Int>()
- // Spawn goroutine to sort
- go {
- val time = measureTimeMillis {
- Collections.sort(list)
- }
- println(" sort took: $time ms")
- c send 1
- }
- // Do something else while sorting
- Thread.sleep(300)
- // Wait for sort to complete
- c.take
- }
- println("total took: $total ms")
- }
- fun example2LimitThroughput() {
- class Request(val id: Int, val done: ChannelOut<Int>) {
- fun done(): Unit = done send id
- }
- val sem = channel<Int>(5) // Throttle request handling to 5 at the same time
- fun handle(val r: Request) {
- println("Pre-handle: ${r.id}")
- sem send 1
- println("\tHandling: ${r.id}")
- Thread.sleep(200)
- r.done()
- sem.take
- println("Post-handle: ${r.id}")
- }
- val quitServer = channel<Int>() // A msg will be send here when it's time to stop the server
- fun serve(queue: ChannelIn<Request>) {
- var run = true
- while ( run ) {
- // Select between queue and quit channels
- select (
- queue case {
- // Spawn handle goroutine
- go { handle(it) }
- },
- quitServer case { run = false }
- )
- }
- }
- val queue = channel<Request>()
- // Spawn server goroutine
- go {
- serve(queue)
- }
- val done = channel<Int>()
- val time = measureTimeMillis {
- // Send 100 simultaneous requests
- var id = 0
- while ( id < 100 ) {
- queue send Request(id++, done)
- }
- // Wait for request completion
- for ( i in 1..100 )
- done.take
- // Stop the server
- quitServer send 1
- }
- // 100 requests / 5 = 20 * 200 ms = 4000ms
- assertTrue(4000 <= time)
- println("Took: $time ms")
- }
- fun example3Compute() {
- // Serial
- fun vectorOp(vector: FloatArray, from: Int, to: Int, op: (Float)->Float) {
- for ( i in from..to - 1 )
- vector[i] = op(vector[i])
- }
- // Channel-aware
- fun vectorOp(vector: FloatArray, from: Int, to: Int, c: ChannelOut<Int>, op: (Float)->Float) {
- vectorOp(vector, from, to, op)
- c send 1
- }
- val vector = FloatArray(1024 * 1024 * 2)
- for ( i in vector.indices )
- vector[i] = i.toFloat()
- // Create a copy to avoid speedup due to caching
- val vector2 = vector.copyOf()
- val op = {(it: Float)-> Math.sin(it.toDouble()).toFloat() }
- val single = measureTimeMillis {
- vectorOp(vector, 0, vector.size, op)
- }
- val multi = measureTimeMillis {
- val BATCHES = GOROUTINE_THREADS * 2
- val c = channel<Int>(BATCHES)
- val batchSize = vector2.size / BATCHES
- for ( i in 1..BATCHES ) {
- // Spawn a goroutine for each batch
- go {
- vectorOp(vector2, (i - 1) * batchSize, i * batchSize, c, op)
- }
- }
- // Wait for all batches to complete
- for ( i in 1..BATCHES )
- c.take
- }
- println("Single: $single ms")
- println(" Multi: $multi ms")
- println(" Practical speedup: ${single.toFloat() / multi.toFloat()}")
- println("Theoritical speedup: ${Runtime.getRuntime().availableProcessors().toFloat()}")
- }
- fun example4LeakyBuffer() {
- val FREE_LIST_SIZE = 80
- val freeList = channel<Array<Int>>(FREE_LIST_SIZE)
- for ( i in 1..FREE_LIST_SIZE )
- freeList send Array<Int>(128) { i }
- val serverChan = channel<Array<Int>>(100)
- val quitServer = channel<Int>() // A msg will be send here when it's time to stop the server
- val processed = AtomicLong()
- // server
- go {
- var run = true
- while ( run ) {
- // Select between queue and quit channels
- select (
- serverChan case { buffer ->
- // Process
- Thread.sleep(1)
- //freeList offer it
- select ({
- // Do nothing, buffer is released to GC
- },
- freeList case buffer
- )
- processed.incrementAndGet()
- },
- quitServer case { run = false }
- )
- }
- }
- // client
- var newCount = 0
- var reuseCount = 0
- for ( i in 1..1000 ) {
- var buffer: Array<Int>? = null
- select (
- {
- // Allocate new buffer
- newCount++
- buffer = Array<Int>(32) { i }
- },
- freeList case {
- // Reuse buffer
- reuseCount++
- buffer = it
- }
- )
- // ...
- serverChan send buffer!!
- }
- // Wait for processing to complete
- while ( processed.intValue() < 1000 ) {
- }
- quitServer send 1
- println("REUSED: $reuseCount")
- println(" NEW: $newCount")
- println(" RATIO: ${reuseCount.toFloat() / newCount.toFloat()}")
- }
- fun example5URLs() {
- fun count(name: String, url: URL): Pair<String, Long> {
- var contentSize = 0
- val time = measureTimeMillis {
- val connection = url.openConnection()
- connection.connect()
- val reader = connection.getInputStream().reader()
- val html = connection.getInputStream().reader().readText()
- contentSize = html.size
- reader.close()
- }
- return Pair("$name: $contentSize characters - $time ms", time)
- }
- fun doSerial(languages: Map<String, String>) {
- var serialTime = 0.toLong()
- for ( (name, url) in languages ) {
- val (msg, time) = count(name, URL(url))
- println(msg)
- serialTime += time
- }
- println("TOTAL SERIAL: $serialTime")
- }
- fun doParallel(languages: Map<String, String>) {
- val c = channel<Pair<String, Long>>()
- for ( (name, url) in languages )
- go {
- c send count(name, URL(url))
- }
- var serialTime = AtomicLong()
- languages.size.indices.forEach {
- val (msg, time) = c.take
- println(msg)
- serialTime += time
- }
- println("TOTAL SERIAL: $serialTime")
- }
- val total = measureTimeMillis {
- inline fun <K, V> K._(value: V): Pair<K, V> = Pair<K, V>(this, value)
- val languages = hashMap(
- "Google Go" _ "http://golang.org",
- "Java" _ "http://www.java.com/en/download/index.jsp",
- "Scala" _ "http://www.scala-lang.org/",
- "C++" _ "http://en.wikipedia.org/wiki/C%2B%2B"
- )
- //doSerial(languages)
- doParallel(languages)
- }
- println(" TOTAL TIME: $total")
- }
- fun main(args: Array<String>) {
- //example1Sort()
- //example2LimitThroughput()
- //example3Compute()
- //example4LeakyBuffer()
- example5URLs()
- THREAD_POOL.shutdown()
- THREAD_POOL.awaitTermination(5, TimeUnit.SECONDS)
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement