1
18
19 package io.undertow.conduits;
20
21 import static org.xnio.Bits.allAreClear;
22 import static org.xnio.Bits.allAreSet;
23 import static org.xnio.Bits.anyAreSet;
24
25 import java.io.IOException;
26 import java.nio.ByteBuffer;
27 import java.nio.channels.ClosedChannelException;
28 import java.nio.channels.FileChannel;
29 import java.util.concurrent.TimeUnit;
30 import java.util.zip.Deflater;
31
32 import io.undertow.server.Connectors;
33 import org.xnio.IoUtils;
34 import io.undertow.connector.PooledByteBuffer;
35 import org.xnio.XnioIoThread;
36 import org.xnio.XnioWorker;
37 import org.xnio.channels.StreamSourceChannel;
38 import org.xnio.conduits.ConduitWritableByteChannel;
39 import org.xnio.conduits.Conduits;
40 import org.xnio.conduits.StreamSinkConduit;
41 import org.xnio.conduits.WriteReadyHandler;
42
43 import io.undertow.UndertowLogger;
44 import io.undertow.server.HttpServerExchange;
45 import io.undertow.util.ConduitFactory;
46 import io.undertow.util.NewInstanceObjectPool;
47 import io.undertow.util.ObjectPool;
48 import io.undertow.util.Headers;
49 import io.undertow.util.PooledObject;
50 import io.undertow.util.SimpleObjectPool;
51
52
57 public class DeflatingStreamSinkConduit implements StreamSinkConduit {
58
59 protected volatile Deflater deflater;
60
61 protected final PooledObject<Deflater> pooledObject;
62 private final ConduitFactory<StreamSinkConduit> conduitFactory;
63 private final HttpServerExchange exchange;
64
65 private StreamSinkConduit next;
66 private WriteReadyHandler writeReadyHandler;
67
68
69
72 protected PooledByteBuffer currentBuffer;
73
76 private ByteBuffer additionalBuffer;
77
78 private int state = 0;
79
80 private static final int SHUTDOWN = 1;
81 private static final int NEXT_SHUTDOWN = 1 << 1;
82 private static final int FLUSHING_BUFFER = 1 << 2;
83 private static final int WRITES_RESUMED = 1 << 3;
84 private static final int CLOSED = 1 << 4;
85 private static final int WRITTEN_TRAILER = 1 << 5;
86
87 public DeflatingStreamSinkConduit(final ConduitFactory<StreamSinkConduit> conduitFactory, final HttpServerExchange exchange) {
88 this(conduitFactory, exchange, Deflater.DEFLATED);
89 }
90
91 public DeflatingStreamSinkConduit(final ConduitFactory<StreamSinkConduit> conduitFactory, final HttpServerExchange exchange, int deflateLevel) {
92 this(conduitFactory, exchange, newInstanceDeflaterPool(deflateLevel));
93 }
94
95 public DeflatingStreamSinkConduit(final ConduitFactory<StreamSinkConduit> conduitFactory, final HttpServerExchange exchange, ObjectPool<Deflater> deflaterPool) {
96 this.pooledObject = deflaterPool.allocate();
97 this.deflater = pooledObject.getObject();
98 this.currentBuffer = exchange.getConnection().getByteBufferPool().allocate();
99 this.exchange = exchange;
100 this.conduitFactory = conduitFactory;
101 setWriteReadyHandler(new WriteReadyHandler.ChannelListenerHandler<>(Connectors.getConduitSinkChannel(exchange)));
102 }
103
104 public static ObjectPool<Deflater> newInstanceDeflaterPool(int deflateLevel) {
105 return new NewInstanceObjectPool<Deflater>(() -> new Deflater(deflateLevel, true), Deflater::end);
106 }
107
108 public static ObjectPool<Deflater> simpleDeflaterPool(int poolSize, int deflateLevel) {
109 return new SimpleObjectPool<Deflater>(poolSize, () -> new Deflater(deflateLevel, true), Deflater::reset, Deflater::end);
110 }
111
112
113 @Override
114 public int write(final ByteBuffer src) throws IOException {
115 if (anyAreSet(state, SHUTDOWN | CLOSED) || currentBuffer == null) {
116 throw new ClosedChannelException();
117 }
118 try {
119 if (!performFlushIfRequired()) {
120 return 0;
121 }
122 if (src.remaining() == 0) {
123 return 0;
124 }
125
126 if (!deflater.needsInput()) {
127 deflateData(false);
128 if (!deflater.needsInput()) {
129 return 0;
130 }
131 }
132 byte[] data = new byte[src.remaining()];
133 src.get(data);
134 preDeflate(data);
135 deflater.setInput(data);
136 Connectors.updateResponseBytesSent(exchange, 0 - data.length);
137 deflateData(false);
138 return data.length;
139 } catch (IOException | RuntimeException | Error e) {
140 freeBuffer();
141 throw e;
142 }
143 }
144
145 protected void preDeflate(byte[] data) {
146
147 }
148
149 @Override
150 public long write(final ByteBuffer[] srcs, final int offset, final int length) throws IOException {
151 if (anyAreSet(state, SHUTDOWN | CLOSED) || currentBuffer == null) {
152 throw new ClosedChannelException();
153 }
154 try {
155 int total = 0;
156 for (int i = offset; i < offset + length; ++i) {
157 if (srcs[i].hasRemaining()) {
158 int ret = write(srcs[i]);
159 total += ret;
160 if (ret == 0) {
161 return total;
162 }
163 }
164 }
165 return total;
166 } catch (IOException | RuntimeException | Error e) {
167 freeBuffer();
168 throw e;
169 }
170 }
171
172 @Override
173 public int writeFinal(ByteBuffer src) throws IOException {
174 return Conduits.writeFinalBasic(this, src);
175 }
176
177 @Override
178 public long writeFinal(ByteBuffer[] srcs, int offset, int length) throws IOException {
179 return Conduits.writeFinalBasic(this, srcs, offset, length);
180 }
181
182 @Override
183 public long transferFrom(final FileChannel src, final long position, final long count) throws IOException {
184 if (anyAreSet(state, SHUTDOWN | CLOSED)) {
185 throw new ClosedChannelException();
186 }
187 if (!performFlushIfRequired()) {
188 return 0;
189 }
190 return src.transferTo(position, count, new ConduitWritableByteChannel(this));
191 }
192
193
194 @Override
195 public long transferFrom(final StreamSourceChannel source, final long count, final ByteBuffer throughBuffer) throws IOException {
196 if (anyAreSet(state, SHUTDOWN | CLOSED)) {
197 throw new ClosedChannelException();
198 }
199 if (!performFlushIfRequired()) {
200 return 0;
201 }
202 return IoUtils.transfer(source, count, throughBuffer, new ConduitWritableByteChannel(this));
203 }
204
205 @Override
206 public XnioWorker getWorker() {
207 return exchange.getConnection().getWorker();
208 }
209
210 @Override
211 public void suspendWrites() {
212 if (next == null) {
213 state = state & ~WRITES_RESUMED;
214 } else {
215 next.suspendWrites();
216 }
217 }
218
219
220 @Override
221 public boolean isWriteResumed() {
222 if (next == null) {
223 return anyAreSet(state, WRITES_RESUMED);
224 } else {
225 return next.isWriteResumed();
226 }
227 }
228
229 @Override
230 public void wakeupWrites() {
231 if (next == null) {
232 resumeWrites();
233 } else {
234 next.wakeupWrites();
235 }
236 }
237
238 @Override
239 public void resumeWrites() {
240 if (next == null) {
241 state |= WRITES_RESUMED;
242 queueWriteListener();
243 } else {
244 next.resumeWrites();
245 }
246 }
247
248 private void queueWriteListener() {
249 exchange.getConnection().getIoThread().execute(new Runnable() {
250 @Override
251 public void run() {
252 if (writeReadyHandler != null) {
253 try {
254 writeReadyHandler.writeReady();
255 } finally {
256
257 if (next == null && isWriteResumed()) {
258 queueWriteListener();
259 }
260 }
261 }
262 }
263 });
264 }
265
266 @Override
267 public void terminateWrites() throws IOException {
268 if (deflater != null) {
269 deflater.finish();
270 }
271 state |= SHUTDOWN;
272 }
273
274 @Override
275 public boolean isWriteShutdown() {
276 return anyAreSet(state, SHUTDOWN);
277 }
278
279 @Override
280 public void awaitWritable() throws IOException {
281 if (next == null) {
282 return;
283 } else {
284 next.awaitWritable();
285 }
286 }
287
288 @Override
289 public void awaitWritable(final long time, final TimeUnit timeUnit) throws IOException {
290 if (next == null) {
291 return;
292 } else {
293 next.awaitWritable(time, timeUnit);
294 }
295 }
296
297 @Override
298 public XnioIoThread getWriteThread() {
299 return exchange.getConnection().getIoThread();
300 }
301
302 @Override
303 public void setWriteReadyHandler(final WriteReadyHandler handler) {
304 this.writeReadyHandler = handler;
305 }
306
307 @Override
308 public boolean flush() throws IOException {
309 if (currentBuffer == null) {
310 if (anyAreSet(state, NEXT_SHUTDOWN)) {
311 return next.flush();
312 } else {
313 return true;
314 }
315 }
316 try {
317 boolean nextCreated = false;
318 try {
319 if (anyAreSet(state, SHUTDOWN)) {
320 if (anyAreSet(state, NEXT_SHUTDOWN)) {
321 return next.flush();
322 } else {
323 if (!performFlushIfRequired()) {
324 return false;
325 }
326
327 if (!deflater.finished()) {
328 deflateData(false);
329
330 if (!deflater.finished()) {
331 return false;
332 }
333 }
334 final ByteBuffer buffer = currentBuffer.getBuffer();
335 if (allAreClear(state, WRITTEN_TRAILER)) {
336 state |= WRITTEN_TRAILER;
337 byte[] data = getTrailer();
338 if (data != null) {
339 Connectors.updateResponseBytesSent(exchange, data.length);
340 if(additionalBuffer != null) {
341 byte[] newData = new byte[additionalBuffer.remaining() + data.length];
342 int pos = 0;
343 while (additionalBuffer.hasRemaining()) {
344 newData[pos++] = additionalBuffer.get();
345 }
346 for (byte aData : data) {
347 newData[pos++] = aData;
348 }
349 this.additionalBuffer = ByteBuffer.wrap(newData);
350 } else if(anyAreSet(state, FLUSHING_BUFFER) && buffer.capacity() - buffer.remaining() >= data.length) {
351 buffer.compact();
352 buffer.put(data);
353 buffer.flip();
354 } else if (data.length <= buffer.remaining() && !anyAreSet(state, FLUSHING_BUFFER)) {
355 buffer.put(data);
356 } else {
357 additionalBuffer = ByteBuffer.wrap(data);
358 }
359 }
360 }
361
362
363 if (!anyAreSet(state, FLUSHING_BUFFER)) {
364 buffer.flip();
365 state |= FLUSHING_BUFFER;
366 if (next == null) {
367 nextCreated = true;
368 this.next = createNextChannel();
369 }
370 }
371 if (performFlushIfRequired()) {
372 state |= NEXT_SHUTDOWN;
373 freeBuffer();
374 next.terminateWrites();
375 return next.flush();
376 } else {
377 return false;
378 }
379 }
380 } else {
381 if(allAreClear(state, FLUSHING_BUFFER)) {
382 if (next == null) {
383 nextCreated = true;
384 this.next = createNextChannel();
385 }
386 deflateData(true);
387 if(allAreClear(state, FLUSHING_BUFFER)) {
388
389 currentBuffer.getBuffer().flip();
390 this.state |= FLUSHING_BUFFER;
391 }
392 }
393 if(!performFlushIfRequired()) {
394 return false;
395 }
396 return next.flush();
397 }
398 } finally {
399 if (nextCreated) {
400 if (anyAreSet(state, WRITES_RESUMED) && !anyAreSet(state ,NEXT_SHUTDOWN)) {
401 try {
402 next.resumeWrites();
403 } catch (Throwable e) {
404 UndertowLogger.REQUEST_LOGGER.debug("Failed to resume", e);
405 }
406 }
407 }
408 }
409 } catch (IOException | RuntimeException | Error e) {
410 freeBuffer();
411 throw e;
412 }
413 }
414
415
418 protected byte[] getTrailer() {
419 return null;
420 }
421
422
427 private boolean performFlushIfRequired() throws IOException {
428 if (anyAreSet(state, FLUSHING_BUFFER)) {
429 final ByteBuffer[] bufs = new ByteBuffer[additionalBuffer == null ? 1 : 2];
430 long totalLength = 0;
431 bufs[0] = currentBuffer.getBuffer();
432 totalLength += bufs[0].remaining();
433 if (additionalBuffer != null) {
434 bufs[1] = additionalBuffer;
435 totalLength += bufs[1].remaining();
436 }
437 if (totalLength > 0) {
438 long total = 0;
439 long res = 0;
440 do {
441 res = next.write(bufs, 0, bufs.length);
442 total += res;
443 if (res == 0) {
444 return false;
445 }
446 } while (total < totalLength);
447 }
448 additionalBuffer = null;
449 currentBuffer.getBuffer().clear();
450 state = state & ~FLUSHING_BUFFER;
451 }
452 return true;
453 }
454
455
456 private StreamSinkConduit createNextChannel() {
457 if (deflater.finished() && allAreSet(state, WRITTEN_TRAILER)) {
458
459
460 int remaining = currentBuffer.getBuffer().remaining();
461 if (additionalBuffer != null) {
462 remaining += additionalBuffer.remaining();
463 }
464 if(!exchange.getResponseHeaders().contains(Headers.TRANSFER_ENCODING)) {
465 exchange.getResponseHeaders().put(Headers.CONTENT_LENGTH, Integer.toString(remaining));
466 }
467 } else {
468 exchange.getResponseHeaders().remove(Headers.CONTENT_LENGTH);
469 }
470 return conduitFactory.create();
471 }
472
473
479 private void deflateData(boolean force) throws IOException {
480
481
482 boolean nextCreated = false;
483 try (PooledByteBuffer arrayPooled = this.exchange.getConnection().getByteBufferPool().getArrayBackedPool().allocate()) {
484 PooledByteBuffer pooled = this.currentBuffer;
485 final ByteBuffer outputBuffer = pooled.getBuffer();
486
487 final boolean shutdown = anyAreSet(state, SHUTDOWN);
488 ByteBuffer buf = arrayPooled.getBuffer();
489 while (force || !deflater.needsInput() || (shutdown && !deflater.finished())) {
490 int count = deflater.deflate(buf.array(), buf.arrayOffset(), buf.remaining(), force ? Deflater.SYNC_FLUSH: Deflater.NO_FLUSH);
491 Connectors.updateResponseBytesSent(exchange, count);
492 if (count != 0) {
493 int remaining = outputBuffer.remaining();
494 if (remaining > count) {
495 outputBuffer.put(buf.array(), buf.arrayOffset(), count);
496 } else {
497 if (remaining == count) {
498 outputBuffer.put(buf.array(), buf.arrayOffset(), count);
499 } else {
500 outputBuffer.put(buf.array(), buf.arrayOffset(), remaining);
501 additionalBuffer = ByteBuffer.allocate(count - remaining);
502 additionalBuffer.put(buf.array(), buf.arrayOffset() + remaining, count - remaining);
503 additionalBuffer.flip();
504 }
505 outputBuffer.flip();
506 this.state |= FLUSHING_BUFFER;
507 if (next == null) {
508 nextCreated = true;
509 this.next = createNextChannel();
510 }
511 if (!performFlushIfRequired()) {
512 return;
513 }
514 }
515 } else {
516 force = false;
517 }
518 }
519 } finally {
520 if (nextCreated) {
521 if (anyAreSet(state, WRITES_RESUMED)) {
522 next.resumeWrites();
523 }
524 }
525 }
526 }
527
528
529 @Override
530 public void truncateWrites() throws IOException {
531 freeBuffer();
532 state |= CLOSED;
533 next.truncateWrites();
534 }
535
536 private void freeBuffer() {
537 if (currentBuffer != null) {
538 currentBuffer.close();
539 currentBuffer = null;
540 state = state & ~FLUSHING_BUFFER;
541 }
542 if (deflater != null) {
543 deflater = null;
544 pooledObject.close();
545 }
546 }
547 }
548