1 /*
2  * Copyright 2014-2020 Amazon Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at:
7  *
8  *    http://aws.amazon.com/apache2.0
9  *
10  * This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
11  * OR CONDITIONS OF ANY KIND, either express or implied. See the
12  * License for the specific language governing permissions and
13  * limitations under the License.
14  */

15 package com.amazonaws.http.conn.ssl;
16
17 import com.amazonaws.annotation.ThreadSafe;
18 import com.amazonaws.http.apache.utils.HttpContextUtils;
19 import com.amazonaws.internal.SdkMetricsSocket;
20 import com.amazonaws.internal.SdkSSLMetricsSocket;
21 import com.amazonaws.internal.SdkSSLSocket;
22 import com.amazonaws.internal.SdkSocket;
23 import com.amazonaws.metrics.AwsSdkMetrics;
24 import com.amazonaws.util.JavaVersionParser;
25 import org.apache.commons.logging.Log;
26 import org.apache.commons.logging.LogFactory;
27 import org.apache.http.HttpHost;
28 import org.apache.http.conn.ssl.SSLConnectionSocketFactory;
29 import org.apache.http.protocol.HttpContext;
30
31 import javax.net.ssl.HostnameVerifier;
32 import javax.net.ssl.SSLContext;
33 import javax.net.ssl.SSLException;
34 import javax.net.ssl.SSLSession;
35 import javax.net.ssl.SSLSessionContext;
36 import javax.net.ssl.SSLSocket;
37 import java.io.IOException;
38 import java.net.InetSocketAddress;
39 import java.net.Proxy;
40 import java.net.Socket;
41 import java.util.ArrayList;
42 import java.util.Arrays;
43 import java.util.Enumeration;
44 import java.util.List;
45
46 /**
47  * Used to enforce the preferred TLS protocol during SSL handshake.
48  */

