1 /*
2  * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License").
5  * You may not use this file except in compliance with the License.
6  * A copy of the License is located at
7  *
8  *  http://aws.amazon.com/apache2.0
9  *
10  * or in the "license" file accompanying this file. This file is distributed
11  * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12  * express or implied. See the License for the specific language governing
13  * permissions and limitations under the License.
14  */

15
16 package software.amazon.awssdk.core.internal.http;
17
18 import static software.amazon.awssdk.utils.FunctionalUtils.invokeSafely;
19
20 import java.io.InputStream;
21 import software.amazon.awssdk.annotations.SdkInternalApi;
22 import software.amazon.awssdk.core.Response;
23 import software.amazon.awssdk.core.internal.http.pipeline.RequestPipeline;
24 import software.amazon.awssdk.core.io.ReleasableInputStream;
25 import software.amazon.awssdk.http.ContentStreamProvider;
26 import software.amazon.awssdk.http.SdkHttpFullRequest;
27 import software.amazon.awssdk.utils.Logger;
28
29 /**
30  * Adds additional wrapping around the request {@link ContentStreamProvider}.
31  * <p>
32  * Currently, it ensures that the stream returned by the provider is not closeable.
33  *
34  * @param <OutputT> Type of unmarshalled response
35  */

36 @SdkInternalApi
37 public final class StreamManagingStage<OutputT> implements RequestPipeline<SdkHttpFullRequest, Response<OutputT>> {
38
39     private static final Logger log = Logger.loggerFor(StreamManagingStage.class);
40
41     private final RequestPipeline<SdkHttpFullRequest, Response<OutputT>> wrapped;
42
43     public StreamManagingStage(RequestPipeline<SdkHttpFullRequest, Response<OutputT>> wrapped) {
44         this.wrapped = wrapped;
45     }
46
47     @Override
48     public Response<OutputT> execute(SdkHttpFullRequest request, RequestExecutionContext context) throws Exception {
49         ClosingStreamProvider toBeClosed = null;
50         if (request.contentStreamProvider().isPresent()) {
51             toBeClosed = createManagedProvider(request.contentStreamProvider().get());
52             request = request.toBuilder().contentStreamProvider(toBeClosed).build();
53         }
54         try {
55             InterruptMonitor.checkInterrupted();
56             return wrapped.execute(request, context);
57         } finally {
58             if (toBeClosed != null) {
59                 toBeClosed.closeCurrentStream();
60             }
61         }
62     }
63
64     private static ClosingStreamProvider createManagedProvider(ContentStreamProvider contentStreamProvider) {
65         return new ClosingStreamProvider(contentStreamProvider);
66     }
67
68     private static class ClosingStreamProvider implements ContentStreamProvider {
69         private final ContentStreamProvider wrapped;
70         private InputStream currentStream;
71
72         ClosingStreamProvider(ContentStreamProvider wrapped) {
73             this.wrapped = wrapped;
74         }
75
76         @Override
77         public InputStream newStream() {
78             currentStream = wrapped.newStream();
79             return ReleasableInputStream.wrap(currentStream).disableClose();
80         }
81
82         void closeCurrentStream() {
83             if (currentStream != null) {
84                 invokeSafely(currentStream::close);
85                 currentStream = null;
86             }
87         }
88     }
89 }
90