1 /*
2  * JBoss, Home of Professional Open Source.
3  * Copyright 2012 Red Hat, Inc., and individual contributors
4  * as indicated by the @author tags.
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  *     http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */

18
19 package org.xnio.nio;
20
21 import static org.xnio.IoUtils.safeClose;
22 import static org.xnio.nio.Log.log;
23 import static org.xnio.nio.Log.tcpServerLog;
24
25 import java.io.IOException;
26 import java.net.InetSocketAddress;
27 import java.net.ServerSocket;
28 import java.net.Socket;
29 import java.net.SocketAddress;
30 import java.nio.channels.ClosedChannelException;
31 import java.nio.channels.SelectionKey;
32 import java.nio.channels.ServerSocketChannel;
33 import java.nio.channels.SocketChannel;
34 import java.util.Set;
35 import java.util.concurrent.CountDownLatch;
36 import java.util.concurrent.ThreadLocalRandom;
37 import java.util.concurrent.TimeUnit;
38 import java.util.concurrent.atomic.AtomicInteger;
39 import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
40 import java.util.concurrent.atomic.AtomicLongFieldUpdater;
41
42 import org.jboss.logging.Logger;
43 import org.xnio.ChannelListener;
44 import org.xnio.ManagementRegistration;
45 import org.xnio.IoUtils;
46 import org.xnio.LocalSocketAddress;
47 import org.xnio.Option;
48 import org.xnio.OptionMap;
49 import org.xnio.Options;
50 import org.xnio.StreamConnection;
51 import org.xnio.XnioExecutor;
52 import org.xnio.channels.AcceptListenerSettable;
53 import org.xnio.channels.AcceptingChannel;
54 import org.xnio.channels.UnsupportedOptionException;
55 import org.xnio.management.XnioServerMXBean;
56
57 final class NioTcpServer extends AbstractNioChannel<NioTcpServer> implements AcceptingChannel<StreamConnection>, AcceptListenerSettable<NioTcpServer> {
58     private static final String FQCN = NioTcpServer.class.getName();
59
60     private volatile ChannelListener<? super NioTcpServer> acceptListener;
61
62     private final NioTcpServerHandle[] handles;
63
64     private final ServerSocketChannel channel;
65     private final ServerSocket socket;
66     private final ManagementRegistration mbeanHandle;
67
68     private static final Set<Option<?>> options = Option.setBuilder()
69             .add(Options.REUSE_ADDRESSES)
70             .add(Options.RECEIVE_BUFFER)
71             .add(Options.SEND_BUFFER)
72             .add(Options.KEEP_ALIVE)
73             .add(Options.TCP_OOB_INLINE)
74             .add(Options.TCP_NODELAY)
75             .add(Options.CONNECTION_HIGH_WATER)
76             .add(Options.CONNECTION_LOW_WATER)
77             .add(Options.READ_TIMEOUT)
78             .add(Options.WRITE_TIMEOUT)
79             .create();
80
81     @SuppressWarnings("unused")
82     private volatile int keepAlive;
83     @SuppressWarnings("unused")
84     private volatile int oobInline;
85     @SuppressWarnings("unused")
86     private volatile int tcpNoDelay;
87     @SuppressWarnings("unused")
88     private volatile int sendBuffer = -1;
89     @SuppressWarnings("unused")
90     private volatile long connectionStatus = CONN_LOW_MASK | CONN_HIGH_MASK;
91     @SuppressWarnings("unused")
92     private volatile int readTimeout;
93     @SuppressWarnings("unused")
94     private volatile int writeTimeout;
95     private volatile int tokenConnectionCount;
96     volatile boolean resumed;
97
98     private static final long CONN_LOW_MASK     = 0x000000007FFFFFFFL;
99     private static final long CONN_LOW_BIT      = 0L;
100     @SuppressWarnings("unused")
101     private static final long CONN_LOW_ONE      = 1L;
102     private static final long CONN_HIGH_MASK    = 0x3FFFFFFF80000000L;
103     private static final long CONN_HIGH_BIT     = 31L;
104     @SuppressWarnings("unused")
105     private static final long CONN_HIGH_ONE     = 1L << CONN_HIGH_BIT;
106
107     private static final AtomicIntegerFieldUpdater<NioTcpServer> keepAliveUpdater = AtomicIntegerFieldUpdater.newUpdater(NioTcpServer.class"keepAlive");
108     private static final AtomicIntegerFieldUpdater<NioTcpServer> oobInlineUpdater = AtomicIntegerFieldUpdater.newUpdater(NioTcpServer.class"oobInline");
109     private static final AtomicIntegerFieldUpdater<NioTcpServer> tcpNoDelayUpdater = AtomicIntegerFieldUpdater.newUpdater(NioTcpServer.class"tcpNoDelay");
110     private static final AtomicIntegerFieldUpdater<NioTcpServer> sendBufferUpdater = AtomicIntegerFieldUpdater.newUpdater(NioTcpServer.class"sendBuffer");
111     private static final AtomicIntegerFieldUpdater<NioTcpServer> readTimeoutUpdater = AtomicIntegerFieldUpdater.newUpdater(NioTcpServer.class"readTimeout");
112     private static final AtomicIntegerFieldUpdater<NioTcpServer> writeTimeoutUpdater = AtomicIntegerFieldUpdater.newUpdater(NioTcpServer.class"writeTimeout");
113
114     private static final AtomicLongFieldUpdater<NioTcpServer> connectionStatusUpdater = AtomicLongFieldUpdater.newUpdater(NioTcpServer.class"connectionStatus");
115
116     NioTcpServer(final NioXnioWorker worker, final ServerSocketChannel channel, final OptionMap optionMap, final boolean useAcceptThreadOnly) throws IOException {
117         super(worker);
118         this.channel = channel;
119         final WorkerThread[] threads;
120         final int threadCount;
121         final int tokens;
122         final int connections;
123         if (useAcceptThreadOnly) {
124             threads = new WorkerThread[] { worker.getAcceptThread() };
125             threadCount = 1;
126             tokens = 0;
127             connections = 0;
128         } else {
129             threads = worker.getAll();
130             threadCount = threads.length;
131             if (threadCount == 0) {
132                 throw log.noThreads();
133             }
134             tokens = optionMap.get(Options.BALANCING_TOKENS, -1);
135             connections = optionMap.get(Options.BALANCING_CONNECTIONS, 16);
136             if (tokens != -1) {
137                 if (tokens < 1 || tokens >= threadCount) {
138                     throw log.balancingTokens();
139                 }
140                 if (connections < 1) {
141                     throw log.balancingConnectionCount();
142                 }
143                 tokenConnectionCount = connections;
144             }
145         }
146         socket = channel.socket();
147         if (optionMap.contains(Options.SEND_BUFFER)) {
148             final int sendBufferSize = optionMap.get(Options.SEND_BUFFER, DEFAULT_BUFFER_SIZE);
149             if (sendBufferSize < 1) {
150                 throw log.parameterOutOfRange("sendBufferSize");
151             }
152             sendBufferUpdater.set(this, sendBufferSize);
153         }
154         if (optionMap.contains(Options.KEEP_ALIVE)) {
155             keepAliveUpdater.lazySet(this, optionMap.get(Options.KEEP_ALIVE, false) ? 1 : 0);
156         }
157         if (optionMap.contains(Options.TCP_OOB_INLINE)) {
158             oobInlineUpdater.lazySet(this, optionMap.get(Options.TCP_OOB_INLINE, false) ? 1 : 0);
159         }
160         if (optionMap.contains(Options.TCP_NODELAY)) {
161             tcpNoDelayUpdater.lazySet(this, optionMap.get(Options.TCP_NODELAY, false) ? 1 : 0);
162         }
163         if (optionMap.contains(Options.READ_TIMEOUT)) {
164             readTimeoutUpdater.lazySet(this, optionMap.get(Options.READ_TIMEOUT, 0));
165         }
166         if (optionMap.contains(Options.WRITE_TIMEOUT)) {
167             writeTimeoutUpdater.lazySet(this, optionMap.get(Options.WRITE_TIMEOUT, 0));
168         }
169         int perThreadLow, perThreadLowRem;
170         int perThreadHigh, perThreadHighRem;
171         if (optionMap.contains(Options.CONNECTION_HIGH_WATER) || optionMap.contains(Options.CONNECTION_LOW_WATER)) {
172             final int highWater = optionMap.get(Options.CONNECTION_HIGH_WATER, Integer.MAX_VALUE);
173             final int lowWater = optionMap.get(Options.CONNECTION_LOW_WATER, highWater);
174             if (highWater <= 0) {
175                 throw badHighWater();
176             }
177             if (lowWater <= 0 || lowWater > highWater) {
178                 throw badLowWater(highWater);
179             }
180             final long highLowWater = (long) highWater << CONN_HIGH_BIT | (long) lowWater << CONN_LOW_BIT;
181             connectionStatusUpdater.lazySet(this, highLowWater);
182             perThreadLow = lowWater / threadCount;
183             perThreadLowRem = lowWater % threadCount;
184             perThreadHigh = highWater / threadCount;
185             perThreadHighRem = highWater % threadCount;
186         } else {
187             perThreadLow = Integer.MAX_VALUE;
188             perThreadLowRem = 0;
189             perThreadHigh = Integer.MAX_VALUE;
190             perThreadHighRem = 0;
191             connectionStatusUpdater.lazySet(this, CONN_LOW_MASK | CONN_HIGH_MASK);
192         }
193         final NioTcpServerHandle[] handles = new NioTcpServerHandle[threadCount];
194         for (int i = 0, length = threadCount; i < length; i++) {
195             final SelectionKey key = threads[i].registerChannel(channel);
196             handles[i] = new NioTcpServerHandle(this, key, threads[i], i < perThreadHighRem ? perThreadHigh + 1 : perThreadHigh, i < perThreadLowRem ? perThreadLow + 1 : perThreadLow);
197             key.attach(handles[i]);
198         }
199         this.handles = handles;
200         if (tokens > 0) {
201             for (int i = 0; i < threadCount; i ++) {
202                 handles[i].initializeTokenCount(i < tokens ? connections : 0);
203             }
204         }
205         mbeanHandle = worker.registerServerMXBean(
206                 new XnioServerMXBean() {
207                     public String getProviderName() {
208                         return "nio";
209                     }
210
211                     public String getWorkerName() {
212                         return worker.getName();
213                     }
214
215                     public String getBindAddress() {
216                         return String.valueOf(getLocalAddress());
217                     }
218
219                     public int getConnectionCount() {
220                         final AtomicInteger counter = new AtomicInteger();
221                         final CountDownLatch latch = new CountDownLatch(handles.length);
222                         for (final NioTcpServerHandle handle : handles) {
223                             handle.getWorkerThread().execute(() -> {
224                                 counter.getAndAdd(handle.getConnectionCount());
225                                 latch.countDown();
226                             });
227                         }
228                         try {
229                             latch.await();
230                         } catch (InterruptedException e) {
231                             Thread.currentThread().interrupt();
232                         }
233                         return counter.get();
234                     }
235
236                     public int getConnectionLimitHighWater() {
237                         return getHighWater(connectionStatus);
238                     }
239
240                     public int getConnectionLimitLowWater() {
241                         return getLowWater(connectionStatus);
242                     }
243                 }
244         );
245
246     }
247
248     private static IllegalArgumentException badLowWater(final int highWater) {
249         return new IllegalArgumentException("Low water must be greater than 0 and less than or equal to high water (" + highWater + ")");
250     }
251
252     private static IllegalArgumentException badHighWater() {
253         return new IllegalArgumentException("High water must be greater than 0");
254     }
255
256     public void close() throws IOException {
257         try {
258             channel.close();
259         } finally {
260             for (NioTcpServerHandle handle : handles) {
261                 handle.cancelKey(false);
262             }
263             safeClose(mbeanHandle);
264         }
265     }
266
267     public boolean supportsOption(final Option<?> option) {
268         return options.contains(option);
269     }
270
271     public <T> T getOption(final Option<T> option) throws UnsupportedOptionException, IOException {
272         if (option == Options.REUSE_ADDRESSES) {
273             return option.cast(Boolean.valueOf(socket.getReuseAddress()));
274         } else if (option == Options.RECEIVE_BUFFER) {
275             return option.cast(Integer.valueOf(socket.getReceiveBufferSize()));
276         } else if (option == Options.SEND_BUFFER) {
277             final int value = sendBuffer;
278             return value == -1 ? null : option.cast(Integer.valueOf(value));
279         } else if (option == Options.KEEP_ALIVE) {
280             return option.cast(Boolean.valueOf(keepAlive != 0));
281         } else if (option == Options.TCP_OOB_INLINE) {
282             return option.cast(Boolean.valueOf(oobInline != 0));
283         } else if (option == Options.TCP_NODELAY) {
284             return option.cast(Boolean.valueOf(tcpNoDelay != 0));
285         } else if (option == Options.READ_TIMEOUT) {
286             return option.cast(Integer.valueOf(readTimeout));
287         } else if (option == Options.WRITE_TIMEOUT) {
288             return option.cast(Integer.valueOf(writeTimeout));
289         } else if (option == Options.CONNECTION_HIGH_WATER) {
290             return option.cast(Integer.valueOf(getHighWater(connectionStatus)));
291         } else if (option == Options.CONNECTION_LOW_WATER) {
292             return option.cast(Integer.valueOf(getLowWater(connectionStatus)));
293         } else {
294             return null;
295         }
296     }
297
298     public <T> T setOption(final Option<T> option, final T value) throws IllegalArgumentException, IOException {
299         final Object old;
300         if (option == Options.REUSE_ADDRESSES) {
301             old = Boolean.valueOf(socket.getReuseAddress());
302             socket.setReuseAddress(Options.REUSE_ADDRESSES.cast(value, Boolean.FALSE).booleanValue());
303         } else if (option == Options.RECEIVE_BUFFER) { 
304             old = Integer.valueOf(socket.getReceiveBufferSize());
305             final int newValue = Options.RECEIVE_BUFFER.cast(value, Integer.valueOf(DEFAULT_BUFFER_SIZE)).intValue();
306             if (newValue < 1) {
307                 throw log.optionOutOfRange("RECEIVE_BUFFER");
308             }
309             socket.setReceiveBufferSize(newValue);
310         } else if (option == Options.SEND_BUFFER) {
311             final int newValue = Options.SEND_BUFFER.cast(value, Integer.valueOf(DEFAULT_BUFFER_SIZE)).intValue();
312             if (newValue < 1) {
313                 throw log.optionOutOfRange("SEND_BUFFER");
314             }
315             final int oldValue = sendBufferUpdater.getAndSet(this, newValue);
316             old = oldValue == -1 ? null : Integer.valueOf(oldValue);
317         } else if (option == Options.KEEP_ALIVE) {
318             old = Boolean.valueOf(keepAliveUpdater.getAndSet(this, Options.KEEP_ALIVE.cast(value, Boolean.FALSE).booleanValue() ? 1 : 0) != 0);
319         } else if (option == Options.TCP_OOB_INLINE) {
320             old = Boolean.valueOf(oobInlineUpdater.getAndSet(this, Options.TCP_OOB_INLINE.cast(value, Boolean.FALSE).booleanValue() ? 1 : 0) != 0);
321         } else if (option == Options.TCP_NODELAY) {
322             old = Boolean.valueOf(tcpNoDelayUpdater.getAndSet(this, Options.TCP_NODELAY.cast(value, Boolean.FALSE).booleanValue() ? 1 : 0) != 0);
323         } else if (option == Options.READ_TIMEOUT) {
324             old = Integer.valueOf(readTimeoutUpdater.getAndSet(this, Options.READ_TIMEOUT.cast(value, Integer.valueOf(0)).intValue()));
325         } else if (option == Options.WRITE_TIMEOUT) {
326             old = Integer.valueOf(writeTimeoutUpdater.getAndSet(this, Options.WRITE_TIMEOUT.cast(value, Integer.valueOf(0)).intValue()));
327         } else if (option == Options.CONNECTION_HIGH_WATER) {
328             old = Integer.valueOf(getHighWater(updateWaterMark(-1, Options.CONNECTION_HIGH_WATER.cast(value, Integer.valueOf(Integer.MAX_VALUE)).intValue())));
329         } else if (option == Options.CONNECTION_LOW_WATER) {
330             old = Integer.valueOf(getLowWater(updateWaterMark(Options.CONNECTION_LOW_WATER.cast(value, Integer.valueOf(Integer.MAX_VALUE)).intValue(), -1)));
331         } else {
332             return null;
333         }
334         return option.cast(old);
335     }
336
337     private long updateWaterMark(int reqNewLowWater, int reqNewHighWater) {
338         // at least one must be specified
339         assert reqNewLowWater != -1 || reqNewHighWater != -1;
340         // if both given, low must be less than high
341         assert reqNewLowWater == -1 || reqNewHighWater == -1 || reqNewLowWater <= reqNewHighWater;
342
343         long oldVal, newVal;
344         int oldHighWater, oldLowWater;
345         int newLowWater, newHighWater;
346
347         do {
348             oldVal = connectionStatus;
349             oldLowWater = getLowWater(oldVal);
350             oldHighWater = getHighWater(oldVal);
351             newLowWater = reqNewLowWater == -1 ? oldLowWater : reqNewLowWater;
352             newHighWater = reqNewHighWater == -1 ? oldHighWater : reqNewHighWater;
353             // Make sure the new values make sense
354             if (reqNewLowWater != -1 && newLowWater > newHighWater) {
355                 newHighWater = newLowWater;
356             } else if (reqNewHighWater != -1 && newHighWater < newLowWater) {
357                 newLowWater = newHighWater;
358             }
359             // See if the change would be redundant
360             if (oldLowWater == newLowWater && oldHighWater == newHighWater) {
361                 return oldVal;
362             }
363             newVal = (long)newLowWater << CONN_LOW_BIT | (long)newHighWater << CONN_HIGH_BIT;
364         } while (! connectionStatusUpdater.compareAndSet(this, oldVal, newVal));
365
366         final NioTcpServerHandle[] conduits = handles;
367         final int threadCount = conduits.length;
368
369         int perThreadLow, perThreadLowRem;
370         int perThreadHigh, perThreadHighRem;
371
372         perThreadLow = newLowWater / threadCount;
373         perThreadLowRem = newLowWater % threadCount;
374         perThreadHigh = newHighWater / threadCount;
375         perThreadHighRem = newHighWater % threadCount;
376
377         for (int i = 0; i < conduits.length; i++) {
378             NioTcpServerHandle conduit = conduits[i];
379             conduit.executeSetTask(i < perThreadHighRem ? perThreadHigh + 1 : perThreadHigh, i < perThreadLowRem ? perThreadLow + 1 : perThreadLow);
380         }
381
382         return oldVal;
383     }
384
385     private static int getHighWater(final long value) {
386         return (int) ((value & CONN_HIGH_MASK) >> CONN_HIGH_BIT);
387     }
388
389     private static int getLowWater(final long value) {
390         return (int) ((value & CONN_LOW_MASK) >> CONN_LOW_BIT);
391     }
392
393     public NioSocketStreamConnection accept() throws ClosedChannelException {
394         final WorkerThread current = WorkerThread.getCurrent();
395         if (current == null) {
396             return null;
397         }
398         final NioTcpServerHandle handle;
399         if (handles.length == 1) {
400             handle = handles[0];
401         } else {
402             handle = handles[current.getNumber()];
403         }
404         if (! handle.getConnection()) {
405             return null;
406         }
407         final SocketChannel accepted;
408         boolean ok = false;
409         try {
410             accepted = channel.accept();
411             if (accepted != nulltry {
412                 int hash = ThreadLocalRandom.current().nextInt();
413                 accepted.configureBlocking(false);
414                 final Socket socket = accepted.socket();
415                 socket.setKeepAlive(keepAlive != 0);
416                 socket.setOOBInline(oobInline != 0);
417                 socket.setTcpNoDelay(tcpNoDelay != 0);
418                 final int sendBuffer = this.sendBuffer;
419                 if (sendBuffer > 0) socket.setSendBufferSize(sendBuffer);
420                 final WorkerThread ioThread = worker.getIoThread(hash);
421                 final SelectionKey selectionKey = ioThread.registerChannel(accepted);
422                 final NioSocketStreamConnection newConnection = new NioSocketStreamConnection(ioThread, selectionKey, handle);
423                 newConnection.setOption(Options.READ_TIMEOUT, Integer.valueOf(readTimeout));
424                 newConnection.setOption(Options.WRITE_TIMEOUT, Integer.valueOf(writeTimeout));
425                 ok = true;
426                 handle.resetBackOff();
427                 return newConnection;
428             } finally {
429                 if (! ok) safeClose(accepted);
430             }
431         } catch (ClosedChannelException e) {
432             throw e;
433         } catch (IOException e) {
434             // something went wrong with the accept
435             // it could be due to running out of file descriptors, or due to closed channel, or other things
436             handle.startBackOff();
437             log.acceptFailed(e, handle.getBackOffTime());
438             return null;
439         } finally {
440             if (! ok) {
441                 handle.freeConnection();
442             }
443         }
444         // by contract, only a resume will do
445         return null;
446     }
447
448     public String toString() {
449         return String.format("TCP server (NIO) <%s>", Integer.toHexString(hashCode()));
450     }
451
452     public ChannelListener<? super NioTcpServer> getAcceptListener() {
453         return acceptListener;
454     }
455
456     public void setAcceptListener(final ChannelListener<? super NioTcpServer> acceptListener) {
457         this.acceptListener = acceptListener;
458     }
459
460     public ChannelListener.Setter<NioTcpServer> getAcceptSetter() {
461         return new AcceptListenerSettable.Setter<NioTcpServer>(this);
462     }
463
464     public boolean isOpen() {
465         return channel.isOpen();
466     }
467
468     public SocketAddress getLocalAddress() {
469         return socket.getLocalSocketAddress();
470     }
471
472     public <A extends SocketAddress> A getLocalAddress(final Class<A> type) {
473         final SocketAddress address = getLocalAddress();
474         return type.isInstance(address) ? type.cast(address) : null;
475     }
476
477     public void suspendAccepts() {
478         resumed = false;
479         doResume(0);
480     }
481
482     public void resumeAccepts() {
483         resumed = true;
484         doResume(SelectionKey.OP_ACCEPT);
485     }
486
487     public boolean isAcceptResumed() {
488         return resumed;
489     }
490
491     private void doResume(final int op) {
492         if (op == 0) {
493             for (NioTcpServerHandle handle : handles) {
494                 handle.suspend();
495             }
496         } else {
497             for (NioTcpServerHandle handle : handles) {
498                 handle.resume();
499             }
500         }
501     }
502
503     public void wakeupAccepts() {
504         tcpServerLog.logf(FQCN, Logger.Level.TRACE, null"Wake up accepts on %s"this);
505         resumeAccepts();
506         final NioTcpServerHandle[] handles = this.handles;
507         final int idx = IoUtils.getThreadLocalRandom().nextInt(handles.length);
508         handles[idx].wakeup(SelectionKey.OP_ACCEPT);
509     }
510
511     public void awaitAcceptable() throws IOException {
512         throw log.unsupported("awaitAcceptable");
513     }
514
515     public void awaitAcceptable(final long time, final TimeUnit timeUnit) throws IOException {
516         throw log.unsupported("awaitAcceptable");
517     }
518
519     @Deprecated
520     public XnioExecutor getAcceptThread() {
521         return getIoThread();
522     }
523
524     NioTcpServerHandle getHandle(final int number) {
525         return handles[number];
526     }
527
528     int getTokenConnectionCount() {
529         return tokenConnectionCount;
530     }
531 }
532