1 /*
2  * JBoss, Home of Professional Open Source.
3  * Copyright 2014 Red Hat, Inc., and individual contributors
4  * as indicated by the @author tags.
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  *     http://www.apache.org/licenses/LICENSE-2.0
11  *
12  *  Unless required by applicable law or agreed to in writing, software
13  *  distributed under the License is distributed on an "AS IS" BASIS,
14  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  *  See the License for the specific language governing permissions and
16  *  limitations under the License.
17  */

18
19 package io.undertow.conduits;
20
21 import io.undertow.UndertowLogger;
22 import org.xnio.Buffers;
23 import org.xnio.channels.FixedLengthOverflowException;
24 import org.xnio.channels.StreamSourceChannel;
25 import org.xnio.conduits.AbstractStreamSinkConduit;
26 import org.xnio.conduits.Conduits;
27 import org.xnio.conduits.StreamSinkConduit;
28
29 import java.io.IOException;
30 import java.nio.ByteBuffer;
31 import java.nio.channels.ClosedChannelException;
32 import java.nio.channels.FileChannel;
33 import java.util.concurrent.TimeUnit;
34
35 import static java.lang.Math.min;
36 import static org.xnio.Bits.allAreClear;
37 import static org.xnio.Bits.allAreSet;
38 import static org.xnio.Bits.anyAreSet;
39 import static org.xnio.Bits.longBitMask;
40
41 /**
42  * A channel which writes a fixed amount of data.  A listener is called once the data has been written.
43  *
44  * @author <a href="mailto:david.lloyd@redhat.com">David M. Lloyd</a>
45  */

