1 /*
2  * JBoss, Home of Professional Open Source.
3  * Copyright 2014 Red Hat, Inc., and individual contributors
4  * as indicated by the @author tags.
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  *     http://www.apache.org/licenses/LICENSE-2.0
11  *
12  *  Unless required by applicable law or agreed to in writing, software
13  *  distributed under the License is distributed on an "AS IS" BASIS,
14  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  *  See the License for the specific language governing permissions and
16  *  limitations under the License.
17  */

18
19 package io.undertow.conduits;
20
21 import static org.xnio.Bits.allAreClear;
22 import static org.xnio.Bits.allAreSet;
23 import static org.xnio.Bits.anyAreSet;
24
25 import java.io.IOException;
26 import java.nio.ByteBuffer;
27 import java.nio.channels.ClosedChannelException;
28 import java.nio.channels.FileChannel;
29 import java.util.concurrent.TimeUnit;
30 import java.util.zip.Deflater;
31
32 import io.undertow.server.Connectors;
33 import org.xnio.IoUtils;
34 import io.undertow.connector.PooledByteBuffer;
35 import org.xnio.XnioIoThread;
36 import org.xnio.XnioWorker;
37 import org.xnio.channels.StreamSourceChannel;
38 import org.xnio.conduits.ConduitWritableByteChannel;
39 import org.xnio.conduits.Conduits;
40 import org.xnio.conduits.StreamSinkConduit;
41 import org.xnio.conduits.WriteReadyHandler;
42
43 import io.undertow.UndertowLogger;
44 import io.undertow.server.HttpServerExchange;
45 import io.undertow.util.ConduitFactory;
46 import io.undertow.util.NewInstanceObjectPool;
47 import io.undertow.util.ObjectPool;
48 import io.undertow.util.Headers;
49 import io.undertow.util.PooledObject;
50 import io.undertow.util.SimpleObjectPool;
51
52 /**
53  * Channel that handles deflate compression
54  *
55  * @author Stuart Douglas
56  */