49 @ThreadSafe
50 public class SdkTLSSocketFactory extends SSLConnectionSocketFactory {
51
52     private static final Log LOG = LogFactory.getLog(SdkTLSSocketFactory.class);
53     private final SSLContext sslContext;
54     private final MasterSecretValidators.MasterSecretValidator masterSecretValidator;
55     private final ShouldClearSslSessionPredicate shouldClearSslSessionsPredicate;
56
57     public SdkTLSSocketFactory(final SSLContext sslContext, final HostnameVerifier hostnameVerifier) {
58         super(sslContext, hostnameVerifier);
59         if (sslContext == null) {
60             throw new IllegalArgumentException(
61                     "sslContext must not be null. " + "Use SSLContext.getDefault() if you are unsure.");
62         }
63         this.sslContext = sslContext;
64         this.masterSecretValidator = MasterSecretValidators.getMasterSecretValidator();
65         this.shouldClearSslSessionsPredicate = new ShouldClearSslSessionPredicate(JavaVersionParser.getCurrentJavaVersion());
66     }
67
68     @Override
69     public Socket createSocket(HttpContext ctx) throws IOException {
70         if (HttpContextUtils.disableSocketProxy(ctx)) {
71             return new Socket(Proxy.NO_PROXY);
72         }
73         return super.createSocket(ctx);
74     }
75
76     /**
77      * {@inheritDoc} Used to enforce the preferred TLS protocol during SSL handshake.
78      */

79     @Override
80     protected final void prepareSocket(final SSLSocket socket) {
81         String[] supported = socket.getSupportedProtocols();
82         String[] enabled = socket.getEnabledProtocols();
83         if (LOG.isDebugEnabled()) {
84             LOG.debug("socket.getSupportedProtocols(): " + Arrays.toString(supported)
85                     + ", socket.getEnabledProtocols(): " + Arrays.toString(enabled));
86         }
87         List<String> target = new ArrayList<String>();
88         if (supported != null) {
89             // Append the preferred protocols in descending order of preference
90             // but only do so if the protocols are supported
91             TLSProtocol[] values = TLSProtocol.values();
92             for (int i = 0; i < values.length; i++) {
93                 final String pname = values[i].getProtocolName();
94                 if (existsIn(pname, supported)) {
95                     target.add(pname);
96                 }
97             }
98         }
99         if (enabled != null) {
100             // Append the rest of the already enabled protocols to the end
101             // if not already included in the list
102             for (String pname : enabled) {
103                 if (!target.contains(pname)) {
104                     target.add(pname);
105                 }
106             }
107         }
108         if (target.size() > 0) {
109             String[] enabling = target.toArray(new String[target.size()]);
110             socket.setEnabledProtocols(enabling);
111             if (LOG.isDebugEnabled()) {
112                 LOG.debug("TLS protocol enabled for SSL handshake: " + Arrays.toString(enabling));
113             }
114         }
115     }
116
117     /**
118      * Returns true if the given element exists in the given array; false otherwise.
119      */

120     private boolean existsIn(String element, String[] a) {
121         for (String s : a) {
122             if (element.equals(s)) {
123                 return true;
124             }
125         }
126         return false;
127     }
128
129     public Socket connectSocket(
130             final int connectTimeout,
131             final Socket socket,
132             final HttpHost host,
133             final InetSocketAddress remoteAddress,
134             final InetSocketAddress localAddress,
135             final HttpContext context) throws IOException {
136         if (LOG.isDebugEnabled()) {
137             LOG.debug("connecting to " + remoteAddress.getAddress() + ":" + remoteAddress.getPort());
138         }
139         Socket connectedSocket;
140         try {
141             connectedSocket = super.connectSocket
142                     (connectTimeout, socket, host, remoteAddress, localAddress, context);
143             if (!masterSecretValidator.isMasterSecretValid(connectedSocket)) {
144                 throw log(new IllegalStateException("Invalid SSL master secret"));
145             }
146         } catch (final SSLException sslEx) {
147             if (shouldClearSslSessionsPredicate.test(sslEx)) {
148                 // clear any related sessions from our cache
149                 if (LOG.isDebugEnabled()) {
150                     LOG.debug("connection failed due to SSL error, clearing TLS session cache", sslEx);
151                 }
152                 clearSessionCache(sslContext.getClientSessionContext(), remoteAddress);
153             }
154             throw sslEx;
155         }
156
157         if (connectedSocket instanceof SSLSocket) {
158             SdkSSLSocket sslSocket = new SdkSSLSocket((SSLSocket) connectedSocket);
159             return AwsSdkMetrics.isHttpSocketReadMetricEnabled() ? new SdkSSLMetricsSocket(sslSocket) : sslSocket;
160         }
161         SdkSocket sdkSocket = new SdkSocket(connectedSocket);
162         return AwsSdkMetrics.isHttpSocketReadMetricEnabled() ? new SdkMetricsSocket(sdkSocket) : sdkSocket;
163     }
164
165     /**
166      * Invalidates all SSL/TLS sessions in {@code sessionContext} associated with {@code remoteAddress}.
167      *
168      * @param sessionContext collection of SSL/TLS sessions to be (potentially) invalidated
169      * @param remoteAddress  associated with sessions to invalidate
170      */

171     private void clearSessionCache(final SSLSessionContext sessionContext, final InetSocketAddress remoteAddress) {
172         final String hostName = remoteAddress.getHostName();
173         final int port = remoteAddress.getPort();
174         final Enumeration<byte[]> ids = sessionContext.getIds();
175
176         if (ids == null) {
177             return;
178         }
179
180         while (ids.hasMoreElements()) {
181             final byte[] id = ids.nextElement();
182             final SSLSession session = sessionContext.getSession(id);
183             if (session != null && session.getPeerHost() != null && session.getPeerHost().equalsIgnoreCase(hostName)
184                     && session.getPeerPort() == port) {
185                 session.invalidate();
186                 if (LOG.isDebugEnabled()) {
187                     LOG.debug("Invalidated session " + session);
188                 }
189             }
190         }
191     }
192
193     private <T extends Throwable> T log(T t) {
194         if (LOG.isDebugEnabled()) {
195             LOG.debug("", t);
196         }
197         return t;
198     }
199 }
200