package se.lesc.stackoverflow.parallelize; import java.io.DataInputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.io.PipedInputStream; import java.io.PipedOutputStream; import java.nio.ByteBuffer; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.SynchronousQueue; import java.util.zip.Adler32; /** Simulator of how to "Parallelize tasks but preserving input order in output". */ public class MessageSteamDecoder { private static final boolean userSerialDecoder = false; public static void main(String args[]) throws IOException { PipedInputStream in = new PipedInputStream(1024*1024); //Feeder thread to the system new Thread(new MessageFeeder(new PipedOutputStream(in)), "Message feeder").start(); //Output queue of the system BlockingQueue outputQueue = new ArrayBlockingQueue(1000); //Handler of the output of the system new Thread(new OutputHandler(outputQueue), "Output Handler").start(); //Choice of decoder if (userSerialDecoder) { new Thread(new SerialDecoder(in, outputQueue), "Serial Decoder").start(); } else { new Thread(new ThreadPoolDecoder(in, outputQueue), "Thread Pool Decoder").start(); } System.out.println("Started."); } /** Receiver of decoded messages. Prints statistics every 10th second. */ private static class OutputHandler implements Runnable { private final BlockingQueue outputQueue; private int nextExpectedMessageNumber = 1; private int numberOfMessagesPer10Sec = 0; private long nextPointInTimeToPrintSpeed = System.currentTimeMillis() + 10000; public OutputHandler(BlockingQueue outputQueue) { this.outputQueue = outputQueue; } public void run() { while (true) { try { handleMessage(outputQueue.take()); } catch (Exception e) { e.printStackTrace(); } } } public void handleMessage(Message message) { if (System.currentTimeMillis() > nextPointInTimeToPrintSpeed) { System.out.println(Math.round(numberOfMessagesPer10Sec/10.0/1000) + " K msg/s"); numberOfMessagesPer10Sec = 0; nextPointInTimeToPrintSpeed += 10000; } if (message.getMessageNumber() != nextExpectedMessageNumber) { System.err.println("Expected #" + nextExpectedMessageNumber + " but got #" + message.getMessageNumber()); System.exit(1); //Error in program } numberOfMessagesPer10Sec++; nextExpectedMessageNumber++; } } /** Feeds "messages" to an OutputStream as fast as possible */ private static class MessageFeeder implements Runnable { private final ByteBuffer messageBuffer = ByteBuffer.wrap(new byte[Message.MESSAGE_SIZE]); private final OutputStream out; MessageFeeder(OutputStream out) { this.out = out; } public void run() { int messageNumber = 0; while(true) { try { messageNumber++; messageBuffer.putInt(0, messageNumber); out.write(messageBuffer.array()); } catch (IOException e) { e.printStackTrace(); } } } } /** A Message representation. To make it easy a message is 100 bytes large */ private static class Message { private static final int MESSAGE_SIZE = 100; private final byte[] messageBytes; private String decodedValue; private int messageNumber; public Message(byte[] messageBytes) { this.messageBytes = messageBytes; } /** Decodes a message. Might take some time. Mostly dummy code */ public void decode() { ByteBuffer messageBuffer = ByteBuffer.wrap(messageBytes); messageNumber = messageBuffer.getInt(0); decodedValue = Integer.toHexString(messageNumber); Adler32 checkSumCalculator = new Adler32(); checkSumCalculator.update(messageBytes); long checksum = checkSumCalculator.getValue(); decodedValue = "Decoded value: " + decodedValue + ", checksum: " + checksum; } public int getMessageNumber() { return messageNumber; } public static Message readMessage(DataInputStream in) throws IOException { byte[] messageBytes = new byte[MESSAGE_SIZE]; in.readFully(messageBytes); return new Message(messageBytes); } } /** The simplest form of Decoder that only uses one thread */ private static class SerialDecoder implements Runnable { private final DataInputStream in; private final BlockingQueue outputQueue; public SerialDecoder(InputStream in, BlockingQueue outputQueue) { this.in = new DataInputStream(in); this.outputQueue = outputQueue; } @Override public void run() { while(true) { try { Message message = Message.readMessage(in); message.decode(); outputQueue.put(message); } catch (Exception e) { e.printStackTrace(); } } } } /** Naive parallelizing implementation using a thread pool */ private static class ThreadPoolDecoder implements Runnable { private static final int NUMBER_OF_THREADS = 12; private final DataInputStream in; private final BlockingQueue outputQueue; private ExecutorService threadPool; private int currentWorkerIndex = 0; private DecoderWorker[] workers; public ThreadPoolDecoder(InputStream in, BlockingQueue outputQueue) { this.in = new DataInputStream(in); this.outputQueue = outputQueue; threadPool = Executors.newFixedThreadPool(NUMBER_OF_THREADS); workers = new DecoderWorker[NUMBER_OF_THREADS]; for (int i = 0; i < NUMBER_OF_THREADS ; i++) { workers[i] = new DecoderWorker(); threadPool.submit(workers[i]); } } public void run() { int numberOfMessagesProcessed = 0; while (true) { try { Message message = Message.readMessage(in); workers[currentWorkerIndex].inQueue.put(message); currentWorkerIndex++; if (currentWorkerIndex == NUMBER_OF_THREADS) { currentWorkerIndex = 0; } numberOfMessagesProcessed++; if (numberOfMessagesProcessed >= NUMBER_OF_THREADS) { outputQueue.put(workers[currentWorkerIndex].outQueue.take()); } } catch (Exception e) { e.printStackTrace(); } } } /** Worker used to decode messages */ private class DecoderWorker implements Runnable { SynchronousQueue inQueue = new SynchronousQueue(); SynchronousQueue outQueue = new SynchronousQueue(); public void run() { while (true) { try { Message message = inQueue.take(); message.decode(); outQueue.put(message); } catch (InterruptedException e) { e.printStackTrace(); } } } } } }