57 public class DeflatingStreamSinkConduit implements StreamSinkConduit {
58
59     protected volatile Deflater deflater;
60
61     protected final PooledObject<Deflater> pooledObject;
62     private final ConduitFactory<StreamSinkConduit> conduitFactory;
63     private final HttpServerExchange exchange;
64
65     private StreamSinkConduit next;
66     private WriteReadyHandler writeReadyHandler;
67
68
69     /**
70      * The streams buffer. This is freed when the next is shutdown
71      */

72     protected PooledByteBuffer currentBuffer;
73     /**
74      * there may have been some additional data that did not fit into the first buffer
75      */

76     private ByteBuffer additionalBuffer;
77
78     private int state = 0;
79
80     private static final int SHUTDOWN = 1;
81     private static final int NEXT_SHUTDOWN = 1 << 1;
82     private static final int FLUSHING_BUFFER = 1 << 2;
83     private static final int WRITES_RESUMED = 1 << 3;
84     private static final int CLOSED = 1 << 4;
85     private static final int WRITTEN_TRAILER = 1 << 5;
86
87     public DeflatingStreamSinkConduit(final ConduitFactory<StreamSinkConduit> conduitFactory, final HttpServerExchange exchange) {
88         this(conduitFactory, exchange, Deflater.DEFLATED);
89     }
90
91     public DeflatingStreamSinkConduit(final ConduitFactory<StreamSinkConduit> conduitFactory, final HttpServerExchange exchange, int deflateLevel) {
92         this(conduitFactory, exchange, newInstanceDeflaterPool(deflateLevel));
93     }
94
95     public DeflatingStreamSinkConduit(final ConduitFactory<StreamSinkConduit> conduitFactory, final HttpServerExchange exchange, ObjectPool<Deflater> deflaterPool) {
96         this.pooledObject = deflaterPool.allocate();
97         this.deflater = pooledObject.getObject();
98         this.currentBuffer = exchange.getConnection().getByteBufferPool().allocate();
99         this.exchange = exchange;
100         this.conduitFactory = conduitFactory;
101         setWriteReadyHandler(new WriteReadyHandler.ChannelListenerHandler<>(Connectors.getConduitSinkChannel(exchange)));
102     }
103
104     public static ObjectPool<Deflater> newInstanceDeflaterPool(int deflateLevel) {
105         return new NewInstanceObjectPool<Deflater>(() -> new Deflater(deflateLevel, true), Deflater::end);
106     }
107
108     public static ObjectPool<Deflater> simpleDeflaterPool(int poolSize, int deflateLevel) {
109         return new SimpleObjectPool<Deflater>(poolSize, () -> new Deflater(deflateLevel, true), Deflater::reset, Deflater::end);
110     }
111
112
113     @Override
114     public int write(final ByteBuffer src) throws IOException {
115         if (anyAreSet(state, SHUTDOWN | CLOSED) || currentBuffer == null) {
116             throw new ClosedChannelException();
117         }
118         try {
119             if (!performFlushIfRequired()) {
120                 return 0;
121             }
122             if (src.remaining() == 0) {
123                 return 0;
124             }
125             //we may already have some input, if so compress it
126             if (!deflater.needsInput()) {
127                 deflateData(false);
128                 if (!deflater.needsInput()) {
129                     return 0;
130                 }
131             }
132             byte[] data = new byte[src.remaining()];
133             src.get(data);
134             preDeflate(data);
135             deflater.setInput(data);
136             Connectors.updateResponseBytesSent(exchange, 0 - data.length);
137             deflateData(false);
138             return data.length;
139         } catch (IOException | RuntimeException | Error e) {
140             freeBuffer();
141             throw e;
142         }
143     }
144
145     protected void preDeflate(byte[] data) {
146
147     }
148
149     @Override
150     public long write(final ByteBuffer[] srcs, final int offset, final int length) throws IOException {
151         if (anyAreSet(state, SHUTDOWN | CLOSED) || currentBuffer == null) {
152             throw new ClosedChannelException();
153         }
154         try {
155             int total = 0;
156             for (int i = offset; i < offset + length; ++i) {
157                 if (srcs[i].hasRemaining()) {
158                     int ret = write(srcs[i]);
159                     total += ret;
160                     if (ret == 0) {
161                         return total;
162                     }
163                 }
164             }
165             return total;
166         } catch (IOException | RuntimeException | Error e) {
167             freeBuffer();
168             throw e;
169         }
170     }
171
172     @Override
173     public int writeFinal(ByteBuffer src) throws IOException {
174         return Conduits.writeFinalBasic(this, src);
175     }
176
177     @Override
178     public long writeFinal(ByteBuffer[] srcs, int offset, int length) throws IOException {
179         return Conduits.writeFinalBasic(this, srcs, offset, length);
180     }
181
182     @Override
183     public long transferFrom(final FileChannel src, final long position, final long count) throws IOException {
184         if (anyAreSet(state, SHUTDOWN | CLOSED)) {
185             throw new ClosedChannelException();
186         }
187         if (!performFlushIfRequired()) {
188             return 0;
189         }
190         return src.transferTo(position, count, new ConduitWritableByteChannel(this));
191     }
192
193
194     @Override
195     public long transferFrom(final StreamSourceChannel source, final long count, final ByteBuffer throughBuffer) throws IOException {
196         if (anyAreSet(state, SHUTDOWN | CLOSED)) {
197             throw new ClosedChannelException();
198         }
199         if (!performFlushIfRequired()) {
200             return 0;
201         }
202         return IoUtils.transfer(source, count, throughBuffer, new ConduitWritableByteChannel(this));
203     }
204
205     @Override
206     public XnioWorker getWorker() {
207         return exchange.getConnection().getWorker();
208     }
209
210     @Override
211     public void suspendWrites() {
212         if (next == null) {
213             state = state & ~WRITES_RESUMED;
214         } else {
215             next.suspendWrites();
216         }
217     }
218
219
220     @Override
221     public boolean isWriteResumed() {
222         if (next == null) {
223             return anyAreSet(state, WRITES_RESUMED);
224         } else {
225             return next.isWriteResumed();
226         }
227     }
228
229     @Override
230     public void wakeupWrites() {
231         if (next == null) {
232             resumeWrites();
233         } else {
234             next.wakeupWrites();
235         }
236     }
237
238     @Override
239     public void resumeWrites() {
240         if (next == null) {
241             state |= WRITES_RESUMED;
242             queueWriteListener();
243         } else {
244             next.resumeWrites();
245         }
246     }
247
248     private void queueWriteListener() {
249         exchange.getConnection().getIoThread().execute(new Runnable() {
250             @Override
251             public void run() {
252                 if (writeReadyHandler != null) {
253                     try {
254                         writeReadyHandler.writeReady();
255                     } finally {
256                         //if writes are still resumed queue up another one
257                         if (next == null && isWriteResumed()) {
258                             queueWriteListener();
259                         }
260                     }
261                 }
262             }
263         });
264     }
265
266     @Override
267     public void terminateWrites() throws IOException {
268         if (deflater != null) {
269             deflater.finish();
270         }
271         state |= SHUTDOWN;
272     }
273
274     @Override
275     public boolean isWriteShutdown() {
276         return anyAreSet(state, SHUTDOWN);
277     }
278
279     @Override
280     public void awaitWritable() throws IOException {
281         if (next == null) {
282             return;
283         } else {
284             next.awaitWritable();
285         }
286     }
287
288     @Override
289     public void awaitWritable(final long time, final TimeUnit timeUnit) throws IOException {
290         if (next == null) {
291             return;
292         } else {
293             next.awaitWritable(time, timeUnit);
294         }
295     }
296
297     @Override
298     public XnioIoThread getWriteThread() {
299         return exchange.getConnection().getIoThread();
300     }
301
302     @Override
303     public void setWriteReadyHandler(final WriteReadyHandler handler) {
304         this.writeReadyHandler = handler;
305     }
306
307     @Override
308     public boolean flush() throws IOException {
309         if (currentBuffer == null) {
310             if (anyAreSet(state, NEXT_SHUTDOWN)) {
311                 return next.flush();
312             } else {
313                 return true;
314             }
315         }
316         try {
317             boolean nextCreated = false;
318             try {
319                 if (anyAreSet(state, SHUTDOWN)) {
320                     if (anyAreSet(state, NEXT_SHUTDOWN)) {
321                         return next.flush();
322                     } else {
323                         if (!performFlushIfRequired()) {
324                             return false;
325                         }
326                         //if the deflater has not been fully flushed we need to flush it
327                         if (!deflater.finished()) {
328                             deflateData(false);
329                             //if could not fully flush
330                             if (!deflater.finished()) {
331                                 return false;
332                             }
333                         }
334                         final ByteBuffer buffer = currentBuffer.getBuffer();
335                         if (allAreClear(state, WRITTEN_TRAILER)) {
336                             state |= WRITTEN_TRAILER;
337                             byte[] data = getTrailer();
338                             if (data != null) {
339                                 Connectors.updateResponseBytesSent(exchange, data.length);
340                                 if(additionalBuffer != null) {
341                                     byte[] newData = new byte[additionalBuffer.remaining() + data.length];
342                                     int pos = 0;
343                                     while (additionalBuffer.hasRemaining()) {
344                                         newData[pos++] = additionalBuffer.get();
345                                     }
346                                     for (byte aData : data) {
347                                         newData[pos++] = aData;
348                                     }
349                                     this.additionalBuffer = ByteBuffer.wrap(newData);
350                                 } else if(anyAreSet(state, FLUSHING_BUFFER) && buffer.capacity() - buffer.remaining() >= data.length) {
351                                     buffer.compact();
352                                     buffer.put(data);
353                                     buffer.flip();
354                                 } else if (data.length <= buffer.remaining() && !anyAreSet(state, FLUSHING_BUFFER)) {
355                                     buffer.put(data);
356                                 } else {
357                                     additionalBuffer = ByteBuffer.wrap(data);
358                                 }
359                             }
360                         }
361
362                         //ok the deflater is flushed, now we need to flush the buffer
363                         if (!anyAreSet(state, FLUSHING_BUFFER)) {
364                             buffer.flip();
365                             state |= FLUSHING_BUFFER;
366                             if (next == null) {
367                                 nextCreated = true;
368                                 this.next = createNextChannel();
369                             }
370                         }
371                         if (performFlushIfRequired()) {
372                             state |= NEXT_SHUTDOWN;
373                             freeBuffer();
374                             next.terminateWrites();
375                             return next.flush();
376                         } else {
377                             return false;
378                         }
379                     }
380                 } else {
381                     if(allAreClear(state, FLUSHING_BUFFER)) {
382                         if (next == null) {
383                             nextCreated = true;
384                             this.next = createNextChannel();
385                         }
386                         deflateData(true);
387                         if(allAreClear(state, FLUSHING_BUFFER)) {
388                             //deflateData can cause this to be change
389                             currentBuffer.getBuffer().flip();
390                             this.state |= FLUSHING_BUFFER;
391                         }
392                     }
393                     if(!performFlushIfRequired()) {
394                         return false;
395                     }
396                     return next.flush();
397                 }
398             } finally {
399                 if (nextCreated) {
400                     if (anyAreSet(state, WRITES_RESUMED) && !anyAreSet(state ,NEXT_SHUTDOWN)) {
401                         try {
402                             next.resumeWrites();
403                         } catch (Throwable e) {
404                             UndertowLogger.REQUEST_LOGGER.debug("Failed to resume", e);
405                         }
406                     }
407                 }
408             }
409         } catch (IOException | RuntimeException | Error e) {
410             freeBuffer();
411             throw e;
412         }
413     }
414
415     /**
416      * called before the stream is finally flushed.
417      */

418     protected byte[] getTrailer() {
419         return null;
420     }
421
422     /**
423      * The we are in the flushing state then we flush to the underlying stream, otherwise just return true
424      *
425      * @return false if there is still more to flush
426      */

427     private boolean performFlushIfRequired() throws IOException {
428         if (anyAreSet(state, FLUSHING_BUFFER)) {
429             final ByteBuffer[] bufs = new ByteBuffer[additionalBuffer == null ? 1 : 2];
430             long totalLength = 0;
431             bufs[0] = currentBuffer.getBuffer();
432             totalLength += bufs[0].remaining();
433             if (additionalBuffer != null) {
434                 bufs[1] = additionalBuffer;
435                 totalLength += bufs[1].remaining();
436             }
437             if (totalLength > 0) {
438                 long total = 0;
439                 long res = 0;
440                 do {
441                     res = next.write(bufs, 0, bufs.length);
442                     total += res;
443                     if (res == 0) {
444                         return false;
445                     }
446                 } while (total < totalLength);
447             }
448             additionalBuffer = null;
449             currentBuffer.getBuffer().clear();
450             state = state & ~FLUSHING_BUFFER;
451         }
452         return true;
453     }
454
455
456     private StreamSinkConduit createNextChannel() {
457         if (deflater.finished() && allAreSet(state, WRITTEN_TRAILER)) {
458             //the deflater was fully flushed before we created the channel. This means that what is in the buffer is
459             //all there is
460             int remaining = currentBuffer.getBuffer().remaining();
461             if (additionalBuffer != null) {
462                 remaining += additionalBuffer.remaining();
463             }
464             if(!exchange.getResponseHeaders().contains(Headers.TRANSFER_ENCODING)) {
465                 exchange.getResponseHeaders().put(Headers.CONTENT_LENGTH, Integer.toString(remaining));
466             }
467         } else {
468             exchange.getResponseHeaders().remove(Headers.CONTENT_LENGTH);
469         }
470         return conduitFactory.create();
471     }
472
473     /**
474      * Runs the current data through the deflater. As much as possible this will be buffered in the current output
475      * stream.
476      *
477      * @throws IOException
478      */

479     private void deflateData(boolean force) throws IOException {
480         //we don't need to flush here, as this should have been called already by the time we get to
481         //this point
482         boolean nextCreated = false;
483         try (PooledByteBuffer arrayPooled = this.exchange.getConnection().getByteBufferPool().getArrayBackedPool().allocate()) {
484             PooledByteBuffer pooled = this.currentBuffer;
485             final ByteBuffer outputBuffer = pooled.getBuffer();
486
487             final boolean shutdown = anyAreSet(state, SHUTDOWN);
488             ByteBuffer buf = arrayPooled.getBuffer();
489             while (force || !deflater.needsInput() || (shutdown && !deflater.finished())) {
490                 int count = deflater.deflate(buf.array(), buf.arrayOffset(), buf.remaining(), force ? Deflater.SYNC_FLUSH: Deflater.NO_FLUSH);
491                 Connectors.updateResponseBytesSent(exchange, count);
492                 if (count != 0) {
493                     int remaining = outputBuffer.remaining();
494                     if (remaining > count) {
495                         outputBuffer.put(buf.array(), buf.arrayOffset(), count);
496                     } else {
497                         if (remaining == count) {
498                             outputBuffer.put(buf.array(), buf.arrayOffset(), count);
499                         } else {
500                             outputBuffer.put(buf.array(), buf.arrayOffset(), remaining);
501                             additionalBuffer = ByteBuffer.allocate(count - remaining);
502                             additionalBuffer.put(buf.array(), buf.arrayOffset() + remaining, count - remaining);
503                             additionalBuffer.flip();
504                         }
505                         outputBuffer.flip();
506                         this.state |= FLUSHING_BUFFER;
507                         if (next == null) {
508                             nextCreated = true;
509                             this.next = createNextChannel();
510                         }
511                         if (!performFlushIfRequired()) {
512                             return;
513                         }
514                     }
515                 } else {
516                     force = false;
517                 }
518             }
519         } finally {
520             if (nextCreated) {
521                 if (anyAreSet(state, WRITES_RESUMED)) {
522                     next.resumeWrites();
523                 }
524             }
525         }
526     }
527
528
529     @Override
530     public void truncateWrites() throws IOException {
531         freeBuffer();
532         state |= CLOSED;
533         next.truncateWrites();
534     }
535
536     private void freeBuffer() {
537         if (currentBuffer != null) {
538             currentBuffer.close();
539             currentBuffer = null;
540             state = state & ~FLUSHING_BUFFER;
541         }
542         if (deflater != null) {
543             deflater = null;
544             pooledObject.close();
545         }
546     }
547 }
548