1 /*
2  * JBoss, Home of Professional Open Source.
3  * Copyright 2014 Red Hat, Inc., and individual contributors
4  * as indicated by the @author tags.
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  *     http://www.apache.org/licenses/LICENSE-2.0
11  *
12  *  Unless required by applicable law or agreed to in writing, software
13  *  distributed under the License is distributed on an "AS IS" BASIS,
14  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  *  See the License for the specific language governing permissions and
16  *  limitations under the License.
17  */

18
19 package io.undertow.servlet.spec;
20
21 import static org.xnio.Bits.allAreClear;
22 import static org.xnio.Bits.anyAreClear;
23 import static org.xnio.Bits.anyAreSet;
24
25 import java.io.IOException;
26 import java.nio.ByteBuffer;
27 import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
28 import javax.servlet.ReadListener;
29 import javax.servlet.ServletInputStream;
30
31 import org.xnio.Buffers;
32 import org.xnio.ChannelListener;
33 import org.xnio.IoUtils;
34 import org.xnio.channels.Channels;
35 import org.xnio.channels.EmptyStreamSourceChannel;
36 import org.xnio.channels.StreamSourceChannel;
37 import io.undertow.connector.ByteBufferPool;
38 import io.undertow.connector.PooledByteBuffer;
39 import io.undertow.servlet.UndertowServletMessages;
40
41 /**
42  * Servlet input stream implementation. This stream is non-buffered, and is used for both
43  * HTTP requests and for upgraded streams.
44  *
45  * @author Stuart Douglas
46  */

