1 /*
2  * JBoss, Home of Professional Open Source.
3  * Copyright 2014 Red Hat, Inc., and individual contributors
4  * as indicated by the @author tags.
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  *     http://www.apache.org/licenses/LICENSE-2.0
11  *
12  *  Unless required by applicable law or agreed to in writing, software
13  *  distributed under the License is distributed on an "AS IS" BASIS,
14  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  *  See the License for the specific language governing permissions and
16  *  limitations under the License.
17  */

18
19 package io.undertow.servlet.spec;
20
21 import static org.xnio.Bits.allAreClear;
22 import static org.xnio.Bits.anyAreClear;
23 import static org.xnio.Bits.anyAreSet;
24
25 import java.io.IOException;
26 import java.nio.ByteBuffer;
27 import java.nio.channels.FileChannel;
28 import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
29 import javax.servlet.DispatcherType;
30 import javax.servlet.ServletOutputStream;
31 import javax.servlet.ServletRequest;
32 import javax.servlet.WriteListener;
33
34 import io.undertow.UndertowLogger;
35 import org.xnio.Buffers;
36 import org.xnio.ChannelListener;
37 import org.xnio.IoUtils;
38 import org.xnio.channels.Channels;
39 import org.xnio.channels.StreamSinkChannel;
40 import io.undertow.connector.ByteBufferPool;
41 import io.undertow.connector.PooledByteBuffer;
42 import io.undertow.io.BufferWritableOutputStream;
43 import io.undertow.server.protocol.http.HttpAttachments;
44 import io.undertow.servlet.UndertowServletMessages;
45 import io.undertow.servlet.handlers.ServletRequestContext;
46 import io.undertow.util.Headers;
47
48 /**
49  * This stream essentially has two modes. When it is being used in standard blocking mode then
50  * it will buffer in the pooled buffer. If the stream is closed before the buffer is full it will
51  * set a content-length header if one has not been explicitly set.
52  * <p>
53  * If a content-length header was present when the stream was created then it will automatically
54  * close and flush itself once the appropriate amount of data has been written.
55  * <p>
56  * Once the listener has been set it goes into async mode, and writes become non blocking. Most methods
57  * have two different code paths, based on if the listener has been set or not
58  * <p>
59  * Once the write listener has been set operations must only be invoked on this stream from the write
60  * listener callback. Attempting to invoke from a different thread will result in an IllegalStateException.
61  * <p>
62  * Async listener tasks are queued in the {@link AsyncContextImpl}. At most one listener can be active at
63  * one time, which simplifies the thread safety requirements.
64  *
65  * @author Stuart Douglas
66  */

