1
18
19 package io.undertow.conduits;
20
21 import io.undertow.UndertowMessages;
22 import io.undertow.server.Connectors;
23 import io.undertow.server.HttpServerExchange;
24 import org.xnio.IoUtils;
25 import org.xnio.channels.StreamSinkChannel;
26 import org.xnio.conduits.AbstractStreamSourceConduit;
27 import org.xnio.conduits.StreamSourceConduit;
28
29 import java.io.IOException;
30 import java.nio.ByteBuffer;
31 import java.nio.channels.FileChannel;
32 import java.util.concurrent.TimeUnit;
33
34 import static java.lang.Math.min;
35 import static org.xnio.Bits.allAreClear;
36 import static org.xnio.Bits.allAreSet;
37 import static org.xnio.Bits.anyAreClear;
38 import static org.xnio.Bits.anyAreSet;
39 import static org.xnio.Bits.longBitMask;
40
41
48
57 public final class FixedLengthStreamSourceConduit extends AbstractStreamSourceConduit<StreamSourceConduit> {
58
59 private final ConduitListener<? super FixedLengthStreamSourceConduit> finishListener;
60
61 @SuppressWarnings("unused")
62 private long state;
63
64 private static final long FLAG_CLOSED = 1L << 63L;
65 private static final long FLAG_FINISHED = 1L << 62L;
66 private static final long FLAG_LENGTH_CHECKED = 1L << 61L;
67 private static final long MASK_COUNT = longBitMask(0, 60);
68
69 private final HttpServerExchange exchange;
70
71
85 public FixedLengthStreamSourceConduit(final StreamSourceConduit next, final long contentLength, final ConduitListener<? super FixedLengthStreamSourceConduit> finishListener, final HttpServerExchange exchange) {
86 super(next);
87 this.finishListener = finishListener;
88 if (contentLength < 0L) {
89 throw new IllegalArgumentException("Content length must be greater than or equal to zero");
90 } else if (contentLength > MASK_COUNT) {
91 throw new IllegalArgumentException("Content length is too long");
92 }
93 state = contentLength;
94 this.exchange = exchange;
95 }
96
97
110 public FixedLengthStreamSourceConduit(final StreamSourceConduit next, final long contentLength, final ConduitListener<? super FixedLengthStreamSourceConduit> finishListener) {
111 this(next, contentLength, finishListener, null);
112 }
113
114 public long transferTo(final long position, final long count, final FileChannel target) throws IOException {
115 long val = state;
116 checkMaxSize(val);
117 if (anyAreSet(val, FLAG_CLOSED | FLAG_FINISHED) || allAreClear(val, MASK_COUNT)) {
118 if (allAreClear(val, FLAG_FINISHED)) {
119 invokeFinishListener();
120 }
121 return -1L;
122 }
123 long res = 0L;
124 try {
125 return res = next.transferTo(position, min(count, val & MASK_COUNT), target);
126 } catch (IOException | RuntimeException | Error e) {
127 IoUtils.safeClose(exchange.getConnection());
128 throw e;
129 } finally {
130 exitRead(res);
131 }
132 }
133
134 public long transferTo(final long count, final ByteBuffer throughBuffer, final StreamSinkChannel target) throws IOException {
135 if (count == 0L) {
136 return 0L;
137 }
138 long val = state;
139 checkMaxSize(val);
140 if (anyAreSet(val, FLAG_CLOSED | FLAG_FINISHED) || allAreClear(val, MASK_COUNT)) {
141 if (allAreClear(val, FLAG_FINISHED)) {
142 invokeFinishListener();
143 }
144 return -1;
145 }
146 long res = 0L;
147 try {
148 return res = next.transferTo(min(count, val & MASK_COUNT), throughBuffer, target);
149 } catch (IOException | RuntimeException | Error e) {
150 IoUtils.safeClose(exchange.getConnection());
151 throw e;
152 } finally {
153 exitRead(res + throughBuffer.remaining());
154 }
155 }
156
157 private void checkMaxSize(long state) throws IOException {
158 if (anyAreClear(state, FLAG_LENGTH_CHECKED)) {
159 HttpServerExchange exchange = this.exchange;
160 if (exchange != null) {
161 if (exchange.getMaxEntitySize() > 0 && exchange.getMaxEntitySize() < (state & MASK_COUNT)) {
162
163
164 Connectors.terminateRequest(exchange);
165 exchange.setPersistent(false);
166 finishListener.handleEvent(this);
167 this.state |= FLAG_FINISHED | FLAG_CLOSED;
168 throw UndertowMessages.MESSAGES.requestEntityWasTooLarge(exchange.getMaxEntitySize());
169 }
170 }
171 this.state |= FLAG_LENGTH_CHECKED;
172 }
173 }
174
175 public long read(final ByteBuffer[] dsts, final int offset, final int length) throws IOException {
176 if (length == 0) {
177 return 0L;
178 } else if (length == 1) {
179 return read(dsts[offset]);
180 }
181 long val = state;
182 checkMaxSize(val);
183 if (allAreSet(val, FLAG_CLOSED) || allAreClear(val, MASK_COUNT)) {
184 if (allAreClear(val, FLAG_FINISHED)) {
185 invokeFinishListener();
186 }
187 return -1;
188 }
189 long res = 0L;
190 try {
191 if ((val & MASK_COUNT) == 0L) {
192 return -1L;
193 }
194 int lim;
195
196 long t = 0L;
197 for (int i = 0; i < length; i++) {
198 final ByteBuffer buffer = dsts[i + offset];
199
200
201 t += (lim = buffer.limit()) - buffer.position();
202 if (t > (val & MASK_COUNT)) {
203
204 buffer.limit(lim - (int) (t - (val & MASK_COUNT)));
205 try {
206 return res = next.read(dsts, offset, i + 1);
207 } finally {
208
209 buffer.limit(lim);
210 }
211 }
212 }
213
214 return res = next.read(dsts, offset, length);
215 } catch (IOException | RuntimeException | Error e) {
216 IoUtils.safeClose(exchange.getConnection());
217 throw e;
218 } finally {
219 exitRead(res);
220 }
221 }
222
223 public long read(final ByteBuffer[] dsts) throws IOException {
224 return read(dsts, 0, dsts.length);
225 }
226
227 public int read(final ByteBuffer dst) throws IOException {
228 long val = state;
229 checkMaxSize(val);
230 if (allAreSet(val, FLAG_CLOSED) || allAreClear(val, MASK_COUNT)) {
231 if (allAreClear(val, FLAG_FINISHED)) {
232 invokeFinishListener();
233 }
234 return -1;
235 }
236 int res = 0;
237 final long remaining = val & MASK_COUNT;
238 try {
239 final int lim = dst.limit();
240 final int pos = dst.position();
241 if (lim - pos > remaining) {
242 dst.limit((int) (remaining + (long) pos));
243 try {
244 return res = next.read(dst);
245 } finally {
246 dst.limit(lim);
247 }
248 } else {
249 return res = next.read(dst);
250 }
251 } catch (IOException | RuntimeException | Error e) {
252 IoUtils.safeClose(exchange.getConnection());
253 throw e;
254 } finally {
255 exitRead(res);
256 }
257 }
258
259 public boolean isReadResumed() {
260 return allAreClear(state, FLAG_CLOSED) && next.isReadResumed();
261 }
262
263 public void wakeupReads() {
264 long val = state;
265 if (anyAreSet(val, FLAG_CLOSED | FLAG_FINISHED)) {
266 return;
267 }
268 next.wakeupReads();
269 }
270
271 @Override
272 public void terminateReads() throws IOException {
273 long val = enterShutdownReads();
274 if (allAreSet(val, FLAG_CLOSED)) {
275 return;
276 }
277 exitShutdownReads(val);
278 }
279
280 public void awaitReadable() throws IOException {
281 final long val = state;
282 if (allAreSet(val, FLAG_CLOSED) || val == 0L) {
283 return;
284 }
285 next.awaitReadable();
286 }
287
288 public void awaitReadable(final long time, final TimeUnit timeUnit) throws IOException {
289 final long val = state;
290 if (allAreSet(val, FLAG_CLOSED) || val == 0L) {
291 return;
292 }
293 try {
294 next.awaitReadable(time, timeUnit);
295 } catch (IOException | RuntimeException | Error e) {
296 IoUtils.safeClose(exchange.getConnection());
297 throw e;
298 }
299 }
300
301
306 public long getRemaining() {
307 return state & MASK_COUNT;
308 }
309
310 private long enterShutdownReads() {
311 long oldVal, newVal;
312 oldVal = state;
313 if (anyAreSet(oldVal, FLAG_CLOSED)) {
314 return oldVal;
315 }
316 newVal = oldVal | FLAG_CLOSED;
317 state = newVal;
318 return oldVal;
319 }
320
321 private void exitShutdownReads(long oldVal) {
322 if (!allAreClear(oldVal, MASK_COUNT)) {
323 invokeFinishListener();
324 }
325 }
326
327
332 private void exitRead(long consumed) throws IOException {
333 long oldVal = state;
334 if(consumed == -1) {
335 if (anyAreSet(oldVal, MASK_COUNT)) {
336 invokeFinishListener();
337 state &= ~MASK_COUNT;
338 throw UndertowMessages.MESSAGES.couldNotReadContentLengthData();
339 }
340 return;
341 }
342 long newVal = oldVal - consumed;
343 state = newVal;
344 }
345
346 private void invokeFinishListener() {
347 this.state |= FLAG_FINISHED;
348 finishListener.handleEvent(this);
349 }
350
351 }
352