1 /*
2  * JBoss, Home of Professional Open Source.
3  * Copyright 2014 Red Hat, Inc., and individual contributors
4  * as indicated by the @author tags.
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  *     http://www.apache.org/licenses/LICENSE-2.0
11  *
12  *  Unless required by applicable law or agreed to in writing, software
13  *  distributed under the License is distributed on an "AS IS" BASIS,
14  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  *  See the License for the specific language governing permissions and
16  *  limitations under the License.
17  */

18
19 package io.undertow.conduits;
20
21 import io.undertow.UndertowMessages;
22 import io.undertow.server.Connectors;
23 import io.undertow.server.HttpServerExchange;
24 import org.xnio.IoUtils;
25 import org.xnio.channels.StreamSinkChannel;
26 import org.xnio.conduits.AbstractStreamSourceConduit;
27 import org.xnio.conduits.StreamSourceConduit;
28
29 import java.io.IOException;
30 import java.nio.ByteBuffer;
31 import java.nio.channels.FileChannel;
32 import java.util.concurrent.TimeUnit;
33
34 import static java.lang.Math.min;
35 import static org.xnio.Bits.allAreClear;
36 import static org.xnio.Bits.allAreSet;
37 import static org.xnio.Bits.anyAreClear;
38 import static org.xnio.Bits.anyAreSet;
39 import static org.xnio.Bits.longBitMask;
40
41 /**
42  * A channel which reads data of a fixed length and calls a finish listener.  When the finish listener is called,
43  * it should examine the result of {@link #getRemaining()} to see if more bytes were pending when the channel was
44  * closed.
45  *
46  * @author <a href="mailto:david.lloyd@redhat.com">David M. Lloyd</a>
47  */

48 /*
49  * Implementation notes
50  * --------------------
51  * The {@code exhausted} flag is set once a method returns -1 and signifies that the read listener should no longer be
52  * called.  The {@code finishListener} is called when remaining is reduced to 0 or when the channel is closed explicitly.
53  * If there are 0 remaining bytes but {@code FLAG_FINISHED} has not yet been set, the channel is considered "ready" until
54  * the EOF -1 value is read or the channel is closed.  Since this is a half-duplex channel, shutting down reads is
55  * identical to closing the channel.
56  */

