1
18
19 package io.undertow.servlet.spec;
20
21 import static org.xnio.Bits.allAreClear;
22 import static org.xnio.Bits.anyAreClear;
23 import static org.xnio.Bits.anyAreSet;
24
25 import java.io.IOException;
26 import java.nio.ByteBuffer;
27 import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
28 import javax.servlet.ReadListener;
29 import javax.servlet.ServletInputStream;
30
31 import org.xnio.Buffers;
32 import org.xnio.ChannelListener;
33 import org.xnio.IoUtils;
34 import org.xnio.channels.Channels;
35 import org.xnio.channels.EmptyStreamSourceChannel;
36 import org.xnio.channels.StreamSourceChannel;
37 import io.undertow.connector.ByteBufferPool;
38 import io.undertow.connector.PooledByteBuffer;
39 import io.undertow.servlet.UndertowServletMessages;
40
41
47 public class ServletInputStreamImpl extends ServletInputStream {
48
49 private final HttpServletRequestImpl request;
50 private final StreamSourceChannel channel;
51 private final ByteBufferPool bufferPool;
52
53 private volatile ReadListener listener;
54 private volatile ServletInputStreamChannelListener internalListener;
55
56
59 private static final int FLAG_READY = 1;
60 private static final int FLAG_CLOSED = 1 << 1;
61 private static final int FLAG_FINISHED = 1 << 2;
62 private static final int FLAG_ON_DATA_READ_CALLED = 1 << 3;
63 private static final int FLAG_CALL_ON_ALL_DATA_READ = 1 << 4;
64 private static final int FLAG_BEING_INVOKED_IN_IO_THREAD = 1 << 5;
65 private static final int FLAG_IS_READY_CALLED = 1 << 6;
66
67 private volatile int state;
68 private volatile AsyncContextImpl asyncContext;
69 private volatile PooledByteBuffer pooled;
70 private volatile boolean asyncIoStarted;
71
72 private static final AtomicIntegerFieldUpdater<ServletInputStreamImpl> stateUpdater = AtomicIntegerFieldUpdater.newUpdater(ServletInputStreamImpl.class, "state");
73
74 public ServletInputStreamImpl(final HttpServletRequestImpl request) {
75 this.request = request;
76 if (request.getExchange().isRequestChannelAvailable()) {
77 this.channel = request.getExchange().getRequestChannel();
78 } else {
79 this.channel = new EmptyStreamSourceChannel(request.getExchange().getIoThread());
80 }
81 this.bufferPool = request.getExchange().getConnection().getByteBufferPool();
82 }
83
84
85 @Override
86 public boolean isFinished() {
87 return anyAreSet(state, FLAG_FINISHED);
88 }
89
90 @Override
91 public boolean isReady() {
92 if (!asyncContext.isInitialRequestDone()) {
93 return false;
94 }
95 boolean finished = anyAreSet(state, FLAG_FINISHED);
96 if(finished) {
97 if (anyAreClear(state, FLAG_ON_DATA_READ_CALLED)) {
98 if(allAreClear(state, FLAG_BEING_INVOKED_IN_IO_THREAD)) {
99 setFlags(FLAG_ON_DATA_READ_CALLED);
100 request.getServletContext().invokeOnAllDataRead(request.getExchange(), listener);
101 } else {
102 setFlags(FLAG_CALL_ON_ALL_DATA_READ);
103 }
104 }
105 }
106 if (!asyncIoStarted) {
107
108 return false;
109 }
110 boolean ready = anyAreSet(state, FLAG_READY) && !finished;
111 if(!ready && listener != null && !finished) {
112 channel.resumeReads();
113 }
114 if(ready) {
115 setFlags(FLAG_IS_READY_CALLED);
116 }
117 return ready;
118 }
119
120 @Override
121 public void setReadListener(final ReadListener readListener) {
122 if (readListener == null) {
123 throw UndertowServletMessages.MESSAGES.listenerCannotBeNull();
124 }
125 if (listener != null) {
126 throw UndertowServletMessages.MESSAGES.listenerAlreadySet();
127 }
128 if (!request.isAsyncStarted()) {
129 throw UndertowServletMessages.MESSAGES.asyncNotStarted();
130 }
131
132 asyncContext = request.getAsyncContext();
133 listener = readListener;
134 channel.getReadSetter().set(internalListener = new ServletInputStreamChannelListener());
135
136
137 asyncContext.addAsyncTask(new Runnable() {
138 @Override
139 public void run() {
140 channel.getIoThread().execute(new Runnable() {
141 @Override
142 public void run() {
143 asyncIoStarted = true;
144 internalListener.handleEvent(channel);
145 }
146 });
147 }
148 });
149 }
150
151 @Override
152 public int read() throws IOException {
153 byte[] b = new byte[1];
154 int read = read(b);
155 if (read == -1) {
156 return -1;
157 }
158 return b[0] & 0xff;
159 }
160
161 @Override
162 public int read(final byte[] b) throws IOException {
163 return read(b, 0, b.length);
164 }
165
166 @Override
167 public int read(final byte[] b, final int off, final int len) throws IOException {
168 if (anyAreSet(state, FLAG_CLOSED)) {
169 throw UndertowServletMessages.MESSAGES.streamIsClosed();
170 }
171 if (listener != null) {
172 if (anyAreClear(state, FLAG_READY | FLAG_IS_READY_CALLED) ) {
173 throw UndertowServletMessages.MESSAGES.streamNotReady();
174 }
175 clearFlags(FLAG_IS_READY_CALLED);
176 } else {
177 readIntoBuffer();
178 }
179 if (anyAreSet(state, FLAG_FINISHED)) {
180 return -1;
181 }
182 if (len == 0) {
183 return 0;
184 }
185 ByteBuffer buffer = pooled.getBuffer();
186 int copied = Buffers.copy(ByteBuffer.wrap(b, off, len), buffer);
187 if (!buffer.hasRemaining()) {
188 pooled.close();
189 pooled = null;
190 if (listener != null) {
191 readIntoBufferNonBlocking();
192 }
193 }
194 return copied;
195 }
196
197 private void readIntoBuffer() throws IOException {
198 if (pooled == null && !anyAreSet(state, FLAG_FINISHED)) {
199 pooled = bufferPool.allocate();
200
201 int res = Channels.readBlocking(channel, pooled.getBuffer());
202 pooled.getBuffer().flip();
203 if (res == -1) {
204 setFlags(FLAG_FINISHED);
205 pooled.close();
206 pooled = null;
207 }
208 }
209 }
210
211 private void readIntoBufferNonBlocking() throws IOException {
212 if (pooled == null && !anyAreSet(state, FLAG_FINISHED)) {
213 pooled = bufferPool.allocate();
214 if (listener == null) {
215 int res = channel.read(pooled.getBuffer());
216 if (res == 0) {
217 pooled.close();
218 pooled = null;
219 return;
220 }
221 pooled.getBuffer().flip();
222 if (res == -1) {
223 setFlags(FLAG_FINISHED);
224 pooled.close();
225 pooled = null;
226 }
227 } else {
228 int res = channel.read(pooled.getBuffer());
229 pooled.getBuffer().flip();
230 if (res == -1) {
231 setFlags(FLAG_FINISHED);
232 pooled.close();
233 pooled = null;
234 } else if (res == 0) {
235 clearFlags(FLAG_READY);
236 pooled.close();
237 pooled = null;
238 }
239 }
240 }
241 }
242
243 @Override
244 public int available() throws IOException {
245 if (anyAreSet(state, FLAG_CLOSED)) {
246 throw UndertowServletMessages.MESSAGES.streamIsClosed();
247 }
248 readIntoBufferNonBlocking();
249 if (anyAreSet(state, FLAG_FINISHED)) {
250 return 0;
251 }
252 if (pooled == null) {
253 return 0;
254 }
255 return pooled.getBuffer().remaining();
256 }
257
258 @Override
259 public void close() throws IOException {
260 if (anyAreSet(state, FLAG_CLOSED)) {
261 return;
262 }
263 setFlags(FLAG_CLOSED);
264 try {
265 while (allAreClear(state, FLAG_FINISHED)) {
266 readIntoBuffer();
267 if (pooled != null) {
268 pooled.close();
269 pooled = null;
270 }
271 }
272 } finally {
273 setFlags(FLAG_FINISHED);
274 if (pooled != null) {
275 pooled.close();
276 pooled = null;
277 }
278 channel.shutdownReads();
279 }
280 }
281
282 private class ServletInputStreamChannelListener implements ChannelListener<StreamSourceChannel> {
283 @Override
284 public void handleEvent(final StreamSourceChannel channel) {
285 try {
286 if (asyncContext.isDispatched()) {
287
288
289
290 channel.suspendReads();
291 return;
292 }
293 if (anyAreSet(state, FLAG_FINISHED)) {
294 channel.suspendReads();
295 return;
296 }
297 readIntoBufferNonBlocking();
298 if (pooled != null) {
299 channel.suspendReads();
300 setFlags(FLAG_READY);
301 if (!anyAreSet(state, FLAG_FINISHED)) {
302 setFlags(FLAG_BEING_INVOKED_IN_IO_THREAD);
303 try {
304 request.getServletContext().invokeOnDataAvailable(request.getExchange(), listener);
305 } finally {
306 clearFlags(FLAG_BEING_INVOKED_IN_IO_THREAD);
307 }
308 if(anyAreSet(state, FLAG_CALL_ON_ALL_DATA_READ) && allAreClear(state, FLAG_ON_DATA_READ_CALLED)) {
309 setFlags(FLAG_ON_DATA_READ_CALLED);
310 request.getServletContext().invokeOnAllDataRead(request.getExchange(), listener);
311 }
312 }
313 } else if(anyAreSet(state, FLAG_FINISHED)) {
314 if (allAreClear(state, FLAG_ON_DATA_READ_CALLED)) {
315 setFlags(FLAG_ON_DATA_READ_CALLED);
316 request.getServletContext().invokeOnAllDataRead(request.getExchange(), listener);
317 }
318 } else {
319 channel.resumeReads();
320 }
321 } catch (final Throwable e) {
322 try {
323 request.getServletContext().invokeRunnable(request.getExchange(), new Runnable() {
324 @Override
325 public void run() {
326 listener.onError(e);
327 }
328 });
329 } finally {
330 if (pooled != null) {
331 pooled.close();
332 pooled = null;
333 }
334 IoUtils.safeClose(channel);
335 }
336 }
337 }
338 }
339
340 private void setFlags(int flags) {
341 int old;
342 do {
343 old = state;
344 } while (!stateUpdater.compareAndSet(this, old, old | flags));
345 }
346
347 private void clearFlags(int flags) {
348 int old;
349 do {
350 old = state;
351 } while (!stateUpdater.compareAndSet(this, old, old & ~flags));
352 }
353 }
354