1 /*
2  * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License").
5  * You may not use this file except in compliance with the License.
6  * A copy of the License is located at
7  *
8  *  http://aws.amazon.com/apache2.0
9  *
10  * or in the "license" file accompanying this file. This file is distributed
11  * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12  * express or implied. See the License for the specific language governing
13  * permissions and limitations under the License.
14  */

15
16 package software.amazon.awssdk.http.apache.internal.conn;
17
18 import java.io.IOException;
19 import java.net.InetSocketAddress;
20 import java.net.Socket;
21 import java.util.ArrayList;
22 import java.util.Arrays;
23 import java.util.List;
24 import javax.net.ssl.HostnameVerifier;
25 import javax.net.ssl.SSLContext;
26 import javax.net.ssl.SSLSocket;
27 import org.apache.http.HttpHost;
28 import org.apache.http.conn.ssl.SSLConnectionSocketFactory;
29 import org.apache.http.protocol.HttpContext;
30 import software.amazon.awssdk.annotations.SdkInternalApi;
31 import software.amazon.awssdk.http.apache.internal.net.SdkSocket;
32 import software.amazon.awssdk.http.apache.internal.net.SdkSslSocket;
33 import software.amazon.awssdk.utils.Logger;
34
35 /**
36  * Used to enforce the preferred TLS protocol during SSL handshake.
37  */

38 @SdkInternalApi
39 public class SdkTlsSocketFactory extends SSLConnectionSocketFactory {
40
41     private static final Logger log = Logger.loggerFor(SdkTlsSocketFactory.class);
42     private final SSLContext sslContext;
43
44     public SdkTlsSocketFactory(final SSLContext sslContext, final HostnameVerifier hostnameVerifier) {
45         super(sslContext, hostnameVerifier);
46         if (sslContext == null) {
47             throw new IllegalArgumentException(
48                     "sslContext must not be null. " + "Use SSLContext.getDefault() if you are unsure.");
49         }
50         this.sslContext = sslContext;
51     }
52
53     /**
54      * {@inheritDoc} Used to enforce the preferred TLS protocol during SSL handshake.
55      */

56     @Override
57     protected final void prepareSocket(final SSLSocket socket) {
58         String[] supported = socket.getSupportedProtocols();
59         String[] enabled = socket.getEnabledProtocols();
60         log.debug(() -> String.format("socket.getSupportedProtocols(): %s, socket.getEnabledProtocols(): %s",
61                                       Arrays.toString(supported),
62                                       Arrays.toString(enabled)));
63         List<String> target = new ArrayList<>();
64         if (supported != null) {
65             // Append the preferred protocols in descending order of preference
66             // but only do so if the protocols are supported
67             TlsProtocol[] values = TlsProtocol.values();
68             for (TlsProtocol value : values) {
69                 String pname = value.getProtocolName();
70                 if (existsIn(pname, supported)) {
71                     target.add(pname);
72                 }
73             }
74         }
75         if (enabled != null) {
76             // Append the rest of the already enabled protocols to the end
77             // if not already included in the list
78             for (String pname : enabled) {
79                 if (!target.contains(pname)) {
80                     target.add(pname);
81                 }
82             }
83         }
84         if (target.size() > 0) {
85             String[] enabling = target.toArray(new String[0]);
86             socket.setEnabledProtocols(enabling);
87             log.debug(() -> "TLS protocol enabled for SSL handshake: " + Arrays.toString(enabling));
88         }
89     }
90
91     /**
92      * Returns true if the given element exists in the given array; false otherwise.
93      */

94     private boolean existsIn(String element, String[] a) {
95         for (String s : a) {
96             if (element.equals(s)) {
97                 return true;
98             }
99         }
100         return false;
101     }
102
103     @Override
104     public Socket connectSocket(
105             final int connectTimeout,
106             final Socket socket,
107             final HttpHost host,
108             final InetSocketAddress remoteAddress,
109             final InetSocketAddress localAddress,
110             final HttpContext context) throws IOException {
111         log.trace(() -> String.format("Connecting to %s:%s", remoteAddress.getAddress(), remoteAddress.getPort()));
112
113         Socket connectedSocket = super.connectSocket(connectTimeout, socket, host, remoteAddress, localAddress, context);
114
115         if (connectedSocket instanceof SSLSocket) {
116             return new SdkSslSocket((SSLSocket) connectedSocket);
117         }
118
119         return new SdkSocket(connectedSocket);
120     }
121
122 }
123