57 public final class FixedLengthStreamSourceConduit extends AbstractStreamSourceConduit<StreamSourceConduit> {
58
59     private final ConduitListener<? super FixedLengthStreamSourceConduit> finishListener;
60
61     @SuppressWarnings("unused")
62     private long state;
63
64     private static final long FLAG_CLOSED = 1L << 63L;
65     private static final long FLAG_FINISHED = 1L << 62L;
66     private static final long FLAG_LENGTH_CHECKED = 1L << 61L;
67     private static final long MASK_COUNT = longBitMask(0, 60);
68
69     private final HttpServerExchange exchange;
70
71     /**
72      * Construct a new instance.  The given listener is called once all the bytes are read from the stream
73      * <b>or</b> the stream is closed.  This listener should cause the remaining data to be drained from the
74      * underlying stream if the underlying stream is to be reused.
75      * <p>
76      * Calling this constructor will replace the read listener of the underlying channel.  The listener should be
77      * restored from the {@code finishListener} object.  The underlying stream should not be closed while this wrapper
78      * stream is active.
79      *
80      * @param next           the stream source channel to read from
81      * @param contentLength  the amount of content to read
82      * @param finishListener the listener to call once the stream is exhausted or closed
83      * @param exchange       The server exchange. This is used to determine the max size
84      */

85     public FixedLengthStreamSourceConduit(final StreamSourceConduit next, final long contentLength, final ConduitListener<? super FixedLengthStreamSourceConduit> finishListener, final HttpServerExchange exchange) {
86         super(next);
87         this.finishListener = finishListener;
88         if (contentLength < 0L) {
89             throw new IllegalArgumentException("Content length must be greater than or equal to zero");
90         } else if (contentLength > MASK_COUNT) {
91             throw new IllegalArgumentException("Content length is too long");
92         }
93         state = contentLength;
94         this.exchange = exchange;
95     }
96
97     /**
98      * Construct a new instance.  The given listener is called once all the bytes are read from the stream
99      * <b>or</b> the stream is closed.  This listener should cause the remaining data to be drained from the
100      * underlying stream if the underlying stream is to be reused.
101      * <p>
102      * Calling this constructor will replace the read listener of the underlying channel.  The listener should be
103      * restored from the {@code finishListener} object.  The underlying stream should not be closed while this wrapper
104      * stream is active.
105      *
106      * @param next           the stream source channel to read from
107      * @param contentLength  the amount of content to read
108      * @param finishListener the listener to call once the stream is exhausted or closed
109      */

110     public FixedLengthStreamSourceConduit(final StreamSourceConduit next, final long contentLength, final ConduitListener<? super FixedLengthStreamSourceConduit> finishListener) {
111         this(next, contentLength, finishListener, null);
112     }
113
114     public long transferTo(final long position, final long count, final FileChannel target) throws IOException {
115         long val = state;
116         checkMaxSize(val);
117         if (anyAreSet(val, FLAG_CLOSED | FLAG_FINISHED) || allAreClear(val, MASK_COUNT)) {
118             if (allAreClear(val, FLAG_FINISHED)) {
119                 invokeFinishListener();
120             }
121             return -1L;
122         }
123         long res = 0L;
124         try {
125             return res = next.transferTo(position, min(count, val & MASK_COUNT), target);
126         } catch (IOException | RuntimeException | Error e) {
127             IoUtils.safeClose(exchange.getConnection());
128             throw e;
129         } finally {
130             exitRead(res);
131         }
132     }
133
134     public long transferTo(final long count, final ByteBuffer throughBuffer, final StreamSinkChannel target) throws IOException {
135         if (count == 0L) {
136             return 0L;
137         }
138         long val = state;
139         checkMaxSize(val);
140         if (anyAreSet(val, FLAG_CLOSED | FLAG_FINISHED) || allAreClear(val, MASK_COUNT)) {
141             if (allAreClear(val, FLAG_FINISHED)) {
142                 invokeFinishListener();
143             }
144             return -1;
145         }
146         long res = 0L;
147         try {
148             return res = next.transferTo(min(count, val & MASK_COUNT), throughBuffer, target);
149         } catch (IOException | RuntimeException | Error e) {
150             IoUtils.safeClose(exchange.getConnection());
151             throw e;
152         } finally {
153             exitRead(res + throughBuffer.remaining());
154         }
155     }
156
157     private void checkMaxSize(long state) throws IOException {
158         if (anyAreClear(state, FLAG_LENGTH_CHECKED)) {
159             HttpServerExchange exchange = this.exchange;
160             if (exchange != null) {
161                 if (exchange.getMaxEntitySize() > 0 && exchange.getMaxEntitySize() < (state & MASK_COUNT)) {
162                     //max entity size is exceeded
163                     //we need to forcibly close the read side
164                     Connectors.terminateRequest(exchange);
165                     exchange.setPersistent(false);
166                     finishListener.handleEvent(this);
167                     this.state |= FLAG_FINISHED | FLAG_CLOSED;
168                     throw UndertowMessages.MESSAGES.requestEntityWasTooLarge(exchange.getMaxEntitySize());
169                 }
170             }
171             this.state |= FLAG_LENGTH_CHECKED;
172         }
173     }
174
175     public long read(final ByteBuffer[] dsts, final int offset, final int length) throws IOException {
176         if (length == 0) {
177             return 0L;
178         } else if (length == 1) {
179             return read(dsts[offset]);
180         }
181         long val = state;
182         checkMaxSize(val);
183         if (allAreSet(val, FLAG_CLOSED) || allAreClear(val, MASK_COUNT)) {
184             if (allAreClear(val, FLAG_FINISHED)) {
185                 invokeFinishListener();
186             }
187             return -1;
188         }
189         long res = 0L;
190         try {
191             if ((val & MASK_COUNT) == 0L) {
192                 return -1L;
193             }
194             int lim;
195             // The total amount of buffer space discovered so far.
196             long t = 0L;
197             for (int i = 0; i < length; i++) {
198                 final ByteBuffer buffer = dsts[i + offset];
199                 // Grow the discovered buffer space by the remaining size of the current buffer.
200                 // We want to capture the limit so we calculate "remaining" ourselves.
201                 t += (lim = buffer.limit()) - buffer.position();
202                 if (t > (val & MASK_COUNT)) {
203                     // only read up to this point, and trim the last buffer by the number of extra bytes
204                     buffer.limit(lim - (int) (t - (val & MASK_COUNT)));
205                     try {
206                         return res = next.read(dsts, offset, i + 1);
207                     } finally {
208                         // restore the original limit
209                         buffer.limit(lim);
210                     }
211                 }
212             }
213             // the total buffer space is less than the remaining count.
214             return res = next.read(dsts, offset, length);
215         } catch (IOException | RuntimeException | Error e) {
216             IoUtils.safeClose(exchange.getConnection());
217             throw e;
218         } finally {
219             exitRead(res);
220         }
221     }
222
223     public long read(final ByteBuffer[] dsts) throws IOException {
224         return read(dsts, 0, dsts.length);
225     }
226
227     public int read(final ByteBuffer dst) throws IOException {
228         long val = state;
229         checkMaxSize(val);
230         if (allAreSet(val, FLAG_CLOSED) || allAreClear(val, MASK_COUNT)) {
231             if (allAreClear(val, FLAG_FINISHED)) {
232                 invokeFinishListener();
233             }
234             return -1;
235         }
236         int res = 0;
237         final long remaining = val & MASK_COUNT;
238         try {
239             final int lim = dst.limit();
240             final int pos = dst.position();
241             if (lim - pos > remaining) {
242                 dst.limit((int) (remaining + (long) pos));
243                 try {
244                     return res = next.read(dst);
245                 } finally {
246                     dst.limit(lim);
247                 }
248             } else {
249                 return res = next.read(dst);
250             }
251         } catch (IOException | RuntimeException | Error e) {
252             IoUtils.safeClose(exchange.getConnection());
253             throw e;
254         }  finally {
255             exitRead(res);
256         }
257     }
258
259     public boolean isReadResumed() {
260         return allAreClear(state, FLAG_CLOSED) && next.isReadResumed();
261     }
262
263     public void wakeupReads() {
264         long val = state;
265         if (anyAreSet(val, FLAG_CLOSED | FLAG_FINISHED)) {
266             return;
267         }
268         next.wakeupReads();
269     }
270
271     @Override
272     public void terminateReads() throws IOException {
273         long val = enterShutdownReads();
274         if (allAreSet(val, FLAG_CLOSED)) {
275             return;
276         }
277         exitShutdownReads(val);
278     }
279
280     public void awaitReadable() throws IOException {
281         final long val = state;
282         if (allAreSet(val, FLAG_CLOSED) || val == 0L) {
283             return;
284         }
285         next.awaitReadable();
286     }
287
288     public void awaitReadable(final long time, final TimeUnit timeUnit) throws IOException {
289         final long val = state;
290         if (allAreSet(val, FLAG_CLOSED) || val == 0L) {
291             return;
292         }
293         try {
294             next.awaitReadable(time, timeUnit);
295         } catch (IOException | RuntimeException | Error e) {
296             IoUtils.safeClose(exchange.getConnection());
297             throw e;
298         }
299     }
300
301     /**
302      * Get the number of remaining bytes.
303      *
304      * @return the number of remaining bytes
305      */

306     public long getRemaining() {
307         return state & MASK_COUNT;
308     }
309
310     private long enterShutdownReads() {
311         long oldVal, newVal;
312         oldVal = state;
313         if (anyAreSet(oldVal, FLAG_CLOSED)) {
314             return oldVal;
315         }
316         newVal = oldVal | FLAG_CLOSED;
317         state = newVal;
318         return oldVal;
319     }
320
321     private void exitShutdownReads(long oldVal) {
322         if (!allAreClear(oldVal, MASK_COUNT)) {
323             invokeFinishListener();
324         }
325     }
326
327     /**
328      * Exit a read method.
329      *
330      * @param consumed the number of bytes consumed by this call (may be 0)
331      */

332     private void exitRead(long consumed) throws IOException {
333         long oldVal = state;
334         if(consumed == -1) {
335             if (anyAreSet(oldVal, MASK_COUNT)) {
336                 invokeFinishListener();
337                 state &= ~MASK_COUNT;
338                 throw UndertowMessages.MESSAGES.couldNotReadContentLengthData();
339             }
340             return;
341         }
342         long newVal = oldVal - consumed;
343         state = newVal;
344     }
345
346     private void invokeFinishListener() {
347         this.state |= FLAG_FINISHED;
348         finishListener.handleEvent(this);
349     }
350
351 }
352