import org.apache.thrift.protocol.TCompactProtocol; import org.apache.thrift.protocol.TBinaryProtocol; import org.apache.thrift.transport.TFramedTransport; import org.apache.thrift.transport.TSocket; import org.apache.thrift.transport.TTransport; import org.glassfish.grizzly.Connection; import org.glassfish.grizzly.filterchain.FilterChainBuilder; import org.glassfish.grizzly.filterchain.TransportFilter; import org.glassfish.grizzly.nio.transport.TCPNIOTransport; import org.glassfish.grizzly.nio.transport.TCPNIOTransportBuilder; import org.glassfish.grizzly.strategies.SameThreadIOStrategy; import org.glassfish.grizzly.thrift.TGrizzlyClientTransport; import org.glassfish.grizzly.thrift.ThriftClientFilter; import org.glassfish.grizzly.thrift.ThriftFrameFilter; import org.jboss.netty.bootstrap.ClientBootstrap; import org.jboss.netty.channel.Channel; import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory; import se.cgbystrom.netty.thrift.TNettyTransport; import se.cgbystrom.netty.thrift.ThriftClientHandler; import se.cgbystrom.netty.thrift.ThriftPipelineFactory; import shared.SharedStruct; import tutorial.Calculator; import tutorial.InvalidOperation; import tutorial.Operation; import tutorial.Work; import java.net.InetSocketAddress; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; /** * @author Bongjae Chang */ public class ThriftClientBenchmark { private final String targetHost; private final int targetPort; private final String serverType; private final long warmingupSecond; private final long runningSecond; private final int clientThreadCount; private final CountDownLatch allConnected; private final CountDownLatch allClientFinished; private final CountDownLatch startSignal; private final AtomicLong result = new AtomicLong(); private final AtomicInteger exception = new AtomicInteger(); private final RunnerThread[] threads; // for 3M //private final static int size = 4 * 1024 * 1024; // for 3K //private final static int size = 4 * 1024; // for 300Bytes private final static int size = 4 * 102; private volatile boolean benchmark; private volatile boolean connectFailure; private enum ServerType { TSERVER, TTHREADPOOLSERVER, TNONBLOCKINGSERVER, NETTY, GRIZZLY } private ThriftClientBenchmark(Builder builder) { targetHost = builder.targetHost; targetPort = builder.targetPort; serverType = builder.serverType; warmingupSecond = builder.warmingupSecond; runningSecond = builder.runningSecond; clientThreadCount = builder.clientThreadCount; allConnected = new CountDownLatch(clientThreadCount); allClientFinished = new CountDownLatch(clientThreadCount); startSignal = new CountDownLatch(1); threads = new RunnerThread[clientThreadCount]; } public void startBenchmark() throws Exception { for (int i = 0; i < clientThreadCount; i++) { final ThriftClient thriftClient = getThriftClient(serverType); threads[i] = new RunnerThread(thriftClient); threads[i].start(); } allConnected.await(); if (connectFailure) { throw new Exception("failed to have all connections"); } System.out.println("all client are connected."); startSignal.countDown(); System.out.println("wait for warming-up."); Thread.sleep(warmingupSecond * 1000); benchmark = true; System.out.println("start benchmarking."); Thread.sleep(runningSecond * 1000); } public void stopBenchmark() throws Exception { benchmark = false; System.out.println("stop benchmarking."); for (int i = 0; i < clientThreadCount; i++) { threads[i].stopThread(); } allClientFinished.await(); System.out.println("** result = " + result.get() + ", exceptions = " + exception.get() + " in " + runningSecond + "sec"); } private ThriftClient getThriftClient(String serverType) { ThriftClient thriftClient; if (serverType == null) { throw new IllegalArgumentException(); } final String serverTypeUpperCase = serverType.toUpperCase(); if (serverTypeUpperCase.equals(ServerType.TSERVER.toString())) { thriftClient = new TThriftClient(); } else if (serverTypeUpperCase.equals(ServerType.TTHREADPOOLSERVER.toString())) { thriftClient = new TThriftClient(); } else if (serverTypeUpperCase.equals(ServerType.TNONBLOCKINGSERVER.toString())) { thriftClient = new TThriftClient(); } else if (serverTypeUpperCase.equals(ServerType.NETTY.toString())) { thriftClient = new NettyThriftClient(); } else if (serverTypeUpperCase.equals(ServerType.GRIZZLY.toString())) { thriftClient = new GrizzlyThriftClient(); } else { throw new IllegalArgumentException("unknown server type"); } return thriftClient; } private class RunnerThread extends Thread { private final ThriftClient thriftClient; private volatile boolean running; private RunnerThread(ThriftClient thriftClient) { this.thriftClient = thriftClient; } @Override public void run() { try { thriftClient.connect(); } catch (Exception e) { connectFailure = true; allConnected.countDown(); allClientFinished.countDown(); return; } allConnected.countDown(); try { startSignal.await(); } catch (InterruptedException e) { Thread.interrupted(); try { thriftClient.close(); } catch (Exception ignore) { } allClientFinished.countDown(); return; } running = true; while (running) { try { thriftClient.perform(); } catch (Exception e) { if (benchmark) { exception.incrementAndGet(); } continue; } if (benchmark) { result.incrementAndGet(); } } try { thriftClient.close(); } catch (Exception ignore) { } allClientFinished.countDown(); } private void stopThread() { running = false; } } private interface ThriftClient { public void connect() throws Exception; public void perform() throws Exception; public void close() throws Exception; } private abstract class AbstractThriftClient implements ThriftClient { protected TTransport ttransport; @Override public void perform() throws Exception { // for BinaryProtocol //Calculator.Client client = new Calculator.Client(new TBinaryProtocol(ttransport)); // for CompactProtocol Calculator.Client client = new Calculator.Client(new TCompactProtocol(ttransport)); // default operations client.ping(); int sum = client.add(1, 1); if (sum != 2) { throw new InvalidOperation(); } Work work = new Work(); work.op = Operation.DIVIDE; work.num1 = 1; work.num2 = 0; try { int quotient = client.calculate(1, work); throw new InvalidOperation(); } catch (InvalidOperation ignored) { } work.op = Operation.SUBTRACT; work.num1 = 15; work.num2 = 10; int diff = client.calculate(1, work); if (diff != 5) { throw new InvalidOperation(); } // get large packets SharedStruct log = client.getStruct(1); if (log.value == null) { throw new InvalidOperation(); } } @Override public void close() throws Exception { if (ttransport != null) { ttransport.close(); ttransport = null; } } } private class GrizzlyThriftClient extends AbstractThriftClient { private TCPNIOTransport transport; @Override public void connect() throws Exception { if (transport == null) { final FilterChainBuilder clientFilterChainBuilder = FilterChainBuilder.stateless(); final ThriftClientFilter clientFilter = new ThriftClientFilter(); clientFilterChainBuilder.add(new TransportFilter()).add(new ThriftFrameFilter(size)).add(clientFilter); transport = TCPNIOTransportBuilder.newInstance().build(); transport.setProcessor(clientFilterChainBuilder.build()); transport.setIOStrategy(SameThreadIOStrategy.getInstance()); transport.start(); final Future future = transport.connect(targetHost, targetPort); final Connection connection = future.get(10, TimeUnit.SECONDS); ttransport = new TGrizzlyClientTransport(connection, clientFilter); } } @Override public void close() throws Exception { super.close(); if (transport != null) { transport.stop(); transport = null; } } } private class NettyThriftClient extends AbstractThriftClient { private Channel channel; @Override public void connect() throws Exception { if (channel == null) { final ClientBootstrap clientBootstrap = new ClientBootstrap(new NioClientSocketChannelFactory(Executors.newCachedThreadPool(), Executors.newCachedThreadPool())); final ThriftClientHandler clientHandler = new ThriftClientHandler(); clientBootstrap.setPipelineFactory(new ThriftPipelineFactory(clientHandler, size)); channel = clientBootstrap.connect(new InetSocketAddress(targetHost, targetPort)).awaitUninterruptibly().getChannel(); ttransport = new TNettyTransport(channel, clientHandler); } } @Override public void close() throws Exception { super.close(); if (channel != null) { channel.close().awaitUninterruptibly(); channel = null; } } } private class TThriftClient extends AbstractThriftClient { private TSocket tsocket; @Override public void connect() throws Exception { if (tsocket == null) { tsocket = new TSocket(targetHost, targetPort); ttransport = new TFramedTransport(tsocket); ttransport.open(); } } @Override public void close() throws Exception { super.close(); if (tsocket != null) { tsocket.close(); tsocket = null; } } } public static class Builder { private String targetHost = "localhost"; private int targetPort = 9090; private String serverType = "grizzly"; private long warmingupSecond = 60; private long runningSecond = 60 * 2; private int clientThreadCount = 10; public Builder targetHost(String targetHost) { this.targetHost = targetHost; return this; } public Builder targetPort(int targetPort) { this.targetPort = targetPort; return this; } public Builder serverType(String serverType) { this.serverType = serverType; return this; } public Builder warmingupSecond(long warmingupSecond) { this.warmingupSecond = warmingupSecond; return this; } public Builder runningSecond(long runningSecond) { this.runningSecond = runningSecond; return this; } public Builder clientThreadCount(int clientThreadCount) { this.clientThreadCount = clientThreadCount; return this; } public ThriftClientBenchmark build() { return new ThriftClientBenchmark(this); } } public static void main(String[] args) { final String targetHost; final int targetPort; final String serverType; final long warmingupSecond; final long runningSecond; final int clientThreadCount; if (args.length != 6) { printUsages(); return; } targetHost = args[0]; try { targetPort = Integer.parseInt(args[1]); } catch (NumberFormatException e) { printUsages(); return; } serverType = args[2]; try { warmingupSecond = Long.parseLong(args[3]); } catch (NumberFormatException e) { printUsages(); return; } try { runningSecond = Long.parseLong(args[4]); } catch (NumberFormatException e) { printUsages(); return; } try { clientThreadCount = Integer.parseInt(args[5]); } catch (NumberFormatException e) { printUsages(); return; } ThriftClientBenchmark test = new ThriftClientBenchmark.Builder() .targetHost(targetHost) .targetPort(targetPort) .serverType(serverType) .warmingupSecond(warmingupSecond) .runningSecond(runningSecond) .clientThreadCount(clientThreadCount) .build(); try { test.startBenchmark(); } catch (Exception e) { e.printStackTrace(); } try { test.stopBenchmark(); } catch (Exception ignore) { } } private static void printUsages() { System.out.println("please enter target host, target port, server type, warmingupSecond, runningSecond and clientThreadCount.\n"); System.out.println("server type: \"tserver\" or \"tthreadpoolserver\" or \"tnonblockingserver\" or \"netty\" or \"grizzly\""); System.out.println("ex) java -cp . ThriftClientBenchmark localhost 9090 grizzly 60 120 10"); } }