46 public abstract class AbstractFixedLengthStreamSinkConduit extends AbstractStreamSinkConduit<StreamSinkConduit> {
47     private int config;
48
49     private long state;
50
51     private boolean broken = false;
52
53     private static final int CONF_FLAG_CONFIGURABLE = 1 << 0;
54     private static final int CONF_FLAG_PASS_CLOSE = 1 << 1;
55
56     private static final long FLAG_CLOSE_REQUESTED = 1L << 63L;
57     private static final long FLAG_CLOSE_COMPLETE = 1L << 62L;
58     private static final long FLAG_FINISHED_CALLED = 1L << 61L;
59     private static final long MASK_COUNT = longBitMask(0, 60);
60
61     /**
62      * Construct a new instance.
63      *
64      * @param next           the next channel
65      * @param contentLength  the content length
66      * @param configurable   {@code trueif this instance should pass configuration to the next
67      * @param propagateClose {@code trueif this instance should pass close to the next
68      */

69     public AbstractFixedLengthStreamSinkConduit(final StreamSinkConduit next, final long contentLength, final boolean configurable, final boolean propagateClose) {
70         super(next);
71         if (contentLength < 0L) {
72             throw new IllegalArgumentException("Content length must be greater than or equal to zero");
73         } else if (contentLength > MASK_COUNT) {
74             throw new IllegalArgumentException("Content length is too long");
75         }
76         config = (configurable ? CONF_FLAG_CONFIGURABLE : 0) | (propagateClose ? CONF_FLAG_PASS_CLOSE : 0);
77         this.state = contentLength;
78     }
79
80     protected void reset(long contentLength, boolean propagateClose) {
81         this.state = contentLength;
82         if (propagateClose) {
83             config |= CONF_FLAG_PASS_CLOSE;
84         } else {
85             config &= ~CONF_FLAG_PASS_CLOSE;
86         }
87     }
88
89     public int write(final ByteBuffer src) throws IOException {
90         long val = state;
91         final long remaining = val & MASK_COUNT;
92         if (!src.hasRemaining()) {
93             return 0;
94         }
95         if (allAreSet(val, FLAG_CLOSE_REQUESTED)) {
96             throw new ClosedChannelException();
97         }
98         int oldLimit = src.limit();
99         if (remaining == 0) {
100             throw new FixedLengthOverflowException();
101         } else if (src.remaining() > remaining) {
102             src.limit((int) (src.position() + remaining));
103         }
104         int res = 0;
105         try {
106             return res = next.write(src);
107         } catch (IOException | RuntimeException | Error e) {
108             broken = true;
109             throw e;
110         } finally {
111             src.limit(oldLimit);
112             exitWrite(val, (long) res);
113         }
114     }
115
116     public long write(final ByteBuffer[] srcs, final int offset, final int length) throws IOException {
117         if (length == 0) {
118             return 0L;
119         } else if (length == 1) {
120             return write(srcs[offset]);
121         }
122         long val = state;
123         final long remaining = val & MASK_COUNT;
124         if (allAreSet(val, FLAG_CLOSE_REQUESTED)) {
125             throw new ClosedChannelException();
126         }
127         long toWrite = Buffers.remaining(srcs, offset, length);
128         if (remaining == 0) {
129             throw new FixedLengthOverflowException();
130         }
131         int[] limits = null;
132         if (toWrite > remaining) {
133             limits = new int[length];
134             long r = remaining;
135             for (int i = offset; i < offset + length; ++i) {
136                 limits[i - offset] = srcs[i].limit();
137                 int br = srcs[i].remaining();
138                 if(br < r) {
139                     r -= br;
140                 } else {
141                     srcs[i].limit((int) (srcs[i].position() + r));
142                     r = 0;
143                 }
144             }
145         }
146         long res = 0L;
147         try {
148             return res = next.write(srcs, offset, length);
149         } catch (IOException | RuntimeException | Error e) {
150             broken = true;
151             throw e;
152         } finally {
153             if (limits != null) {
154                 for (int i = offset; i < offset + length; ++i) {
155                     srcs[i].limit(limits[i - offset]);
156                 }
157             }
158             exitWrite(val, res);
159         }
160     }
161
162     @Override
163     public long writeFinal(ByteBuffer[] srcs, int offset, int length) throws IOException {
164         try {
165             return Conduits.writeFinalBasic(this, srcs, offset, length);
166         } catch (IOException | RuntimeException | Error e) {
167             broken = true;
168             throw e;
169         }
170     }
171
172     @Override
173     public int writeFinal(ByteBuffer src) throws IOException {
174         try {
175             return Conduits.writeFinalBasic(this, src);
176         } catch (IOException | RuntimeException | Error e) {
177             broken = true;
178             throw e;
179         }
180     }
181
182     public long transferFrom(final FileChannel src, final long position, final long count) throws IOException {
183         if (count == 0L) return 0L;
184         long val = state;
185         if (allAreSet(val, FLAG_CLOSE_REQUESTED)) {
186             throw new ClosedChannelException();
187         }
188         if (allAreClear(val, MASK_COUNT)) {
189             throw new FixedLengthOverflowException();
190         }
191         long res = 0L;
192         try {
193             return res = next.transferFrom(src, position, min(count, (val & MASK_COUNT)));
194         } catch (IOException | RuntimeException | Error e) {
195             broken = true;
196             throw e;
197         } finally {
198             exitWrite(val, res);
199         }
200     }
201
202     public long transferFrom(final StreamSourceChannel source, final long count, final ByteBuffer throughBuffer) throws IOException {
203         if (count == 0L) return 0L;
204         long val = state;
205         if (allAreSet(val, FLAG_CLOSE_REQUESTED)) {
206             throw new ClosedChannelException();
207         }
208         if (allAreClear(val, MASK_COUNT)) {
209             throw new FixedLengthOverflowException();
210         }
211         long res = 0L;
212         try {
213             return res = next.transferFrom(source, min(count, (val & MASK_COUNT)), throughBuffer);
214         } catch (IOException | RuntimeException | Error e) {
215             broken = true;
216             throw e;
217         } finally {
218             exitWrite(val, res);
219         }
220     }
221
222     public boolean flush() throws IOException {
223         long val = state;
224         if (anyAreSet(val, FLAG_CLOSE_COMPLETE)) {
225             return true;
226         }
227         boolean flushed = false;
228         try {
229             return flushed = next.flush();
230         } catch (IOException | RuntimeException | Error e) {
231             broken = true;
232             throw e;
233         } finally {
234             exitFlush(val, flushed);
235         }
236     }
237
238     public boolean isWriteResumed() {
239         // not perfect but not provably wrong either...
240         return allAreClear(state, FLAG_CLOSE_COMPLETE) && next.isWriteResumed();
241     }
242
243     public void wakeupWrites() {
244         long val = state;
245         if (anyAreSet(val, FLAG_CLOSE_COMPLETE)) {
246             return;
247         }
248         next.wakeupWrites();
249     }
250
251     public void terminateWrites() throws IOException {
252         final long val = enterShutdown();
253         if (anyAreSet(val, MASK_COUNT) && !broken) {
254             UndertowLogger.REQUEST_IO_LOGGER.debugf("Fixed length stream closed with with %s bytes remaining", val & MASK_COUNT);
255             try {
256                 next.truncateWrites();
257             } finally {
258                 if (!anyAreSet(state, FLAG_FINISHED_CALLED)) {
259                     state |= FLAG_FINISHED_CALLED;
260                     channelFinished();
261                 }
262             }
263         } else if (allAreSet(config, CONF_FLAG_PASS_CLOSE)) {
264             next.terminateWrites();
265         }
266
267     }
268
269     @Override
270     public void truncateWrites() throws IOException {
271         try {
272             if (!anyAreSet(state, FLAG_FINISHED_CALLED)) {
273                 state |= FLAG_FINISHED_CALLED;
274                 channelFinished();
275             }
276         } finally {
277             super.truncateWrites();
278         }
279     }
280
281     public void awaitWritable() throws IOException {
282         next.awaitWritable();
283     }
284
285     public void awaitWritable(final long time, final TimeUnit timeUnit) throws IOException {
286         next.awaitWritable(time, timeUnit);
287     }
288
289
290     /**
291      * Get the number of remaining bytes in this fixed length channel.
292      *
293      * @return the number of remaining bytes
294      */

295     public long getRemaining() {
296         return state & MASK_COUNT;
297     }
298
299     private void exitWrite(long oldVal, long consumed) {
300         long newVal = oldVal - consumed;
301         state = newVal;
302     }
303
304     private void exitFlush(long oldVal, boolean flushed) {
305         long newVal = oldVal;
306         boolean callFinish = false;
307         if ((anyAreSet(oldVal, FLAG_CLOSE_REQUESTED) || (newVal & MASK_COUNT) == 0L) && flushed) {
308             newVal |= FLAG_CLOSE_COMPLETE;
309
310             if (!anyAreSet(oldVal, FLAG_FINISHED_CALLED) && (newVal & MASK_COUNT) == 0L) {
311                 newVal |= FLAG_FINISHED_CALLED;
312                 callFinish = true;
313             }
314             state = newVal;
315             if (callFinish) {
316                 channelFinished();
317             }
318         }
319     }
320
321     protected void channelFinished() {
322     }
323
324     private long enterShutdown() {
325         long oldVal, newVal;
326         oldVal = state;
327         if (anyAreSet(oldVal, FLAG_CLOSE_REQUESTED | FLAG_CLOSE_COMPLETE)) {
328             // no action necessary
329             return oldVal;
330         }
331         newVal = oldVal | FLAG_CLOSE_REQUESTED;
332         if (anyAreSet(oldVal, MASK_COUNT)) {
333             // error: channel not filled.  set both close flags.
334             newVal |= FLAG_CLOSE_COMPLETE;
335         }
336         state = newVal;
337         return oldVal;
338     }
339
340 }
341