47 public class ServletInputStreamImpl extends ServletInputStream {
48
49     private final HttpServletRequestImpl request;
50     private final StreamSourceChannel channel;
51     private final ByteBufferPool bufferPool;
52
53     private volatile ReadListener listener;
54     private volatile ServletInputStreamChannelListener internalListener;
55
56     /**
57      * If this stream is ready for a read
58      */

59     private static final int FLAG_READY = 1;
60     private static final int FLAG_CLOSED = 1 << 1;
61     private static final int FLAG_FINISHED = 1 << 2;
62     private static final int FLAG_ON_DATA_READ_CALLED = 1 << 3;
63     private static final int FLAG_CALL_ON_ALL_DATA_READ = 1 << 4;
64     private static final int FLAG_BEING_INVOKED_IN_IO_THREAD = 1 << 5;
65     private static final int FLAG_IS_READY_CALLED = 1 << 6;
66
67     private volatile int state;
68     private volatile AsyncContextImpl asyncContext;
69     private volatile PooledByteBuffer pooled;
70     private volatile boolean asyncIoStarted;
71
72     private static final AtomicIntegerFieldUpdater<ServletInputStreamImpl> stateUpdater = AtomicIntegerFieldUpdater.newUpdater(ServletInputStreamImpl.class"state");
73
74     public ServletInputStreamImpl(final HttpServletRequestImpl request) {
75         this.request = request;
76         if (request.getExchange().isRequestChannelAvailable()) {
77             this.channel = request.getExchange().getRequestChannel();
78         } else {
79             this.channel = new EmptyStreamSourceChannel(request.getExchange().getIoThread());
80         }
81         this.bufferPool = request.getExchange().getConnection().getByteBufferPool();
82     }
83
84
85     @Override
86     public boolean isFinished() {
87         return anyAreSet(state, FLAG_FINISHED);
88     }
89
90     @Override
91     public boolean isReady() {
92         if (!asyncContext.isInitialRequestDone()) {
93             return false;
94         }
95         boolean finished = anyAreSet(state, FLAG_FINISHED);
96         if(finished) {
97             if (anyAreClear(state, FLAG_ON_DATA_READ_CALLED)) {
98                 if(allAreClear(state, FLAG_BEING_INVOKED_IN_IO_THREAD)) {
99                     setFlags(FLAG_ON_DATA_READ_CALLED);
100                     request.getServletContext().invokeOnAllDataRead(request.getExchange(), listener);
101                 } else {
102                     setFlags(FLAG_CALL_ON_ALL_DATA_READ);
103                 }
104             }
105         }
106         if (!asyncIoStarted) {
107             //make sure we don't call resumeReads unless we have started async IO
108             return false;
109         }
110         boolean ready = anyAreSet(state, FLAG_READY) && !finished;
111         if(!ready && listener != null && !finished) {
112             channel.resumeReads();
113         }
114         if(ready) {
115             setFlags(FLAG_IS_READY_CALLED);
116         }
117         return ready;
118     }
119
120     @Override
121     public void setReadListener(final ReadListener readListener) {
122         if (readListener == null) {
123             throw UndertowServletMessages.MESSAGES.listenerCannotBeNull();
124         }
125         if (listener != null) {
126             throw UndertowServletMessages.MESSAGES.listenerAlreadySet();
127         }
128         if (!request.isAsyncStarted()) {
129             throw UndertowServletMessages.MESSAGES.asyncNotStarted();
130         }
131
132         asyncContext = request.getAsyncContext();
133         listener = readListener;
134         channel.getReadSetter().set(internalListener = new ServletInputStreamChannelListener());
135
136         //we resume from an async task, after the request has been dispatched
137         asyncContext.addAsyncTask(new Runnable() {
138             @Override
139             public void run() {
140                 channel.getIoThread().execute(new Runnable() {
141                     @Override
142                     public void run() {
143                         asyncIoStarted = true;
144                         internalListener.handleEvent(channel);
145                     }
146                 });
147             }
148         });
149     }
150
151     @Override
152     public int read() throws IOException {
153         byte[] b = new byte[1];
154         int read = read(b);
155         if (read == -1) {
156             return -1;
157         }
158         return b[0] & 0xff;
159     }
160
161     @Override
162     public int read(final byte[] b) throws IOException {
163         return read(b, 0, b.length);
164     }
165
166     @Override
167     public int read(final byte[] b, final int off, final int len) throws IOException {
168         if (anyAreSet(state, FLAG_CLOSED)) {
169             throw UndertowServletMessages.MESSAGES.streamIsClosed();
170         }
171         if (listener != null) {
172             if (anyAreClear(state, FLAG_READY | FLAG_IS_READY_CALLED) ) {
173                 throw UndertowServletMessages.MESSAGES.streamNotReady();
174             }
175             clearFlags(FLAG_IS_READY_CALLED);
176         } else {
177             readIntoBuffer();
178         }
179         if (anyAreSet(state, FLAG_FINISHED)) {
180             return -1;
181         }
182         if (len == 0) {
183             return 0;
184         }
185         ByteBuffer buffer = pooled.getBuffer();
186         int copied = Buffers.copy(ByteBuffer.wrap(b, off, len), buffer);
187         if (!buffer.hasRemaining()) {
188             pooled.close();
189             pooled = null;
190             if (listener != null) {
191                 readIntoBufferNonBlocking();
192             }
193         }
194         return copied;
195     }
196
197     private void readIntoBuffer() throws IOException {
198         if (pooled == null && !anyAreSet(state, FLAG_FINISHED)) {
199             pooled = bufferPool.allocate();
200
201             int res = Channels.readBlocking(channel, pooled.getBuffer());
202             pooled.getBuffer().flip();
203             if (res == -1) {
204                 setFlags(FLAG_FINISHED);
205                 pooled.close();
206                 pooled = null;
207             }
208         }
209     }
210
211     private void readIntoBufferNonBlocking() throws IOException {
212         if (pooled == null && !anyAreSet(state, FLAG_FINISHED)) {
213             pooled = bufferPool.allocate();
214             if (listener == null) {
215                 int res = channel.read(pooled.getBuffer());
216                 if (res == 0) {
217                     pooled.close();
218                     pooled = null;
219                     return;
220                 }
221                 pooled.getBuffer().flip();
222                 if (res == -1) {
223                     setFlags(FLAG_FINISHED);
224                     pooled.close();
225                     pooled = null;
226                 }
227             } else {
228                 int res = channel.read(pooled.getBuffer());
229                 pooled.getBuffer().flip();
230                 if (res == -1) {
231                     setFlags(FLAG_FINISHED);
232                     pooled.close();
233                     pooled = null;
234                 } else if (res == 0) {
235                     clearFlags(FLAG_READY);
236                     pooled.close();
237                     pooled = null;
238                 }
239             }
240         }
241     }
242
243     @Override
244     public int available() throws IOException {
245         if (anyAreSet(state, FLAG_CLOSED)) {
246             throw UndertowServletMessages.MESSAGES.streamIsClosed();
247         }
248         readIntoBufferNonBlocking();
249         if (anyAreSet(state, FLAG_FINISHED)) {
250             return 0;
251         }
252         if (pooled == null) {
253             return 0;
254         }
255         return pooled.getBuffer().remaining();
256     }
257
258     @Override
259     public void close() throws IOException {
260         if (anyAreSet(state, FLAG_CLOSED)) {
261             return;
262         }
263         setFlags(FLAG_CLOSED);
264         try {
265             while (allAreClear(state, FLAG_FINISHED)) {
266                 readIntoBuffer();
267                 if (pooled != null) {
268                     pooled.close();
269                     pooled = null;
270                 }
271             }
272         } finally {
273             setFlags(FLAG_FINISHED);
274             if (pooled != null) {
275                 pooled.close();
276                 pooled = null;
277             }
278             channel.shutdownReads();
279         }
280     }
281
282     private class ServletInputStreamChannelListener implements ChannelListener<StreamSourceChannel> {
283         @Override
284         public void handleEvent(final StreamSourceChannel channel) {
285             try {
286                 if (asyncContext.isDispatched()) {
287                     //this is no longer an async request
288                     //we just return
289                     //TODO: what do we do here? Revert back to blocking mode?
290                     channel.suspendReads();
291                     return;
292                 }
293                 if (anyAreSet(state, FLAG_FINISHED)) {
294                     channel.suspendReads();
295                     return;
296                 }
297                 readIntoBufferNonBlocking();
298                 if (pooled != null) {
299                     channel.suspendReads();
300                     setFlags(FLAG_READY);
301                     if (!anyAreSet(state, FLAG_FINISHED)) {
302                         setFlags(FLAG_BEING_INVOKED_IN_IO_THREAD);
303                         try {
304                             request.getServletContext().invokeOnDataAvailable(request.getExchange(), listener);
305                         } finally {
306                             clearFlags(FLAG_BEING_INVOKED_IN_IO_THREAD);
307                         }
308                         if(anyAreSet(state, FLAG_CALL_ON_ALL_DATA_READ) && allAreClear(state, FLAG_ON_DATA_READ_CALLED)) {
309                             setFlags(FLAG_ON_DATA_READ_CALLED);
310                             request.getServletContext().invokeOnAllDataRead(request.getExchange(), listener);
311                         }
312                     }
313                 } else if(anyAreSet(state, FLAG_FINISHED)) {
314                     if (allAreClear(state, FLAG_ON_DATA_READ_CALLED)) {
315                         setFlags(FLAG_ON_DATA_READ_CALLED);
316                         request.getServletContext().invokeOnAllDataRead(request.getExchange(), listener);
317                     }
318                 } else {
319                     channel.resumeReads();
320                 }
321             } catch (final Throwable e) {
322                 try {
323                     request.getServletContext().invokeRunnable(request.getExchange(), new Runnable() {
324                         @Override
325                         public void run() {
326                             listener.onError(e);
327                         }
328                     });
329                 } finally {
330                     if (pooled != null) {
331                         pooled.close();
332                         pooled = null;
333                     }
334                     IoUtils.safeClose(channel);
335                 }
336             }
337         }
338     }
339
340     private void setFlags(int flags) {
341         int old;
342         do {
343             old = state;
344         } while (!stateUpdater.compareAndSet(this, old, old | flags));
345     }
346
347     private void clearFlags(int flags) {
348         int old;
349         do {
350             old = state;
351         } while (!stateUpdater.compareAndSet(this, old, old & ~flags));
352     }
353 }
354