67 public class ServletOutputStreamImpl extends ServletOutputStream implements BufferWritableOutputStream {
68
69     private final ServletRequestContext servletRequestContext;
70     private PooledByteBuffer pooledBuffer;
71     private ByteBuffer buffer;
72     private Integer bufferSize;
73     private StreamSinkChannel channel;
74     private long written;
75     private volatile int state;
76     private volatile boolean asyncIoStarted;
77     private AsyncContextImpl asyncContext;
78
79     private WriteListener listener;
80     private WriteChannelListener internalListener;
81
82
83     /**
84      * buffers that are queued up to be written via async writes. This will include
85      * {@link #buffer} as the first element, and maybe a user supplied buffer that
86      * did not fit
87      */

88     private ByteBuffer[] buffersToWrite;
89
90     private FileChannel pendingFile;
91
92     private static final int FLAG_CLOSED = 1;
93     private static final int FLAG_WRITE_STARTED = 1 << 1;
94     private static final int FLAG_READY = 1 << 2;
95     private static final int FLAG_DELEGATE_SHUTDOWN = 1 << 3;
96     private static final int FLAG_IN_CALLBACK = 1 << 4;
97
98     //TODO: should this be configurable?
99     private static final int MAX_BUFFERS_TO_ALLOCATE = 6;
100
101     private static final AtomicIntegerFieldUpdater<ServletOutputStreamImpl> stateUpdater = AtomicIntegerFieldUpdater.newUpdater(ServletOutputStreamImpl.class"state");
102
103
104     /**
105      * Construct a new instance.  No write timeout is configured.
106      */

107     public ServletOutputStreamImpl(final ServletRequestContext servletRequestContext) {
108         this.servletRequestContext = servletRequestContext;
109     }
110
111     /**
112      * Construct a new instance.  No write timeout is configured.
113      */

114     public ServletOutputStreamImpl(final ServletRequestContext servletRequestContext, int bufferSize) {
115         this.bufferSize = bufferSize;
116         this.servletRequestContext = servletRequestContext;
117     }
118
119     /**
120      * {@inheritDoc}
121      */

122     public void write(final int b) throws IOException {
123         write(new byte[]{(byte) b}, 0, 1);
124     }
125
126     /**
127      * {@inheritDoc}
128      */

129     public void write(final byte[] b) throws IOException {
130         write(b, 0, b.length);
131     }
132
133     /**
134      * {@inheritDoc}
135      */

136     public void write(final byte[] b, final int off, final int len) throws IOException {
137         if (anyAreSet(state, FLAG_CLOSED) || servletRequestContext.getOriginalResponse().isTreatAsCommitted()) {
138             throw UndertowServletMessages.MESSAGES.streamIsClosed();
139         }
140         if (len < 1) {
141             return;
142         }
143
144         if (listener == null) {
145             ByteBuffer buffer = buffer();
146             if (buffer.remaining() < len) {
147                 writeTooLargeForBuffer(b, off, len, buffer);
148             } else {
149                 buffer.put(b, off, len);
150                 if (buffer.remaining() == 0) {
151                     writeBufferBlocking(false);
152                 }
153             }
154             updateWritten(len);
155         } else {
156             writeAsync(b, off, len);
157         }
158     }
159
160     private void writeTooLargeForBuffer(byte[] b, int off, int len, ByteBuffer buffer) throws IOException {
161         //so what we have will not fit.
162         //We allocate multiple buffers up to MAX_BUFFERS_TO_ALLOCATE
163         //and put it in them
164         //if it still dopes not fit we loop, re-using these buffers
165
166         StreamSinkChannel channel = this.channel;
167         if (channel == null) {
168             this.channel = channel = servletRequestContext.getExchange().getResponseChannel();
169         }
170         final ByteBufferPool bufferPool = servletRequestContext.getExchange().getConnection().getByteBufferPool();
171         ByteBuffer[] buffers = new ByteBuffer[MAX_BUFFERS_TO_ALLOCATE + 1];
172         PooledByteBuffer[] pooledBuffers = new PooledByteBuffer[MAX_BUFFERS_TO_ALLOCATE];
173         try {
174             buffers[0] = buffer;
175             int bytesWritten = 0;
176             int rem = buffer.remaining();
177             buffer.put(b, bytesWritten + off, rem);
178             buffer.flip();
179             bytesWritten += rem;
180             int bufferCount = 1;
181             for (int i = 0; i < MAX_BUFFERS_TO_ALLOCATE; ++i) {
182                 PooledByteBuffer pooled = bufferPool.allocate();
183                 pooledBuffers[bufferCount - 1] = pooled;
184                 buffers[bufferCount++] = pooled.getBuffer();
185                 ByteBuffer cb = pooled.getBuffer();
186                 int toWrite = len - bytesWritten;
187                 if (toWrite > cb.remaining()) {
188                     rem = cb.remaining();
189                     cb.put(b, bytesWritten + off, rem);
190                     cb.flip();
191                     bytesWritten += rem;
192                 } else {
193                     cb.put(b, bytesWritten + off, toWrite);
194                     bytesWritten = len;
195                     cb.flip();
196                     break;
197                 }
198             }
199             Channels.writeBlocking(channel, buffers, 0, bufferCount);
200             while (bytesWritten < len) {
201                 //ok, it did not fit, loop and loop and loop until it is done
202                 bufferCount = 0;
203                 for (int i = 0; i < MAX_BUFFERS_TO_ALLOCATE + 1; ++i) {
204                     ByteBuffer cb = buffers[i];
205                     cb.clear();
206                     bufferCount++;
207                     int toWrite = len - bytesWritten;
208                     if (toWrite > cb.remaining()) {
209                         rem = cb.remaining();
210                         cb.put(b, bytesWritten + off, rem);
211                         cb.flip();
212                         bytesWritten += rem;
213                     } else {
214                         cb.put(b, bytesWritten + off, toWrite);
215                         bytesWritten = len;
216                         cb.flip();
217                         break;
218                     }
219                 }
220                 Channels.writeBlocking(channel, buffers, 0, bufferCount);
221             }
222             buffer.clear();
223         } finally {
224             for (int i = 0; i < pooledBuffers.length; ++i) {
225                 PooledByteBuffer p = pooledBuffers[i];
226                 if (p == null) {
227                     break;
228                 }
229                 p.close();
230             }
231         }
232     }
233
234     private void writeAsync(byte[] b, int off, int len) throws IOException {
235         if (anyAreClear(state, FLAG_READY)) {
236             throw UndertowServletMessages.MESSAGES.streamNotReady();
237         }
238         //even though we are in async mode we are still buffering
239         try {
240             ByteBuffer buffer = buffer();
241             if (buffer.remaining() > len) {
242                 buffer.put(b, off, len);
243             } else {
244                 buffer.flip();
245                 final ByteBuffer userBuffer = ByteBuffer.wrap(b, off, len);
246                 final ByteBuffer[] bufs = new ByteBuffer[]{buffer, userBuffer};
247                 long toWrite = Buffers.remaining(bufs);
248                 long res;
249                 long written = 0;
250                 createChannel();
251                 setFlags(FLAG_WRITE_STARTED);
252                 do {
253                     res = channel.write(bufs);
254                     written += res;
255                     if (res == 0) {
256                         //write it out with a listener
257                         //but we need to copy any extra data
258                         final ByteBuffer copy = ByteBuffer.allocate(userBuffer.remaining());
259                         copy.put(userBuffer);
260                         copy.flip();
261
262                         this.buffersToWrite = new ByteBuffer[]{buffer, copy};
263                         clearFlags(FLAG_READY);
264                         return;
265                     }
266                 } while (written < toWrite);
267                 buffer.clear();
268             }
269         } finally {
270             updateWrittenAsync(len);
271         }
272     }
273
274
275     @Override
276     public void write(ByteBuffer[] buffers) throws IOException {
277         if (anyAreSet(state, FLAG_CLOSED) || servletRequestContext.getOriginalResponse().isTreatAsCommitted()) {
278             throw UndertowServletMessages.MESSAGES.streamIsClosed();
279         }
280         int len = 0;
281         for (ByteBuffer buf : buffers) {
282             len += buf.remaining();
283         }
284         if (len < 1) {
285             return;
286         }
287
288         if (listener == null) {
289             //if we have received the exact amount of content write it out in one go
290             //this is a common case when writing directly from a buffer cache.
291             if (this.written == 0 && len == servletRequestContext.getOriginalResponse().getContentLength()) {
292                 if (channel == null) {
293                     channel = servletRequestContext.getExchange().getResponseChannel();
294                 }
295                 Channels.writeBlocking(channel, buffers, 0, buffers.length);
296                 setFlags(FLAG_WRITE_STARTED);
297             } else {
298                 ByteBuffer buffer = buffer();
299                 if (len < buffer.remaining()) {
300                     Buffers.copy(buffer, buffers, 0, buffers.length);
301                 } else {
302                     if (channel == null) {
303                         channel = servletRequestContext.getExchange().getResponseChannel();
304                     }
305                     if (buffer.position() == 0) {
306                         Channels.writeBlocking(channel, buffers, 0, buffers.length);
307                     } else {
308                         final ByteBuffer[] newBuffers = new ByteBuffer[buffers.length + 1];
309                         buffer.flip();
310                         newBuffers[0] = buffer;
311                         System.arraycopy(buffers, 0, newBuffers, 1, buffers.length);
312                         Channels.writeBlocking(channel, newBuffers, 0, newBuffers.length);
313                         buffer.clear();
314                     }
315                     setFlags(FLAG_WRITE_STARTED);
316                 }
317             }
318
319             updateWritten(len);
320         } else {
321             if (anyAreClear(state, FLAG_READY)) {
322                 throw UndertowServletMessages.MESSAGES.streamNotReady();
323             }
324             //even though we are in async mode we are still buffering
325             try {
326                 ByteBuffer buffer = buffer();
327                 if (buffer.remaining() > len) {
328                     Buffers.copy(buffer, buffers, 0, buffers.length);
329                 } else {
330                     final ByteBuffer[] bufs = new ByteBuffer[buffers.length + 1];
331                     buffer.flip();
332                     bufs[0] = buffer;
333                     System.arraycopy(buffers, 0, bufs, 1, buffers.length);
334                     long toWrite = Buffers.remaining(bufs);
335                     long res;
336                     long written = 0;
337                     createChannel();
338                     setFlags(FLAG_WRITE_STARTED);
339                     do {
340                         res = channel.write(bufs);
341                         written += res;
342                         if (res == 0) {
343                             //write it out with a listener
344                             //but we need to copy any extra data
345                             //TODO: should really allocate from the pool here
346                             final ByteBuffer copy = ByteBuffer.allocate((int) Buffers.remaining(buffers));
347                             Buffers.copy(copy, buffers, 0, buffers.length);
348                             copy.flip();
349                             this.buffersToWrite = new ByteBuffer[]{buffer, copy};
350                             clearFlags(FLAG_READY);
351                             channel.resumeWrites();
352                             return;
353                         }
354                     } while (written < toWrite);
355                     buffer.clear();
356                 }
357             } finally {
358                 updateWrittenAsync(len);
359             }
360         }
361     }
362
363     @Override
364     public void write(ByteBuffer byteBuffer) throws IOException {
365         write(new ByteBuffer[]{byteBuffer});
366     }
367
368     void updateWritten(final long len) throws IOException {
369         this.written += len;
370         long contentLength = servletRequestContext.getOriginalResponse().getContentLength();
371         if (contentLength != -1 && this.written >= contentLength) {
372             close();
373         }
374     }
375
376     void updateWrittenAsync(final long len) throws IOException {
377         this.written += len;
378         long contentLength = servletRequestContext.getOriginalResponse().getContentLength();
379         if (contentLength != -1 && this.written >= contentLength) {
380             setFlags(FLAG_CLOSED);
381             //if buffersToWrite is set we are already flushing
382             //so we don't have to do anything
383             if (buffersToWrite == null && pendingFile == null) {
384                 if (flushBufferAsync(true)) {
385                     channel.shutdownWrites();
386                     setFlags(FLAG_DELEGATE_SHUTDOWN);
387                     channel.flush();
388                     if (pooledBuffer != null) {
389                         pooledBuffer.close();
390                         buffer = null;
391                         pooledBuffer = null;
392                     }
393                 }
394             }
395         }
396     }
397
398     private boolean flushBufferAsync(final boolean writeFinal) throws IOException {
399
400         ByteBuffer[] bufs = buffersToWrite;
401         if (bufs == null) {
402             ByteBuffer buffer = this.buffer;
403             if (buffer == null || buffer.position() == 0) {
404                 return true;
405             }
406             buffer.flip();
407             bufs = new ByteBuffer[]{buffer};
408         }
409         long toWrite = Buffers.remaining(bufs);
410         if (toWrite == 0) {
411             //we clear the buffer, so it can be written to again
412             buffer.clear();
413             return true;
414         }
415         setFlags(FLAG_WRITE_STARTED);
416         createChannel();
417         long res;
418         long written = 0;
419         do {
420             if (writeFinal) {
421                 res = channel.writeFinal(bufs);
422             } else {
423                 res = channel.write(bufs);
424             }
425             written += res;
426             if (res == 0) {
427                 //write it out with a listener
428                 clearFlags(FLAG_READY);
429                 buffersToWrite = bufs;
430                 channel.resumeWrites();
431                 return false;
432             }
433         } while (written < toWrite);
434         buffer.clear();
435         return true;
436     }
437
438
439     /**
440      * Returns the underlying buffer. If this has not been created yet then
441      * it is created.
442      * <p>
443      * Callers that use this method must call {@link #updateWritten(long)} to update the written
444      * amount.
445      * <p>
446      * This allows the buffer to be filled directly, which can be more efficient.
447      * <p>
448      * This method is basically a hack that should only be used by the print writer
449      *
450      * @return The underlying buffer
451      */

452     ByteBuffer underlyingBuffer() {
453         if (anyAreSet(state, FLAG_CLOSED)) {
454             return null;
455         }
456         return buffer();
457     }
458
459     /**
460      * {@inheritDoc}
461      */

462     public void flush() throws IOException {
463         //according to the servlet spec we ignore a flush from within an include
464         if (servletRequestContext.getOriginalRequest().getDispatcherType() == DispatcherType.INCLUDE ||
465                 servletRequestContext.getOriginalResponse().isTreatAsCommitted()) {
466             return;
467         }
468         if (servletRequestContext.getDeployment().getDeploymentInfo().isIgnoreFlush() &&
469                 servletRequestContext.getExchange().isRequestComplete() &&
470                 servletRequestContext.getOriginalResponse().getHeader(Headers.TRANSFER_ENCODING_STRING) == null) {
471             //we mark the stream as flushed, but don't actually flush
472             //because in most cases flush just kills performance
473             //we only do this if the request is fully read, so that http tunneling scenarios still work
474             servletRequestContext.getOriginalResponse().setIgnoredFlushPerformed(true);
475             return;
476         }
477         flushInternal();
478     }
479
480     /**
481      * {@inheritDoc}
482      */

483     public void flushInternal() throws IOException {
484         if (listener == null) {
485             if (anyAreSet(state, FLAG_CLOSED)) {
486                 //just return
487                 return;
488             }
489             if (buffer != null && buffer.position() != 0) {
490                 writeBufferBlocking(false);
491             }
492             if (channel == null) {
493                 channel = servletRequestContext.getExchange().getResponseChannel();
494             }
495             Channels.flushBlocking(channel);
496         } else {
497             if (anyAreClear(state, FLAG_READY)) {
498                 return;
499             }
500             createChannel();
501             if (buffer == null || buffer.position() == 0) {
502                 //nothing to flush, we just flush the underlying stream
503                 //it does not matter if this succeeds or not
504                 channel.flush();
505                 return;
506             }
507             //we have some data in the buffer, we can just write it out
508             //if the write fails we just compact, rather than changing the ready state
509             setFlags(FLAG_WRITE_STARTED);
510             buffer.flip();
511             long res;
512             do {
513                 res = channel.write(buffer);
514             } while (buffer.hasRemaining() && res != 0);
515             if (!buffer.hasRemaining()) {
516                 channel.flush();
517             }
518             buffer.compact();
519         }
520     }
521
522     @Override
523     public void transferFrom(FileChannel source) throws IOException {
524         if (anyAreSet(state, FLAG_CLOSED) || servletRequestContext.getOriginalResponse().isTreatAsCommitted()) {
525             throw UndertowServletMessages.MESSAGES.streamIsClosed();
526         }
527         if (listener == null) {
528             if (buffer != null && buffer.position() != 0) {
529                 writeBufferBlocking(false);
530             }
531             if (channel == null) {
532                 channel = servletRequestContext.getExchange().getResponseChannel();
533             }
534             long position = source.position();
535             long count = source.size() - position;
536             Channels.transferBlocking(channel, source, position, count);
537             updateWritten(count);
538         } else {
539             setFlags(FLAG_WRITE_STARTED);
540             createChannel();
541
542             long pos = 0;
543             try {
544                 long size = source.size();
545                 pos = source.position();
546
547                 while (size - pos > 0) {
548                     long ret = channel.transferFrom(pendingFile, pos, size - pos);
549                     if (ret <= 0) {
550                         clearFlags(FLAG_READY);
551                         pendingFile = source;
552                         source.position(pos);
553                         channel.resumeWrites();
554                         return;
555                     }
556                     pos += ret;
557                 }
558             } finally {
559                 updateWrittenAsync(pos - source.position());
560             }
561         }
562
563     }
564
565
566     private void writeBufferBlocking(final boolean writeFinal) throws IOException {
567         if (channel == null) {
568             channel = servletRequestContext.getExchange().getResponseChannel();
569         }
570         buffer.flip();
571         while (buffer.hasRemaining()) {
572             if (writeFinal) {
573                 channel.writeFinal(buffer);
574             } else {
575                 channel.write(buffer);
576             }
577             if (buffer.hasRemaining()) {
578                 channel.awaitWritable();
579             }
580         }
581         buffer.clear();
582         setFlags(FLAG_WRITE_STARTED);
583     }
584
585     /**
586      * {@inheritDoc}
587      */

588     public void close() throws IOException {
589         if (servletRequestContext.getOriginalRequest().getDispatcherType() == DispatcherType.INCLUDE ||
590                 servletRequestContext.getOriginalResponse().isTreatAsCommitted()) {
591             return;
592         }
593         if (listener == null) {
594             if (anyAreSet(state, FLAG_CLOSED)) return;
595             setFlags(FLAG_CLOSED);
596             clearFlags(FLAG_READY);
597             if (allAreClear(state, FLAG_WRITE_STARTED) && channel == null && servletRequestContext.getOriginalResponse().getHeader(Headers.CONTENT_LENGTH_STRING) == null) {
598                 if (servletRequestContext.getOriginalResponse().getHeader(Headers.TRANSFER_ENCODING_STRING) == null
599                         && servletRequestContext.getExchange().getAttachment(HttpAttachments.RESPONSE_TRAILER_SUPPLIER) == null
600                         && servletRequestContext.getExchange().getAttachment(HttpAttachments.RESPONSE_TRAILERS) == null) {
601                     if (buffer == null) {
602                         servletRequestContext.getExchange().getResponseHeaders().put(Headers.CONTENT_LENGTH, "0");
603                     } else {
604                         servletRequestContext.getExchange().getResponseHeaders().put(Headers.CONTENT_LENGTH, Integer.toString(buffer.position()));
605                     }
606                 }
607             }
608             try {
609                 if (buffer != null) {
610                     writeBufferBlocking(true);
611                 }
612                 if (channel == null) {
613                     channel = servletRequestContext.getExchange().getResponseChannel();
614                 }
615                 setFlags(FLAG_DELEGATE_SHUTDOWN);
616                 StreamSinkChannel channel = this.channel;
617                 if (channel != null) { //mock requests
618                     channel.shutdownWrites();
619                     Channels.flushBlocking(channel);
620                 }
621             } catch (IOException | RuntimeException | Error e) {
622                 IoUtils.safeClose(this.channel);
623                 throw e;
624             } finally {
625                 if (pooledBuffer != null) {
626                     pooledBuffer.close();
627                     buffer = null;
628                 } else {
629                     buffer = null;
630                 }
631             }
632         } else {
633             closeAsync();
634         }
635     }
636
637     /**
638      * Closes the channel, and flushes any data out using async IO
639      * <p>
640      * This is used in two situations, if an output stream is not closed when a
641      * request is done, and when performing a close on a stream that is in async
642      * mode
643      *
644      * @throws IOException
645      */

646     public void closeAsync() throws IOException {
647         if (anyAreSet(state, FLAG_CLOSED) || servletRequestContext.getOriginalResponse().isTreatAsCommitted()) {
648             return;
649         }
650         if (!servletRequestContext.getExchange().isInIoThread()) {
651             servletRequestContext.getExchange().getIoThread().execute(new Runnable() {
652                 @Override
653                 public void run() {
654                     try {
655                         closeAsync();
656                     } catch (IOException e) {
657                         UndertowLogger.REQUEST_IO_LOGGER.closeAsyncFailed(e);
658                     }
659                 }
660             });
661             return;
662         }
663         try {
664
665             setFlags(FLAG_CLOSED);
666             clearFlags(FLAG_READY);
667             if (allAreClear(state, FLAG_WRITE_STARTED) && channel == null) {
668
669                 if (servletRequestContext.getOriginalResponse().getHeader(Headers.TRANSFER_ENCODING_STRING) == null) {
670                     if (buffer == null) {
671                         servletRequestContext.getOriginalResponse().setHeader(Headers.CONTENT_LENGTH, "0");
672                     } else {
673                         servletRequestContext.getOriginalResponse().setHeader(Headers.CONTENT_LENGTH, Integer.toString(buffer.position()));
674                     }
675                 }
676             }
677             createChannel();
678             if (buffer != null) {
679                 if (!flushBufferAsync(true)) {
680                     return;
681                 }
682
683                 if (pooledBuffer != null) {
684                     pooledBuffer.close();
685                     buffer = null;
686                 } else {
687                     buffer = null;
688                 }
689             }
690             channel.shutdownWrites();
691             setFlags(FLAG_DELEGATE_SHUTDOWN);
692             if (!channel.flush()) {
693                 channel.resumeWrites();
694             }
695         } catch (IOException | RuntimeException | Error e) {
696             if (pooledBuffer != null) {
697                 pooledBuffer.close();
698                 pooledBuffer = null;
699                 buffer = null;
700             }
701             throw e;
702         }
703     }
704
705     private void createChannel() {
706         if (channel == null) {
707             channel = servletRequestContext.getExchange().getResponseChannel();
708             if (internalListener != null) {
709                 channel.getWriteSetter().set(internalListener);
710             }
711         }
712     }
713
714
715     private ByteBuffer buffer() {
716         ByteBuffer buffer = this.buffer;
717         if (buffer != null) {
718             return buffer;
719         }
720         if (bufferSize != null) {
721             this.buffer = ByteBuffer.allocateDirect(bufferSize);
722             return this.buffer;
723         } else {
724             this.pooledBuffer = servletRequestContext.getExchange().getConnection().getByteBufferPool().allocate();
725             this.buffer = pooledBuffer.getBuffer();
726             return this.buffer;
727         }
728     }
729
730     public void resetBuffer() {
731         if (allAreClear(state, FLAG_WRITE_STARTED)) {
732             if (pooledBuffer != null) {
733                 pooledBuffer.close();
734                 pooledBuffer = null;
735             }
736             buffer = null;
737             this.written = 0;
738         } else {
739             throw UndertowServletMessages.MESSAGES.responseAlreadyCommited();
740         }
741     }
742
743     public void setBufferSize(final int size) {
744         if (buffer != null || servletRequestContext.getOriginalResponse().isTreatAsCommitted()) {
745             throw UndertowServletMessages.MESSAGES.contentHasBeenWritten();
746         }
747         this.bufferSize = size;
748     }
749
750     public boolean isClosed() {
751         return anyAreSet(state, FLAG_CLOSED);
752     }
753
754     @Override
755     public boolean isReady() {
756         if (listener == null) {
757             //TODO: is this the correct behaviour?
758             throw UndertowServletMessages.MESSAGES.streamNotInAsyncMode();
759         }
760         if (!asyncIoStarted) {
761             //if we don't add this guard here calls to isReady could start async IO too soon
762             //resulting in a 'resuming + dispatched' message
763             return false;
764         }
765         if (!anyAreSet(state, FLAG_READY)) {
766             if (channel != null) {
767                 channel.resumeWrites();
768             }
769             return false;
770         }
771         return true;
772     }
773
774     @Override
775     public void setWriteListener(final WriteListener writeListener) {
776         if (writeListener == null) {
777             throw UndertowServletMessages.MESSAGES.listenerCannotBeNull();
778         }
779         if (listener != null) {
780             throw UndertowServletMessages.MESSAGES.listenerAlreadySet();
781         }
782         final ServletRequest servletRequest = servletRequestContext.getOriginalRequest();
783         if (!servletRequest.isAsyncStarted()) {
784             throw UndertowServletMessages.MESSAGES.asyncNotStarted();
785         }
786         asyncContext = (AsyncContextImpl) servletRequest.getAsyncContext();
787         listener = writeListener;
788         //we register the write listener on the underlying connection
789         //so we don't have to force the creation of the response channel
790         //under normal circumstances this will break write listener delegation
791         this.internalListener = new WriteChannelListener();
792         if (this.channel != null) {
793             this.channel.getWriteSetter().set(internalListener);
794         }
795         //we resume from an async task, after the request has been dispatched
796         asyncContext.addAsyncTask(new Runnable() {
797             @Override
798             public void run() {
799                 asyncIoStarted = true;
800                 if (channel == null) {
801                     servletRequestContext.getExchange().getIoThread().execute(new Runnable() {
802                         @Override
803                         public void run() {
804                             internalListener.handleEvent(null);
805                         }
806                     });
807                 } else {
808                     channel.resumeWrites();
809                 }
810             }
811         });
812     }
813
814     ServletRequestContext getServletRequestContext() {
815         return servletRequestContext;
816     }
817
818
819     private class WriteChannelListener implements ChannelListener<StreamSinkChannel> {
820
821         @Override
822         public void handleEvent(final StreamSinkChannel aChannel) {
823             //flush the channel if it is closed
824             if (anyAreSet(state, FLAG_DELEGATE_SHUTDOWN)) {
825                 try {
826                     //either it will work, and the channel is closed
827                     //or it won't, and we continue with writes resumed
828                     channel.flush();
829                     return;
830                 } catch (Throwable t) {
831                     handleError(t);
832                     return;
833                 }
834             }
835             //if there is data still to write
836             if (buffersToWrite != null) {
837                 long toWrite = Buffers.remaining(buffersToWrite);
838                 long written = 0;
839                 long res;
840                 if (toWrite > 0) { //should always be true, but just to be defensive
841                     do {
842                         try {
843                             res = channel.write(buffersToWrite);
844                             written += res;
845                             if (res == 0) {
846                                 return;
847                             }
848                         } catch (Throwable t) {
849                             handleError(t);
850                             return;
851                         }
852                     } while (written < toWrite);
853                 }
854                 buffersToWrite = null;
855                 buffer.clear();
856             }
857             if (pendingFile != null) {
858                 try {
859                     long size = pendingFile.size();
860                     long pos = pendingFile.position();
861
862                     while (size - pos > 0) {
863                         long ret = channel.transferFrom(pendingFile, pos, size - pos);
864                         if (ret <= 0) {
865                             pendingFile.position(pos);
866                             return;
867                         }
868                         pos += ret;
869                     }
870                     pendingFile = null;
871                 } catch (Throwable t) {
872                     handleError(t);
873                     return;
874                 }
875             }
876             if (anyAreSet(state, FLAG_CLOSED)) {
877                 try {
878
879                     if (pooledBuffer != null) {
880                         pooledBuffer.close();
881                         buffer = null;
882                     } else {
883                         buffer = null;
884                     }
885                     channel.shutdownWrites();
886                     setFlags(FLAG_DELEGATE_SHUTDOWN);
887                     channel.flush();
888                 } catch (Throwable t) {
889                     handleError(t);
890                     return;
891                 }
892             } else {
893
894                 if (asyncContext.isDispatched()) {
895                     //this is no longer an async request
896                     //we just return for now
897                     //TODO: what do we do here? Revert back to blocking mode?
898                     channel.suspendWrites();
899                     return;
900                 }
901
902                 setFlags(FLAG_READY);
903                 try {
904                     setFlags(FLAG_IN_CALLBACK);
905
906                     //if the stream is still ready then we do not resume writes
907                     //this is per spec, we only call the listener once for each time
908                     //isReady returns true
909                     if (channel != null) {
910                         channel.suspendWrites();
911                     }
912                     servletRequestContext.getCurrentServletContext().invokeOnWritePossible(servletRequestContext.getExchange(), listener);
913                 } catch (Throwable e) {
914                     IoUtils.safeClose(channel);
915                 } finally {
916                     clearFlags(FLAG_IN_CALLBACK);
917                 }
918             }
919
920         }
921
922         private void handleError(final Throwable t) {
923
924             try {
925                 servletRequestContext.getCurrentServletContext().invokeRunnable(servletRequestContext.getExchange(), new Runnable() {
926                     @Override
927                     public void run() {
928                         listener.onError(t);
929                     }
930                 });
931             } finally {
932                 IoUtils.safeClose(channel, servletRequestContext.getExchange().getConnection());
933                 if (pooledBuffer != null) {
934                     pooledBuffer.close();
935                     pooledBuffer = null;
936                     buffer = null;
937                 }
938             }
939         }
940     }
941
942     private void setFlags(int flags) {
943         int old;
944         do {
945             old = state;
946         } while (!stateUpdater.compareAndSet(this, old, old | flags));
947     }
948
949     private void clearFlags(int flags) {
950         int old;
951         do {
952             old = state;
953         } while (!stateUpdater.compareAndSet(this, old, old & ~flags));
954     }
955
956 }
957