1
16 package io.netty.channel.embedded;
17
18 import java.net.SocketAddress;
19 import java.nio.channels.ClosedChannelException;
20 import java.util.ArrayDeque;
21 import java.util.Queue;
22
23 import io.netty.channel.AbstractChannel;
24 import io.netty.channel.Channel;
25 import io.netty.channel.ChannelConfig;
26 import io.netty.channel.ChannelFuture;
27 import io.netty.channel.ChannelFutureListener;
28 import io.netty.channel.ChannelHandler;
29 import io.netty.channel.ChannelHandlerContext;
30 import io.netty.channel.ChannelId;
31 import io.netty.channel.ChannelInitializer;
32 import io.netty.channel.ChannelMetadata;
33 import io.netty.channel.ChannelOutboundBuffer;
34 import io.netty.channel.ChannelPipeline;
35 import io.netty.channel.ChannelPromise;
36 import io.netty.channel.DefaultChannelConfig;
37 import io.netty.channel.DefaultChannelPipeline;
38 import io.netty.channel.EventLoop;
39 import io.netty.channel.RecvByteBufAllocator;
40 import io.netty.util.ReferenceCountUtil;
41 import io.netty.util.internal.ObjectUtil;
42 import io.netty.util.internal.PlatformDependent;
43 import io.netty.util.internal.RecyclableArrayList;
44 import io.netty.util.internal.logging.InternalLogger;
45 import io.netty.util.internal.logging.InternalLoggerFactory;
46
47
50 public class EmbeddedChannel extends AbstractChannel {
51
52 private static final SocketAddress LOCAL_ADDRESS = new EmbeddedSocketAddress();
53 private static final SocketAddress REMOTE_ADDRESS = new EmbeddedSocketAddress();
54
55 private static final ChannelHandler[] EMPTY_HANDLERS = new ChannelHandler[0];
56 private enum State { OPEN, ACTIVE, CLOSED }
57
58 private static final InternalLogger logger = InternalLoggerFactory.getInstance(EmbeddedChannel.class);
59
60 private static final ChannelMetadata METADATA_NO_DISCONNECT = new ChannelMetadata(false);
61 private static final ChannelMetadata METADATA_DISCONNECT = new ChannelMetadata(true);
62
63 private final EmbeddedEventLoop loop = new EmbeddedEventLoop();
64 private final ChannelFutureListener recordExceptionListener = new ChannelFutureListener() {
65 @Override
66 public void operationComplete(ChannelFuture future) throws Exception {
67 recordException(future);
68 }
69 };
70
71 private final ChannelMetadata metadata;
72 private final ChannelConfig config;
73
74 private Queue<Object> inboundMessages;
75 private Queue<Object> outboundMessages;
76 private Throwable lastException;
77 private State state;
78
79
82 public EmbeddedChannel() {
83 this(EMPTY_HANDLERS);
84 }
85
86
91 public EmbeddedChannel(ChannelId channelId) {
92 this(channelId, EMPTY_HANDLERS);
93 }
94
95
100 public EmbeddedChannel(ChannelHandler... handlers) {
101 this(EmbeddedChannelId.INSTANCE, handlers);
102 }
103
104
111 public EmbeddedChannel(boolean hasDisconnect, ChannelHandler... handlers) {
112 this(EmbeddedChannelId.INSTANCE, hasDisconnect, handlers);
113 }
114
115
124 public EmbeddedChannel(boolean register, boolean hasDisconnect, ChannelHandler... handlers) {
125 this(EmbeddedChannelId.INSTANCE, register, hasDisconnect, handlers);
126 }
127
128
135 public EmbeddedChannel(ChannelId channelId, ChannelHandler... handlers) {
136 this(channelId, false, handlers);
137 }
138
139
148 public EmbeddedChannel(ChannelId channelId, boolean hasDisconnect, ChannelHandler... handlers) {
149 this(channelId, true, hasDisconnect, handlers);
150 }
151
152
163 public EmbeddedChannel(ChannelId channelId, boolean register, boolean hasDisconnect,
164 ChannelHandler... handlers) {
165 this(null, channelId, register, hasDisconnect, handlers);
166 }
167
168
180 public EmbeddedChannel(Channel parent, ChannelId channelId, boolean register, boolean hasDisconnect,
181 final ChannelHandler... handlers) {
182 super(parent, channelId);
183 metadata = metadata(hasDisconnect);
184 config = new DefaultChannelConfig(this);
185 setup(register, handlers);
186 }
187
188
198 public EmbeddedChannel(ChannelId channelId, boolean hasDisconnect, final ChannelConfig config,
199 final ChannelHandler... handlers) {
200 super(null, channelId);
201 metadata = metadata(hasDisconnect);
202 this.config = ObjectUtil.checkNotNull(config, "config");
203 setup(true, handlers);
204 }
205
206 private static ChannelMetadata metadata(boolean hasDisconnect) {
207 return hasDisconnect ? METADATA_DISCONNECT : METADATA_NO_DISCONNECT;
208 }
209
210 private void setup(boolean register, final ChannelHandler... handlers) {
211 ObjectUtil.checkNotNull(handlers, "handlers");
212 ChannelPipeline p = pipeline();
213 p.addLast(new ChannelInitializer<Channel>() {
214 @Override
215 protected void initChannel(Channel ch) throws Exception {
216 ChannelPipeline pipeline = ch.pipeline();
217 for (ChannelHandler h: handlers) {
218 if (h == null) {
219 break;
220 }
221 pipeline.addLast(h);
222 }
223 }
224 });
225 if (register) {
226 ChannelFuture future = loop.register(this);
227 assert future.isDone();
228 }
229 }
230
231
234 public void register() throws Exception {
235 ChannelFuture future = loop.register(this);
236 assert future.isDone();
237 Throwable cause = future.cause();
238 if (cause != null) {
239 PlatformDependent.throwException(cause);
240 }
241 }
242
243 @Override
244 protected final DefaultChannelPipeline newChannelPipeline() {
245 return new EmbeddedChannelPipeline(this);
246 }
247
248 @Override
249 public ChannelMetadata metadata() {
250 return metadata;
251 }
252
253 @Override
254 public ChannelConfig config() {
255 return config;
256 }
257
258 @Override
259 public boolean isOpen() {
260 return state != State.CLOSED;
261 }
262
263 @Override
264 public boolean isActive() {
265 return state == State.ACTIVE;
266 }
267
268
271 public Queue<Object> inboundMessages() {
272 if (inboundMessages == null) {
273 inboundMessages = new ArrayDeque<Object>();
274 }
275 return inboundMessages;
276 }
277
278
281 @Deprecated
282 public Queue<Object> lastInboundBuffer() {
283 return inboundMessages();
284 }
285
286
289 public Queue<Object> outboundMessages() {
290 if (outboundMessages == null) {
291 outboundMessages = new ArrayDeque<Object>();
292 }
293 return outboundMessages;
294 }
295
296
299 @Deprecated
300 public Queue<Object> lastOutboundBuffer() {
301 return outboundMessages();
302 }
303
304
307 @SuppressWarnings("unchecked")
308 public <T> T readInbound() {
309 T message = (T) poll(inboundMessages);
310 if (message != null) {
311 ReferenceCountUtil.touch(message, "Caller of readInbound() will handle the message from this point");
312 }
313 return message;
314 }
315
316
319 @SuppressWarnings("unchecked")
320 public <T> T readOutbound() {
321 T message = (T) poll(outboundMessages);
322 if (message != null) {
323 ReferenceCountUtil.touch(message, "Caller of readOutbound() will handle the message from this point.");
324 }
325 return message;
326 }
327
328
335 public boolean writeInbound(Object... msgs) {
336 ensureOpen();
337 if (msgs.length == 0) {
338 return isNotEmpty(inboundMessages);
339 }
340
341 ChannelPipeline p = pipeline();
342 for (Object m: msgs) {
343 p.fireChannelRead(m);
344 }
345
346 flushInbound(false, voidPromise());
347 return isNotEmpty(inboundMessages);
348 }
349
350
356 public ChannelFuture writeOneInbound(Object msg) {
357 return writeOneInbound(msg, newPromise());
358 }
359
360
366 public ChannelFuture writeOneInbound(Object msg, ChannelPromise promise) {
367 if (checkOpen(true)) {
368 pipeline().fireChannelRead(msg);
369 }
370 return checkException(promise);
371 }
372
373
378 public EmbeddedChannel flushInbound() {
379 flushInbound(true, voidPromise());
380 return this;
381 }
382
383 private ChannelFuture flushInbound(boolean recordException, ChannelPromise promise) {
384 if (checkOpen(recordException)) {
385 pipeline().fireChannelReadComplete();
386 runPendingTasks();
387 }
388
389 return checkException(promise);
390 }
391
392
398 public boolean writeOutbound(Object... msgs) {
399 ensureOpen();
400 if (msgs.length == 0) {
401 return isNotEmpty(outboundMessages);
402 }
403
404 RecyclableArrayList futures = RecyclableArrayList.newInstance(msgs.length);
405 try {
406 for (Object m: msgs) {
407 if (m == null) {
408 break;
409 }
410 futures.add(write(m));
411 }
412
413 flushOutbound0();
414
415 int size = futures.size();
416 for (int i = 0; i < size; i++) {
417 ChannelFuture future = (ChannelFuture) futures.get(i);
418 if (future.isDone()) {
419 recordException(future);
420 } else {
421
422 future.addListener(recordExceptionListener);
423 }
424 }
425
426 checkException();
427 return isNotEmpty(outboundMessages);
428 } finally {
429 futures.recycle();
430 }
431 }
432
433
439 public ChannelFuture writeOneOutbound(Object msg) {
440 return writeOneOutbound(msg, newPromise());
441 }
442
443
449 public ChannelFuture writeOneOutbound(Object msg, ChannelPromise promise) {
450 if (checkOpen(true)) {
451 return write(msg, promise);
452 }
453 return checkException(promise);
454 }
455
456
461 public EmbeddedChannel flushOutbound() {
462 if (checkOpen(true)) {
463 flushOutbound0();
464 }
465 checkException(voidPromise());
466 return this;
467 }
468
469 private void flushOutbound0() {
470
471
472 runPendingTasks();
473
474 flush();
475 }
476
477
482 public boolean finish() {
483 return finish(false);
484 }
485
486
492 public boolean finishAndReleaseAll() {
493 return finish(true);
494 }
495
496
502 private boolean finish(boolean releaseAll) {
503 close();
504 try {
505 checkException();
506 return isNotEmpty(inboundMessages) || isNotEmpty(outboundMessages);
507 } finally {
508 if (releaseAll) {
509 releaseAll(inboundMessages);
510 releaseAll(outboundMessages);
511 }
512 }
513 }
514
515
519 public boolean releaseInbound() {
520 return releaseAll(inboundMessages);
521 }
522
523
527 public boolean releaseOutbound() {
528 return releaseAll(outboundMessages);
529 }
530
531 private static boolean releaseAll(Queue<Object> queue) {
532 if (isNotEmpty(queue)) {
533 for (;;) {
534 Object msg = queue.poll();
535 if (msg == null) {
536 break;
537 }
538 ReferenceCountUtil.release(msg);
539 }
540 return true;
541 }
542 return false;
543 }
544
545 private void finishPendingTasks(boolean cancel) {
546 runPendingTasks();
547 if (cancel) {
548
549 loop.cancelScheduledTasks();
550 }
551 }
552
553 @Override
554 public final ChannelFuture close() {
555 return close(newPromise());
556 }
557
558 @Override
559 public final ChannelFuture disconnect() {
560 return disconnect(newPromise());
561 }
562
563 @Override
564 public final ChannelFuture close(ChannelPromise promise) {
565
566
567 runPendingTasks();
568 ChannelFuture future = super.close(promise);
569
570
571 finishPendingTasks(true);
572 return future;
573 }
574
575 @Override
576 public final ChannelFuture disconnect(ChannelPromise promise) {
577 ChannelFuture future = super.disconnect(promise);
578 finishPendingTasks(!metadata.hasDisconnect());
579 return future;
580 }
581
582 private static boolean isNotEmpty(Queue<Object> queue) {
583 return queue != null && !queue.isEmpty();
584 }
585
586 private static Object poll(Queue<Object> queue) {
587 return queue != null ? queue.poll() : null;
588 }
589
590
594 public void runPendingTasks() {
595 try {
596 loop.runTasks();
597 } catch (Exception e) {
598 recordException(e);
599 }
600
601 try {
602 loop.runScheduledTasks();
603 } catch (Exception e) {
604 recordException(e);
605 }
606 }
607
608
613 public long runScheduledPendingTasks() {
614 try {
615 return loop.runScheduledTasks();
616 } catch (Exception e) {
617 recordException(e);
618 return loop.nextScheduledTask();
619 }
620 }
621
622 private void recordException(ChannelFuture future) {
623 if (!future.isSuccess()) {
624 recordException(future.cause());
625 }
626 }
627
628 private void recordException(Throwable cause) {
629 if (lastException == null) {
630 lastException = cause;
631 } else {
632 logger.warn(
633 "More than one exception was raised. " +
634 "Will report only the first one and log others.", cause);
635 }
636 }
637
638
641 private ChannelFuture checkException(ChannelPromise promise) {
642 Throwable t = lastException;
643 if (t != null) {
644 lastException = null;
645
646 if (promise.isVoid()) {
647 PlatformDependent.throwException(t);
648 }
649
650 return promise.setFailure(t);
651 }
652
653 return promise.setSuccess();
654 }
655
656
659 public void checkException() {
660 checkException(voidPromise());
661 }
662
663
667 private boolean checkOpen(boolean recordException) {
668 if (!isOpen()) {
669 if (recordException) {
670 recordException(new ClosedChannelException());
671 }
672 return false;
673 }
674
675 return true;
676 }
677
678
681 protected final void ensureOpen() {
682 if (!checkOpen(true)) {
683 checkException();
684 }
685 }
686
687 @Override
688 protected boolean isCompatible(EventLoop loop) {
689 return loop instanceof EmbeddedEventLoop;
690 }
691
692 @Override
693 protected SocketAddress localAddress0() {
694 return isActive()? LOCAL_ADDRESS : null;
695 }
696
697 @Override
698 protected SocketAddress remoteAddress0() {
699 return isActive()? REMOTE_ADDRESS : null;
700 }
701
702 @Override
703 protected void doRegister() throws Exception {
704 state = State.ACTIVE;
705 }
706
707 @Override
708 protected void doBind(SocketAddress localAddress) throws Exception {
709
710 }
711
712 @Override
713 protected void doDisconnect() throws Exception {
714 if (!metadata.hasDisconnect()) {
715 doClose();
716 }
717 }
718
719 @Override
720 protected void doClose() throws Exception {
721 state = State.CLOSED;
722 }
723
724 @Override
725 protected void doBeginRead() throws Exception {
726
727 }
728
729 @Override
730 protected AbstractUnsafe newUnsafe() {
731 return new EmbeddedUnsafe();
732 }
733
734 @Override
735 public Unsafe unsafe() {
736 return ((EmbeddedUnsafe) super.unsafe()).wrapped;
737 }
738
739 @Override
740 protected void doWrite(ChannelOutboundBuffer in) throws Exception {
741 for (;;) {
742 Object msg = in.current();
743 if (msg == null) {
744 break;
745 }
746
747 ReferenceCountUtil.retain(msg);
748 handleOutboundMessage(msg);
749 in.remove();
750 }
751 }
752
753
758 protected void handleOutboundMessage(Object msg) {
759 outboundMessages().add(msg);
760 }
761
762
765 protected void handleInboundMessage(Object msg) {
766 inboundMessages().add(msg);
767 }
768
769 private final class EmbeddedUnsafe extends AbstractUnsafe {
770
771
772
773 final Unsafe wrapped = new Unsafe() {
774 @Override
775 public RecvByteBufAllocator.Handle recvBufAllocHandle() {
776 return EmbeddedUnsafe.this.recvBufAllocHandle();
777 }
778
779 @Override
780 public SocketAddress localAddress() {
781 return EmbeddedUnsafe.this.localAddress();
782 }
783
784 @Override
785 public SocketAddress remoteAddress() {
786 return EmbeddedUnsafe.this.remoteAddress();
787 }
788
789 @Override
790 public void register(EventLoop eventLoop, ChannelPromise promise) {
791 EmbeddedUnsafe.this.register(eventLoop, promise);
792 runPendingTasks();
793 }
794
795 @Override
796 public void bind(SocketAddress localAddress, ChannelPromise promise) {
797 EmbeddedUnsafe.this.bind(localAddress, promise);
798 runPendingTasks();
799 }
800
801 @Override
802 public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
803 EmbeddedUnsafe.this.connect(remoteAddress, localAddress, promise);
804 runPendingTasks();
805 }
806
807 @Override
808 public void disconnect(ChannelPromise promise) {
809 EmbeddedUnsafe.this.disconnect(promise);
810 runPendingTasks();
811 }
812
813 @Override
814 public void close(ChannelPromise promise) {
815 EmbeddedUnsafe.this.close(promise);
816 runPendingTasks();
817 }
818
819 @Override
820 public void closeForcibly() {
821 EmbeddedUnsafe.this.closeForcibly();
822 runPendingTasks();
823 }
824
825 @Override
826 public void deregister(ChannelPromise promise) {
827 EmbeddedUnsafe.this.deregister(promise);
828 runPendingTasks();
829 }
830
831 @Override
832 public void beginRead() {
833 EmbeddedUnsafe.this.beginRead();
834 runPendingTasks();
835 }
836
837 @Override
838 public void write(Object msg, ChannelPromise promise) {
839 EmbeddedUnsafe.this.write(msg, promise);
840 runPendingTasks();
841 }
842
843 @Override
844 public void flush() {
845 EmbeddedUnsafe.this.flush();
846 runPendingTasks();
847 }
848
849 @Override
850 public ChannelPromise voidPromise() {
851 return EmbeddedUnsafe.this.voidPromise();
852 }
853
854 @Override
855 public ChannelOutboundBuffer outboundBuffer() {
856 return EmbeddedUnsafe.this.outboundBuffer();
857 }
858 };
859
860 @Override
861 public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
862 safeSetSuccess(promise);
863 }
864 }
865
866 private final class EmbeddedChannelPipeline extends DefaultChannelPipeline {
867 EmbeddedChannelPipeline(EmbeddedChannel channel) {
868 super(channel);
869 }
870
871 @Override
872 protected void onUnhandledInboundException(Throwable cause) {
873 recordException(cause);
874 }
875
876 @Override
877 protected void onUnhandledInboundMessage(ChannelHandlerContext ctx, Object msg) {
878 handleInboundMessage(msg);
879 }
880 }
881 }
882