Thursday, February 25, 2010

NIO proxy for manipulating traffic

I wrote some little proxy for manipulating the traffic between a web service client and server - though it can be used for other traffic as well. I used those two fantastic sites to brush up on NIO: Example Depot and IBM developerWorks. Here's the code I came up with:


import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.Iterator;

public class NIOTunnel {
int port = 8080;
String hostName = "localhost";
int hostPort = 8089;
private ServerSocketChannel ssServerChannel;
private Selector selector;

protected SocketChannel createClientSocket() throws IOException {
SocketChannel sChannel = SocketChannel.open();
sChannel.configureBlocking(false); // Send a connection request to the
// server; this method is
// non-blocking
sChannel.connect(new InetSocketAddress(hostName, hostPort));

return sChannel;
}

protected void createServerChannel() throws IOException {
ssServerChannel = ServerSocketChannel.open();
ssServerChannel.configureBlocking(false);
ssServerChannel.socket().bind(new InetSocketAddress(port));

}

protected void setUpAndWait() throws IOException {
selector = Selector.open();
// Register both channels with selector
createServerChannel();
ssServerChannel.register(selector, SelectionKey.OP_ACCEPT);

while (true) { // Wait for an event
selector.select(); // Get list of selection keys with pending events
Iterator it = selector.selectedKeys().iterator(); // Process
// each
// key
while (it.hasNext()) { // Get the selection key
SelectionKey key = (SelectionKey) it.next(); // Remove it from
// the list to
// indicate that it
// is being
// processed
it.remove(); // Check if it's a connection request
SocketChannel socket;
if (key.isAcceptable()) {
System.out.println("Acceptable Key");
ServerSocketChannel ssc = (ServerSocketChannel) key.channel();
socket = (SocketChannel) ssc.accept();
socket.configureBlocking(false);
ConnectionHandler h = new ConnectionHandler(socket, createClientSocket());
h.run();
}

}
}
}

public static void main(String[] args) throws IOException {
NIOTunnel tunnel = new NIOTunnel();
tunnel.setUpAndWait();
}
}

public class ConnectionHandler extends Thread {
private SocketChannel serverChannel;
private SocketChannel clientChannel;
private Selector selector;
private ModifyTraffic modifyTraffic;

public ConnectionHandler(SocketChannel serverChannel, SocketChannel clientChannel) {
this.serverChannel = serverChannel;
this.clientChannel = clientChannel;
this.modifyTraffic = new ModifyTraffic();
}

public void run() {
try {
selector = Selector.open();
while (!clientChannel.finishConnect()) {
//wait
}
SelectionKey client = clientChannel.register(selector,SelectionKey.OP_READ|SelectionKey.OP_WRITE);
SelectionKey server = serverChannel.register(selector,SelectionKey.OP_READ|SelectionKey.OP_WRITE);

client.attach(new WriteStorage(false));
server.attach(new WriteStorage(true));

while (client.isValid() && server.isValid() )
{
selector.select();
Set readyKeys = selector.selectedKeys();
Iterator it = readyKeys.iterator();

while (it.hasNext()) {
SelectionKey key = (SelectionKey)it.next();
it.remove();

if (key.isValid() && key.isReadable()) {
String ret = readMessage(key);
WriteStorage ws = (WriteStorage) key.attachment();
if (ws.isServer()) {
((WriteStorage)client.attachment()).append(modifyTraffic.server(ret));
} else {
((WriteStorage)server.attachment()).append(modifyTraffic.client(ret));
}
continue;
}

if (key.isValid() && key.isWritable()) {
writeMessage(key, (WriteStorage)key.attachment());
}

}
}

} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
return;
} finally {
try {
selector.close();
clientChannel.close();
serverChannel.close();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}

public void writeMessage(SelectionKey key, WriteStorage writeStorage) throws IOException
{
SocketChannel socket = (SocketChannel)key.channel();
ByteBuffer buffer = ByteBuffer.wrap(writeStorage.getString().getBytes());
int nBytes = socket.write(buffer);
writeStorage.clear();
}




public String readMessage(SelectionKey key) throws IOException
{
int nBytes = 0;
SocketChannel socket = (SocketChannel)key.channel();
ByteBuffer buf = ByteBuffer.allocate(1024);
nBytes = socket.read(buf);
if (nBytes == -1) {
return null;
}
buf.flip();
Charset charset = Charset.forName("us-ascii");
CharsetDecoder decoder = charset.newDecoder();
CharBuffer charBuffer = decoder.decode(buf);
String result = charBuffer.toString();
System.out.println(result);
return result;

}
}

public class WriteStorage {
private StringBuffer toWrite;
private boolean server;

public WriteStorage(boolean server) {
this.setServer(server);
clear();
}

public boolean isServer() {
return server;
}
public void setServer(boolean server) {
this.server = server;
}

public void clear() {
toWrite = new StringBuffer();
}

public void append(String s) {
toWrite.append(s);
}

public String getString() {
return toWrite.toString();
}
}