1
18
19 package io.undertow.conduits;
20
21 import io.undertow.UndertowLogger;
22 import org.xnio.Buffers;
23 import org.xnio.channels.FixedLengthOverflowException;
24 import org.xnio.channels.StreamSourceChannel;
25 import org.xnio.conduits.AbstractStreamSinkConduit;
26 import org.xnio.conduits.Conduits;
27 import org.xnio.conduits.StreamSinkConduit;
28
29 import java.io.IOException;
30 import java.nio.ByteBuffer;
31 import java.nio.channels.ClosedChannelException;
32 import java.nio.channels.FileChannel;
33 import java.util.concurrent.TimeUnit;
34
35 import static java.lang.Math.min;
36 import static org.xnio.Bits.allAreClear;
37 import static org.xnio.Bits.allAreSet;
38 import static org.xnio.Bits.anyAreSet;
39 import static org.xnio.Bits.longBitMask;
40
41
46 public abstract class AbstractFixedLengthStreamSinkConduit extends AbstractStreamSinkConduit<StreamSinkConduit> {
47 private int config;
48
49 private long state;
50
51 private boolean broken = false;
52
53 private static final int CONF_FLAG_CONFIGURABLE = 1 << 0;
54 private static final int CONF_FLAG_PASS_CLOSE = 1 << 1;
55
56 private static final long FLAG_CLOSE_REQUESTED = 1L << 63L;
57 private static final long FLAG_CLOSE_COMPLETE = 1L << 62L;
58 private static final long FLAG_FINISHED_CALLED = 1L << 61L;
59 private static final long MASK_COUNT = longBitMask(0, 60);
60
61
69 public AbstractFixedLengthStreamSinkConduit(final StreamSinkConduit next, final long contentLength, final boolean configurable, final boolean propagateClose) {
70 super(next);
71 if (contentLength < 0L) {
72 throw new IllegalArgumentException("Content length must be greater than or equal to zero");
73 } else if (contentLength > MASK_COUNT) {
74 throw new IllegalArgumentException("Content length is too long");
75 }
76 config = (configurable ? CONF_FLAG_CONFIGURABLE : 0) | (propagateClose ? CONF_FLAG_PASS_CLOSE : 0);
77 this.state = contentLength;
78 }
79
80 protected void reset(long contentLength, boolean propagateClose) {
81 this.state = contentLength;
82 if (propagateClose) {
83 config |= CONF_FLAG_PASS_CLOSE;
84 } else {
85 config &= ~CONF_FLAG_PASS_CLOSE;
86 }
87 }
88
89 public int write(final ByteBuffer src) throws IOException {
90 long val = state;
91 final long remaining = val & MASK_COUNT;
92 if (!src.hasRemaining()) {
93 return 0;
94 }
95 if (allAreSet(val, FLAG_CLOSE_REQUESTED)) {
96 throw new ClosedChannelException();
97 }
98 int oldLimit = src.limit();
99 if (remaining == 0) {
100 throw new FixedLengthOverflowException();
101 } else if (src.remaining() > remaining) {
102 src.limit((int) (src.position() + remaining));
103 }
104 int res = 0;
105 try {
106 return res = next.write(src);
107 } catch (IOException | RuntimeException | Error e) {
108 broken = true;
109 throw e;
110 } finally {
111 src.limit(oldLimit);
112 exitWrite(val, (long) res);
113 }
114 }
115
116 public long write(final ByteBuffer[] srcs, final int offset, final int length) throws IOException {
117 if (length == 0) {
118 return 0L;
119 } else if (length == 1) {
120 return write(srcs[offset]);
121 }
122 long val = state;
123 final long remaining = val & MASK_COUNT;
124 if (allAreSet(val, FLAG_CLOSE_REQUESTED)) {
125 throw new ClosedChannelException();
126 }
127 long toWrite = Buffers.remaining(srcs, offset, length);
128 if (remaining == 0) {
129 throw new FixedLengthOverflowException();
130 }
131 int[] limits = null;
132 if (toWrite > remaining) {
133 limits = new int[length];
134 long r = remaining;
135 for (int i = offset; i < offset + length; ++i) {
136 limits[i - offset] = srcs[i].limit();
137 int br = srcs[i].remaining();
138 if(br < r) {
139 r -= br;
140 } else {
141 srcs[i].limit((int) (srcs[i].position() + r));
142 r = 0;
143 }
144 }
145 }
146 long res = 0L;
147 try {
148 return res = next.write(srcs, offset, length);
149 } catch (IOException | RuntimeException | Error e) {
150 broken = true;
151 throw e;
152 } finally {
153 if (limits != null) {
154 for (int i = offset; i < offset + length; ++i) {
155 srcs[i].limit(limits[i - offset]);
156 }
157 }
158 exitWrite(val, res);
159 }
160 }
161
162 @Override
163 public long writeFinal(ByteBuffer[] srcs, int offset, int length) throws IOException {
164 try {
165 return Conduits.writeFinalBasic(this, srcs, offset, length);
166 } catch (IOException | RuntimeException | Error e) {
167 broken = true;
168 throw e;
169 }
170 }
171
172 @Override
173 public int writeFinal(ByteBuffer src) throws IOException {
174 try {
175 return Conduits.writeFinalBasic(this, src);
176 } catch (IOException | RuntimeException | Error e) {
177 broken = true;
178 throw e;
179 }
180 }
181
182 public long transferFrom(final FileChannel src, final long position, final long count) throws IOException {
183 if (count == 0L) return 0L;
184 long val = state;
185 if (allAreSet(val, FLAG_CLOSE_REQUESTED)) {
186 throw new ClosedChannelException();
187 }
188 if (allAreClear(val, MASK_COUNT)) {
189 throw new FixedLengthOverflowException();
190 }
191 long res = 0L;
192 try {
193 return res = next.transferFrom(src, position, min(count, (val & MASK_COUNT)));
194 } catch (IOException | RuntimeException | Error e) {
195 broken = true;
196 throw e;
197 } finally {
198 exitWrite(val, res);
199 }
200 }
201
202 public long transferFrom(final StreamSourceChannel source, final long count, final ByteBuffer throughBuffer) throws IOException {
203 if (count == 0L) return 0L;
204 long val = state;
205 if (allAreSet(val, FLAG_CLOSE_REQUESTED)) {
206 throw new ClosedChannelException();
207 }
208 if (allAreClear(val, MASK_COUNT)) {
209 throw new FixedLengthOverflowException();
210 }
211 long res = 0L;
212 try {
213 return res = next.transferFrom(source, min(count, (val & MASK_COUNT)), throughBuffer);
214 } catch (IOException | RuntimeException | Error e) {
215 broken = true;
216 throw e;
217 } finally {
218 exitWrite(val, res);
219 }
220 }
221
222 public boolean flush() throws IOException {
223 long val = state;
224 if (anyAreSet(val, FLAG_CLOSE_COMPLETE)) {
225 return true;
226 }
227 boolean flushed = false;
228 try {
229 return flushed = next.flush();
230 } catch (IOException | RuntimeException | Error e) {
231 broken = true;
232 throw e;
233 } finally {
234 exitFlush(val, flushed);
235 }
236 }
237
238 public boolean isWriteResumed() {
239
240 return allAreClear(state, FLAG_CLOSE_COMPLETE) && next.isWriteResumed();
241 }
242
243 public void wakeupWrites() {
244 long val = state;
245 if (anyAreSet(val, FLAG_CLOSE_COMPLETE)) {
246 return;
247 }
248 next.wakeupWrites();
249 }
250
251 public void terminateWrites() throws IOException {
252 final long val = enterShutdown();
253 if (anyAreSet(val, MASK_COUNT) && !broken) {
254 UndertowLogger.REQUEST_IO_LOGGER.debugf("Fixed length stream closed with with %s bytes remaining", val & MASK_COUNT);
255 try {
256 next.truncateWrites();
257 } finally {
258 if (!anyAreSet(state, FLAG_FINISHED_CALLED)) {
259 state |= FLAG_FINISHED_CALLED;
260 channelFinished();
261 }
262 }
263 } else if (allAreSet(config, CONF_FLAG_PASS_CLOSE)) {
264 next.terminateWrites();
265 }
266
267 }
268
269 @Override
270 public void truncateWrites() throws IOException {
271 try {
272 if (!anyAreSet(state, FLAG_FINISHED_CALLED)) {
273 state |= FLAG_FINISHED_CALLED;
274 channelFinished();
275 }
276 } finally {
277 super.truncateWrites();
278 }
279 }
280
281 public void awaitWritable() throws IOException {
282 next.awaitWritable();
283 }
284
285 public void awaitWritable(final long time, final TimeUnit timeUnit) throws IOException {
286 next.awaitWritable(time, timeUnit);
287 }
288
289
290
295 public long getRemaining() {
296 return state & MASK_COUNT;
297 }
298
299 private void exitWrite(long oldVal, long consumed) {
300 long newVal = oldVal - consumed;
301 state = newVal;
302 }
303
304 private void exitFlush(long oldVal, boolean flushed) {
305 long newVal = oldVal;
306 boolean callFinish = false;
307 if ((anyAreSet(oldVal, FLAG_CLOSE_REQUESTED) || (newVal & MASK_COUNT) == 0L) && flushed) {
308 newVal |= FLAG_CLOSE_COMPLETE;
309
310 if (!anyAreSet(oldVal, FLAG_FINISHED_CALLED) && (newVal & MASK_COUNT) == 0L) {
311 newVal |= FLAG_FINISHED_CALLED;
312 callFinish = true;
313 }
314 state = newVal;
315 if (callFinish) {
316 channelFinished();
317 }
318 }
319 }
320
321 protected void channelFinished() {
322 }
323
324 private long enterShutdown() {
325 long oldVal, newVal;
326 oldVal = state;
327 if (anyAreSet(oldVal, FLAG_CLOSE_REQUESTED | FLAG_CLOSE_COMPLETE)) {
328
329 return oldVal;
330 }
331 newVal = oldVal | FLAG_CLOSE_REQUESTED;
332 if (anyAreSet(oldVal, MASK_COUNT)) {
333
334 newVal |= FLAG_CLOSE_COMPLETE;
335 }
336 state = newVal;
337 return oldVal;
338 }
339
340 }
341