1
15 package com.amazonaws.http.conn.ssl;
16
17 import com.amazonaws.annotation.ThreadSafe;
18 import com.amazonaws.http.apache.utils.HttpContextUtils;
19 import com.amazonaws.internal.SdkMetricsSocket;
20 import com.amazonaws.internal.SdkSSLMetricsSocket;
21 import com.amazonaws.internal.SdkSSLSocket;
22 import com.amazonaws.internal.SdkSocket;
23 import com.amazonaws.metrics.AwsSdkMetrics;
24 import com.amazonaws.util.JavaVersionParser;
25 import org.apache.commons.logging.Log;
26 import org.apache.commons.logging.LogFactory;
27 import org.apache.http.HttpHost;
28 import org.apache.http.conn.ssl.SSLConnectionSocketFactory;
29 import org.apache.http.protocol.HttpContext;
30
31 import javax.net.ssl.HostnameVerifier;
32 import javax.net.ssl.SSLContext;
33 import javax.net.ssl.SSLException;
34 import javax.net.ssl.SSLSession;
35 import javax.net.ssl.SSLSessionContext;
36 import javax.net.ssl.SSLSocket;
37 import java.io.IOException;
38 import java.net.InetSocketAddress;
39 import java.net.Proxy;
40 import java.net.Socket;
41 import java.util.ArrayList;
42 import java.util.Arrays;
43 import java.util.Enumeration;
44 import java.util.List;
45
46
49 @ThreadSafe
50 public class SdkTLSSocketFactory extends SSLConnectionSocketFactory {
51
52 private static final Log LOG = LogFactory.getLog(SdkTLSSocketFactory.class);
53 private final SSLContext sslContext;
54 private final MasterSecretValidators.MasterSecretValidator masterSecretValidator;
55 private final ShouldClearSslSessionPredicate shouldClearSslSessionsPredicate;
56
57 public SdkTLSSocketFactory(final SSLContext sslContext, final HostnameVerifier hostnameVerifier) {
58 super(sslContext, hostnameVerifier);
59 if (sslContext == null) {
60 throw new IllegalArgumentException(
61 "sslContext must not be null. " + "Use SSLContext.getDefault() if you are unsure.");
62 }
63 this.sslContext = sslContext;
64 this.masterSecretValidator = MasterSecretValidators.getMasterSecretValidator();
65 this.shouldClearSslSessionsPredicate = new ShouldClearSslSessionPredicate(JavaVersionParser.getCurrentJavaVersion());
66 }
67
68 @Override
69 public Socket createSocket(HttpContext ctx) throws IOException {
70 if (HttpContextUtils.disableSocketProxy(ctx)) {
71 return new Socket(Proxy.NO_PROXY);
72 }
73 return super.createSocket(ctx);
74 }
75
76
79 @Override
80 protected final void prepareSocket(final SSLSocket socket) {
81 String[] supported = socket.getSupportedProtocols();
82 String[] enabled = socket.getEnabledProtocols();
83 if (LOG.isDebugEnabled()) {
84 LOG.debug("socket.getSupportedProtocols(): " + Arrays.toString(supported)
85 + ", socket.getEnabledProtocols(): " + Arrays.toString(enabled));
86 }
87 List<String> target = new ArrayList<String>();
88 if (supported != null) {
89
90
91 TLSProtocol[] values = TLSProtocol.values();
92 for (int i = 0; i < values.length; i++) {
93 final String pname = values[i].getProtocolName();
94 if (existsIn(pname, supported)) {
95 target.add(pname);
96 }
97 }
98 }
99 if (enabled != null) {
100
101
102 for (String pname : enabled) {
103 if (!target.contains(pname)) {
104 target.add(pname);
105 }
106 }
107 }
108 if (target.size() > 0) {
109 String[] enabling = target.toArray(new String[target.size()]);
110 socket.setEnabledProtocols(enabling);
111 if (LOG.isDebugEnabled()) {
112 LOG.debug("TLS protocol enabled for SSL handshake: " + Arrays.toString(enabling));
113 }
114 }
115 }
116
117
120 private boolean existsIn(String element, String[] a) {
121 for (String s : a) {
122 if (element.equals(s)) {
123 return true;
124 }
125 }
126 return false;
127 }
128
129 public Socket connectSocket(
130 final int connectTimeout,
131 final Socket socket,
132 final HttpHost host,
133 final InetSocketAddress remoteAddress,
134 final InetSocketAddress localAddress,
135 final HttpContext context) throws IOException {
136 if (LOG.isDebugEnabled()) {
137 LOG.debug("connecting to " + remoteAddress.getAddress() + ":" + remoteAddress.getPort());
138 }
139 Socket connectedSocket;
140 try {
141 connectedSocket = super.connectSocket
142 (connectTimeout, socket, host, remoteAddress, localAddress, context);
143 if (!masterSecretValidator.isMasterSecretValid(connectedSocket)) {
144 throw log(new IllegalStateException("Invalid SSL master secret"));
145 }
146 } catch (final SSLException sslEx) {
147 if (shouldClearSslSessionsPredicate.test(sslEx)) {
148
149 if (LOG.isDebugEnabled()) {
150 LOG.debug("connection failed due to SSL error, clearing TLS session cache", sslEx);
151 }
152 clearSessionCache(sslContext.getClientSessionContext(), remoteAddress);
153 }
154 throw sslEx;
155 }
156
157 if (connectedSocket instanceof SSLSocket) {
158 SdkSSLSocket sslSocket = new SdkSSLSocket((SSLSocket) connectedSocket);
159 return AwsSdkMetrics.isHttpSocketReadMetricEnabled() ? new SdkSSLMetricsSocket(sslSocket) : sslSocket;
160 }
161 SdkSocket sdkSocket = new SdkSocket(connectedSocket);
162 return AwsSdkMetrics.isHttpSocketReadMetricEnabled() ? new SdkMetricsSocket(sdkSocket) : sdkSocket;
163 }
164
165
171 private void clearSessionCache(final SSLSessionContext sessionContext, final InetSocketAddress remoteAddress) {
172 final String hostName = remoteAddress.getHostName();
173 final int port = remoteAddress.getPort();
174 final Enumeration<byte[]> ids = sessionContext.getIds();
175
176 if (ids == null) {
177 return;
178 }
179
180 while (ids.hasMoreElements()) {
181 final byte[] id = ids.nextElement();
182 final SSLSession session = sessionContext.getSession(id);
183 if (session != null && session.getPeerHost() != null && session.getPeerHost().equalsIgnoreCase(hostName)
184 && session.getPeerPort() == port) {
185 session.invalidate();
186 if (LOG.isDebugEnabled()) {
187 LOG.debug("Invalidated session " + session);
188 }
189 }
190 }
191 }
192
193 private <T extends Throwable> T log(T t) {
194 if (LOG.isDebugEnabled()) {
195 LOG.debug("", t);
196 }
197 return t;
198 }
199 }
200