1 /*
2  * Copyright 2013-2020 the original author or authors.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      https://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */

16
17 package org.springframework.cloud.aws.core.io.s3;
18
19 import java.lang.reflect.Field;
20 import java.net.URI;
21 import java.net.URISyntaxException;
22 import java.util.concurrent.ConcurrentHashMap;
23
24 import com.amazonaws.auth.AWSCredentialsProvider;
25 import com.amazonaws.regions.Regions;
26 import com.amazonaws.services.s3.AmazonS3;
27 import com.amazonaws.services.s3.AmazonS3Client;
28 import com.amazonaws.services.s3.AmazonS3ClientBuilder;
29 import com.amazonaws.services.s3.AmazonS3URI;
30
31 import org.springframework.aop.framework.Advised;
32 import org.springframework.aop.support.AopUtils;
33 import org.springframework.cloud.aws.core.SpringCloudClientConfiguration;
34 import org.springframework.util.Assert;
35 import org.springframework.util.ReflectionUtils;
36
37 /**
38  * {@link AmazonS3} client factory that create clients for other regions based on the
39  * source client and a endpoint url. Caches clients per region to enable re-use on a
40  * region base.
41  *
42  * @author Agim Emruli
43  * @author EddĂș MelĂ©ndez
44  * @since 1.2
45  */

46 public class AmazonS3ClientFactory {
47
48     private static final String CREDENTIALS_PROVIDER_FIELD_NAME = "awsCredentialsProvider";
49
50     private final ConcurrentHashMap<String, AmazonS3> clientCache = new ConcurrentHashMap<>(
51             Regions.values().length);
52
53     private final Field credentialsProviderField;
54
55     public AmazonS3ClientFactory() {
56         this.credentialsProviderField = ReflectionUtils.findField(AmazonS3Client.class,
57                 CREDENTIALS_PROVIDER_FIELD_NAME);
58         Assert.notNull(this.credentialsProviderField,
59                 "Credentials Provider field not found, this class does not work with the current "
60                         + "AWS SDK release");
61         ReflectionUtils.makeAccessible(this.credentialsProviderField);
62     }
63
64     private static String getRegion(String endpointUrl) {
65         Assert.notNull(endpointUrl, "Endpoint Url must not be null");
66         try {
67             URI uri = new URI(endpointUrl);
68             if ("s3.amazonaws.com".equals(uri.getHost())) {
69                 return Regions.DEFAULT_REGION.getName();
70             }
71             else {
72                 return new AmazonS3URI(endpointUrl).getRegion();
73             }
74         }
75         catch (URISyntaxException e) {
76             throw new RuntimeException("Malformed URL received for endpoint", e);
77         }
78     }
79
80     private static AmazonS3Client getAmazonS3ClientFromProxy(AmazonS3 amazonS3) {
81         if (AopUtils.isAopProxy(amazonS3)) {
82             Advised advised = (Advised) amazonS3;
83             Object target = null;
84             try {
85                 target = advised.getTargetSource().getTarget();
86             }
87             catch (Exception e) {
88                 return null;
89             }
90             return target instanceof AmazonS3Client ? (AmazonS3Client) target : null;
91         }
92         else {
93             return amazonS3 instanceof AmazonS3Client ? (AmazonS3Client) amazonS3 : null;
94         }
95     }
96
97     public AmazonS3 createClientForEndpointUrl(AmazonS3 prototype, String endpointUrl) {
98         return createClientForEndpointUrl(prototype, endpointUrl, null);
99     }
100
101     AmazonS3 createClientForEndpointUrl(AmazonS3 prototype, String endpointUrl,
102             Regions bucketRegion) {
103         Assert.notNull(prototype, "AmazonS3 must not be null");
104         Assert.notNull(endpointUrl, "Endpoint Url must not be null");
105
106         String region = bucketRegion != null ? bucketRegion.getName()
107                 : getRegion(endpointUrl);
108         Assert.notNull(region,
109                 "Error detecting region from endpoint url:'" + endpointUrl + "'");
110
111         if (!this.clientCache.containsKey(region)) {
112             AmazonS3ClientBuilder amazonS3ClientBuilder = buildAmazonS3ForRegion(
113                     prototype, region);
114             this.clientCache.putIfAbsent(region, amazonS3ClientBuilder.build());
115         }
116
117         return this.clientCache.get(region);
118     }
119
120     private AmazonS3ClientBuilder buildAmazonS3ForRegion(AmazonS3 prototype,
121             String region) {
122         AmazonS3ClientBuilder clientBuilder = AmazonS3ClientBuilder.standard()
123                 .withClientConfiguration(
124                         SpringCloudClientConfiguration.getClientConfiguration());
125
126         AmazonS3Client target = getAmazonS3ClientFromProxy(prototype);
127         if (target != null) {
128             AWSCredentialsProvider awsCredentialsProvider = (AWSCredentialsProvider) ReflectionUtils
129                     .getField(this.credentialsProviderField, target);
130             clientBuilder.withCredentials(awsCredentialsProvider);
131         }
132
133         return clientBuilder.withRegion(region);
134     }
135
136 }
137