package se.lesc.stackoverflow.parallelize;
import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.nio.ByteBuffer;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.SynchronousQueue;
import java.util.zip.Adler32;
/** Simulator of how to "Parallelize tasks but preserving input order in output". */
public class MessageSteamDecoder {
private static final boolean userSerialDecoder = false;
public static void main(String args[]) throws IOException {
PipedInputStream in = new PipedInputStream(1024*1024);
//Feeder thread to the system
new Thread(new MessageFeeder(new PipedOutputStream(in)), "Message feeder").start();
//Output queue of the system
BlockingQueue<Message> outputQueue = new ArrayBlockingQueue<Message>(1000);
//Handler of the output of the system
new Thread(new OutputHandler(outputQueue), "Output Handler").start();
//Choice of decoder
if (userSerialDecoder) {
new Thread(new SerialDecoder(in, outputQueue), "Serial Decoder").start();
} else {
new Thread(new ThreadPoolDecoder(in, outputQueue), "Thread Pool Decoder").start();
}
System.out.println("Started.");
}
/** Receiver of decoded messages. Prints statistics every 10th second. */
private static class OutputHandler implements Runnable {
private final BlockingQueue<Message> outputQueue;
private int nextExpectedMessageNumber = 1;
private int numberOfMessagesPer10Sec = 0;
private long nextPointInTimeToPrintSpeed = System.currentTimeMillis() + 10000;
public OutputHandler(BlockingQueue<Message> outputQueue) {
this.outputQueue = outputQueue;
}
public void run() {
while (true) {
try {
handleMessage(outputQueue.take());
} catch (Exception e) {
e.printStackTrace();
}
}
}
public void handleMessage(Message message) {
if (System.currentTimeMillis() > nextPointInTimeToPrintSpeed) {
System.out.println(Math.round(numberOfMessagesPer10Sec/10.0/1000) + " K msg/s");
numberOfMessagesPer10Sec = 0;
nextPointInTimeToPrintSpeed += 10000;
}
if (message.getMessageNumber() != nextExpectedMessageNumber) {
System.err.println("Expected #" + nextExpectedMessageNumber + " but got #" +
message.getMessageNumber());
System.exit(1); //Error in program
}
numberOfMessagesPer10Sec++;
nextExpectedMessageNumber++;
}
}
/** Feeds "messages" to an OutputStream as fast as possible */
private static class MessageFeeder implements Runnable {
private final ByteBuffer messageBuffer = ByteBuffer.wrap(new byte[Message.MESSAGE_SIZE]);
private final OutputStream out;
MessageFeeder(OutputStream out) {
this.out = out;
}
public void run() {
int messageNumber = 0;
while(true) {
try {
messageNumber++;
messageBuffer.putInt(0, messageNumber);
out.write(messageBuffer.array());
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
/** A Message representation. To make it easy a message is 100 bytes large */
private static class Message {
private static final int MESSAGE_SIZE = 100;
private final byte[] messageBytes;
private String decodedValue;
private int messageNumber;
public Message(byte[] messageBytes) {
this.messageBytes = messageBytes;
}
/** Decodes a message. Might take some time. Mostly dummy code */
public void decode() {
ByteBuffer messageBuffer = ByteBuffer.wrap(messageBytes);
messageNumber = messageBuffer.getInt(0);
decodedValue = Integer.toHexString(messageNumber);
Adler32 checkSumCalculator = new Adler32();
checkSumCalculator.update(messageBytes);
long checksum = checkSumCalculator.getValue();
decodedValue = "Decoded value: " + decodedValue + ", checksum: " + checksum;
}
public int getMessageNumber() {
return messageNumber;
}
public static Message readMessage(DataInputStream in) throws IOException {
byte[] messageBytes = new byte[MESSAGE_SIZE];
in.readFully(messageBytes);
return new Message(messageBytes);
}
}
/** The simplest form of Decoder that only uses one thread */
private static class SerialDecoder implements Runnable {
private final DataInputStream in;
private final BlockingQueue<Message> outputQueue;
public SerialDecoder(InputStream in, BlockingQueue<Message> outputQueue) {
this.in = new DataInputStream(in);
this.outputQueue = outputQueue;
}
@Override
public void run() {
while(true) {
try {
Message message = Message.readMessage(in);
message.decode();
outputQueue.put(message);
} catch (Exception e) {
e.printStackTrace();
}
}
}
}
/** Naive parallelizing implementation using a thread pool */
private static class ThreadPoolDecoder implements Runnable {
private static final int NUMBER_OF_THREADS = 12;
private final DataInputStream in;
private final BlockingQueue<Message> outputQueue;
private ExecutorService threadPool;
private int currentWorkerIndex = 0;
private DecoderWorker[] workers;
public ThreadPoolDecoder(InputStream in, BlockingQueue<Message> outputQueue) {
this.in = new DataInputStream(in);
this.outputQueue = outputQueue;
threadPool = Executors.newFixedThreadPool(NUMBER_OF_THREADS);
workers = new DecoderWorker[NUMBER_OF_THREADS];
for (int i = 0; i < NUMBER_OF_THREADS ; i++) {
workers[i] = new DecoderWorker();
threadPool.submit(workers[i]);
}
}
public void run() {
int numberOfMessagesProcessed = 0;
while (true) {
try {
Message message = Message.readMessage(in);
workers[currentWorkerIndex].inQueue.put(message);
currentWorkerIndex++;
if (currentWorkerIndex == NUMBER_OF_THREADS) {
currentWorkerIndex = 0;
}
numberOfMessagesProcessed++;
if (numberOfMessagesProcessed >= NUMBER_OF_THREADS) {
outputQueue.put(workers[currentWorkerIndex].outQueue.take());
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
/** Worker used to decode messages */
private class DecoderWorker implements Runnable {
SynchronousQueue<Message> inQueue = new SynchronousQueue<Message>();
SynchronousQueue<Message> outQueue = new SynchronousQueue<Message>();
public void run() {
while (true) {
try {
Message message = inQueue.take();
message.decode();
outQueue.put(message);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
}
}
}