1
15
16 package software.amazon.awssdk.http.apache.internal.conn;
17
18 import java.io.IOException;
19 import java.net.InetSocketAddress;
20 import java.net.Socket;
21 import java.util.ArrayList;
22 import java.util.Arrays;
23 import java.util.List;
24 import javax.net.ssl.HostnameVerifier;
25 import javax.net.ssl.SSLContext;
26 import javax.net.ssl.SSLSocket;
27 import org.apache.http.HttpHost;
28 import org.apache.http.conn.ssl.SSLConnectionSocketFactory;
29 import org.apache.http.protocol.HttpContext;
30 import software.amazon.awssdk.annotations.SdkInternalApi;
31 import software.amazon.awssdk.http.apache.internal.net.SdkSocket;
32 import software.amazon.awssdk.http.apache.internal.net.SdkSslSocket;
33 import software.amazon.awssdk.utils.Logger;
34
35
38 @SdkInternalApi
39 public class SdkTlsSocketFactory extends SSLConnectionSocketFactory {
40
41 private static final Logger log = Logger.loggerFor(SdkTlsSocketFactory.class);
42 private final SSLContext sslContext;
43
44 public SdkTlsSocketFactory(final SSLContext sslContext, final HostnameVerifier hostnameVerifier) {
45 super(sslContext, hostnameVerifier);
46 if (sslContext == null) {
47 throw new IllegalArgumentException(
48 "sslContext must not be null. " + "Use SSLContext.getDefault() if you are unsure.");
49 }
50 this.sslContext = sslContext;
51 }
52
53
56 @Override
57 protected final void prepareSocket(final SSLSocket socket) {
58 String[] supported = socket.getSupportedProtocols();
59 String[] enabled = socket.getEnabledProtocols();
60 log.debug(() -> String.format("socket.getSupportedProtocols(): %s, socket.getEnabledProtocols(): %s",
61 Arrays.toString(supported),
62 Arrays.toString(enabled)));
63 List<String> target = new ArrayList<>();
64 if (supported != null) {
65
66
67 TlsProtocol[] values = TlsProtocol.values();
68 for (TlsProtocol value : values) {
69 String pname = value.getProtocolName();
70 if (existsIn(pname, supported)) {
71 target.add(pname);
72 }
73 }
74 }
75 if (enabled != null) {
76
77
78 for (String pname : enabled) {
79 if (!target.contains(pname)) {
80 target.add(pname);
81 }
82 }
83 }
84 if (target.size() > 0) {
85 String[] enabling = target.toArray(new String[0]);
86 socket.setEnabledProtocols(enabling);
87 log.debug(() -> "TLS protocol enabled for SSL handshake: " + Arrays.toString(enabling));
88 }
89 }
90
91
94 private boolean existsIn(String element, String[] a) {
95 for (String s : a) {
96 if (element.equals(s)) {
97 return true;
98 }
99 }
100 return false;
101 }
102
103 @Override
104 public Socket connectSocket(
105 final int connectTimeout,
106 final Socket socket,
107 final HttpHost host,
108 final InetSocketAddress remoteAddress,
109 final InetSocketAddress localAddress,
110 final HttpContext context) throws IOException {
111 log.trace(() -> String.format("Connecting to %s:%s", remoteAddress.getAddress(), remoteAddress.getPort()));
112
113 Socket connectedSocket = super.connectSocket(connectTimeout, socket, host, remoteAddress, localAddress, context);
114
115 if (connectedSocket instanceof SSLSocket) {
116 return new SdkSslSocket((SSLSocket) connectedSocket);
117 }
118
119 return new SdkSocket(connectedSocket);
120 }
121
122 }
123