1 /*
2  * Copyright 2010-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License").
5  * You may not use this file except in compliance with the License.
6  * A copy of the License is located at
7  *
8  *  http://aws.amazon.com/apache2.0
9  *
10  * or in the "license" file accompanying this file. This file is distributed
11  * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12  * express or implied. See the License for the specific language governing
13  * permissions and limitations under the License.
14  */

15 package com.amazonaws.services.sqs;
16
17 import static com.amazonaws.util.StringUtils.UTF8;
18
19 import java.io.UnsupportedEncodingException;
20 import java.nio.ByteBuffer;
21 import java.security.MessageDigest;
22 import java.util.ArrayList;
23 import java.util.Collections;
24 import java.util.HashMap;
25 import java.util.List;
26 import java.util.Map;
27
28 import org.apache.commons.logging.Log;
29 import org.apache.commons.logging.LogFactory;
30
31 import com.amazonaws.AmazonClientException;
32 import com.amazonaws.Request;
33 import com.amazonaws.handlers.AbstractRequestHandler;
34 import com.amazonaws.services.sqs.model.Message;
35 import com.amazonaws.services.sqs.model.MessageAttributeValue;
36 import com.amazonaws.services.sqs.model.ReceiveMessageRequest;
37 import com.amazonaws.services.sqs.model.ReceiveMessageResult;
38 import com.amazonaws.services.sqs.model.SendMessageBatchRequest;
39 import com.amazonaws.services.sqs.model.SendMessageBatchRequestEntry;
40 import com.amazonaws.services.sqs.model.SendMessageBatchResult;
41 import com.amazonaws.services.sqs.model.SendMessageBatchResultEntry;
42 import com.amazonaws.services.sqs.model.SendMessageRequest;
43 import com.amazonaws.services.sqs.model.SendMessageResult;
44 import com.amazonaws.util.BinaryUtils;
45 import com.amazonaws.util.Md5Utils;
46 import com.amazonaws.util.TimingInfo;
47
48 /**
49  * SQS operations on sending and receiving messages will return the MD5 digest of the message body.
50  * This custom request handler will verify that the message is correctly received by SQS, by
51  * comparing the returned MD5 with the calculation according to the original request.
52  */

53 public class MessageMD5ChecksumHandler extends AbstractRequestHandler {
54
55     private static final int INTEGER_SIZE_IN_BYTES = 4;
56     private static final byte STRING_TYPE_FIELD_INDEX = 1;
57     private static final byte BINARY_TYPE_FIELD_INDEX = 2;
58     private static final byte STRING_LIST_TYPE_FIELD_INDEX = 3;
59     private static final byte BINARY_LIST_TYPE_FIELD_INDEX = 4;
60
61     /*
62      * Constant strings for composing error message.
63      */

64     private static final String MD5_MISMATCH_ERROR_MESSAGE = "MD5 returned by SQS does not match the calculation on the original request. "
65             + "(MD5 calculated by the %s: \"%s\", MD5 checksum returned: \"%s\")";
66     private static final String MD5_MISMATCH_ERROR_MESSAGE_WITH_ID = "MD5 returned by SQS does not match the calculation on the original request. "
67             + "(Message ID: %s, MD5 calculated by the %s: \"%s\", MD5 checksum returned: \"%s\")";
68     private static final String MESSAGE_BODY = "message body";
69     private static final String MESSAGE_ATTRIBUTES = "message attributes";
70
71     private static final Log log = LogFactory.getLog(MessageMD5ChecksumHandler.class);
72
73     @Override
74     public void afterResponse(Request<?> request, Object response, TimingInfo timingInfo) {
75         if (request != null && response != null) {
76             // SendMessage
77             if (request.getOriginalRequest() instanceof SendMessageRequest && response instanceof SendMessageResult) {
78                 SendMessageRequest sendMessageRequest = (SendMessageRequest) request.getOriginalRequest();
79                 SendMessageResult sendMessageResult = (SendMessageResult) response;
80                 sendMessageOperationMd5Check(sendMessageRequest, sendMessageResult);
81             }
82
83             // ReceiveMessage
84             else if (request.getOriginalRequest() instanceof ReceiveMessageRequest
85                     && response instanceof ReceiveMessageResult) {
86                 ReceiveMessageResult receiveMessageResult = (ReceiveMessageResult) response;
87                 receiveMessageResultMd5Check(receiveMessageResult);
88             }
89
90             // SendMessageBatch
91             else if (request.getOriginalRequest() instanceof SendMessageBatchRequest
92                     && response instanceof SendMessageBatchResult) {
93                 SendMessageBatchRequest sendMessageBatchRequest = (SendMessageBatchRequest) request
94                         .getOriginalRequest();
95                 SendMessageBatchResult sendMessageBatchResult = (SendMessageBatchResult) response;
96                 sendMessageBatchOperationMd5Check(sendMessageBatchRequest, sendMessageBatchResult);
97             }
98         }
99     }
100
101     /**
102      * Throw an exception if the MD5 checksums returned in the SendMessageResult do not match the
103      * client-side calculation based on the original message in the SendMessageRequest.
104      */

105     private static void sendMessageOperationMd5Check(SendMessageRequest sendMessageRequest,
106                                                      SendMessageResult sendMessageResult) {
107         String messageBodySent = sendMessageRequest.getMessageBody();
108         String bodyMd5Returned = sendMessageResult.getMD5OfMessageBody();
109         String clientSideBodyMd5 = calculateMessageBodyMd5(messageBodySent);
110         if (!clientSideBodyMd5.equals(bodyMd5Returned)) {
111             throw new AmazonClientException(String.format(MD5_MISMATCH_ERROR_MESSAGE, MESSAGE_BODY, clientSideBodyMd5,
112                     bodyMd5Returned));
113         }
114
115         Map<String, MessageAttributeValue> messageAttrSent = sendMessageRequest.getMessageAttributes();
116         if (messageAttrSent != null && !messageAttrSent.isEmpty()) {
117             String clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttrSent);
118             String attrMd5Returned = sendMessageResult.getMD5OfMessageAttributes();
119             if (!clientSideAttrMd5.equals(attrMd5Returned)) {
120                 throw new AmazonClientException(String.format(MD5_MISMATCH_ERROR_MESSAGE, MESSAGE_ATTRIBUTES,
121                         clientSideAttrMd5, attrMd5Returned));
122             }
123         }
124     }
125
126     /**
127      * Throw an exception if the MD5 checksums included in the ReceiveMessageResult do not match the
128      * client-side calculation on the received messages.
129      */

130     private static void receiveMessageResultMd5Check(ReceiveMessageResult receiveMessageResult) {
131         if (receiveMessageResult.getMessages() != null) {
132             for (Message messageReceived : receiveMessageResult.getMessages()) {
133                 String messageBody = messageReceived.getBody();
134                 String bodyMd5Returned = messageReceived.getMD5OfBody();
135                 String clientSideBodyMd5 = calculateMessageBodyMd5(messageBody);
136                 if (!clientSideBodyMd5.equals(bodyMd5Returned)) {
137                     throw new AmazonClientException(String.format(MD5_MISMATCH_ERROR_MESSAGE, MESSAGE_BODY,
138                             clientSideBodyMd5, bodyMd5Returned));
139                 }
140
141                 Map<String, MessageAttributeValue> messageAttr = messageReceived.getMessageAttributes();
142                 if (messageAttr != null && !messageAttr.isEmpty()) {
143                     String attrMd5Returned = messageReceived.getMD5OfMessageAttributes();
144                     String clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttr);
145                     if (!clientSideAttrMd5.equals(attrMd5Returned)) {
146                         throw new AmazonClientException(String.format(MD5_MISMATCH_ERROR_MESSAGE, MESSAGE_ATTRIBUTES,
147                                 clientSideAttrMd5, attrMd5Returned));
148                     }
149                 }
150             }
151         }
152     }
153
154     /**
155      * Throw an exception if the MD5 checksums returned in the SendMessageBatchResult do not match
156      * the client-side calculation based on the original messages in the SendMessageBatchRequest.
157      */

158     private static void sendMessageBatchOperationMd5Check(SendMessageBatchRequest sendMessageBatchRequest,
159                                                           SendMessageBatchResult sendMessageBatchResult) {
160         Map<String, SendMessageBatchRequestEntry> idToRequestEntryMap = new HashMap<String, SendMessageBatchRequestEntry>();
161         if (sendMessageBatchRequest.getEntries() != null) {
162             for (SendMessageBatchRequestEntry entry : sendMessageBatchRequest.getEntries()) {
163                 idToRequestEntryMap.put(entry.getId(), entry);
164             }
165         }
166
167         if (sendMessageBatchResult.getSuccessful() != null) {
168             for (SendMessageBatchResultEntry entry : sendMessageBatchResult.getSuccessful()) {
169                 String messageBody = idToRequestEntryMap.get(entry.getId()).getMessageBody();
170                 String bodyMd5Returned = entry.getMD5OfMessageBody();
171                 String clientSideBodyMd5 = calculateMessageBodyMd5(messageBody);
172                 if (!clientSideBodyMd5.equals(bodyMd5Returned)) {
173                     throw new AmazonClientException(String.format(MD5_MISMATCH_ERROR_MESSAGE_WITH_ID, MESSAGE_BODY,
174                             entry.getId(), clientSideBodyMd5, bodyMd5Returned));
175                 }
176
177                 Map<String, MessageAttributeValue> messageAttr = idToRequestEntryMap.get(entry.getId())
178                         .getMessageAttributes();
179                 if (messageAttr != null && !messageAttr.isEmpty()) {
180                     String attrMd5Returned = entry.getMD5OfMessageAttributes();
181                     String clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttr);
182                     if (!clientSideAttrMd5.equals(attrMd5Returned)) {
183                         throw new AmazonClientException(String.format(MD5_MISMATCH_ERROR_MESSAGE_WITH_ID,
184                                 MESSAGE_ATTRIBUTES, entry.getId(), clientSideAttrMd5, attrMd5Returned));
185                     }
186                 }
187             }
188         }
189     }
190
191     /**
192      * Returns the hex-encoded MD5 hash String of the given message body.
193      */

194     private static String calculateMessageBodyMd5(String messageBody) {
195         if (log.isDebugEnabled()) {
196             log.debug("Message body: " + messageBody);
197         }
198         byte[] expectedMd5;
199         try {
200             expectedMd5 = Md5Utils.computeMD5Hash(messageBody.getBytes(UTF8));
201         } catch (Exception e) {
202             throw new AmazonClientException("Unable to calculate the MD5 hash of the message body. " + e.getMessage(),
203                     e);
204         }
205         String expectedMd5Hex = BinaryUtils.toHex(expectedMd5);
206         if (log.isDebugEnabled()) {
207             log.debug("Expected  MD5 of message body: " + expectedMd5Hex);
208         }
209         return expectedMd5Hex;
210     }
211
212     /**
213      * Returns the hex-encoded MD5 hash String of the given message attributes.
214      */

215     private static String calculateMessageAttributesMd5(final Map<String, MessageAttributeValue> messageAttributes) {
216         if (log.isDebugEnabled()) {
217             log.debug("Message attribtues: " + messageAttributes);
218         }
219         List<String> sortedAttributeNames = new ArrayList<String>(messageAttributes.keySet());
220         Collections.sort(sortedAttributeNames);
221
222         MessageDigest md5Digest = null;
223         try {
224             md5Digest = MessageDigest.getInstance("MD5");
225
226             for (String attrName : sortedAttributeNames) {
227                 MessageAttributeValue attrValue = messageAttributes.get(attrName);
228
229                 // Encoded Name
230                 updateLengthAndBytes(md5Digest, attrName);
231                 // Encoded Type
232                 updateLengthAndBytes(md5Digest, attrValue.getDataType());
233
234                 // Encoded Value
235                 if (attrValue.getStringValue() != null) {
236                     md5Digest.update(STRING_TYPE_FIELD_INDEX);
237                     updateLengthAndBytes(md5Digest, attrValue.getStringValue());
238                 } else if (attrValue.getBinaryValue() != null) {
239                     md5Digest.update(BINARY_TYPE_FIELD_INDEX);
240                     updateLengthAndBytes(md5Digest, attrValue.getBinaryValue());
241                 } else if (attrValue.getStringListValues().size() > 0) {
242                     md5Digest.update(STRING_LIST_TYPE_FIELD_INDEX);
243                     for (String strListMember : attrValue.getStringListValues()) {
244                         updateLengthAndBytes(md5Digest, strListMember);
245                     }
246                 } else if (attrValue.getBinaryListValues().size() > 0) {
247                     md5Digest.update(BINARY_LIST_TYPE_FIELD_INDEX);
248                     for (ByteBuffer byteListMember : attrValue.getBinaryListValues()) {
249                         updateLengthAndBytes(md5Digest, byteListMember);
250                     }
251                 }
252             }
253         } catch (Exception e) {
254             throw new AmazonClientException("Unable to calculate the MD5 hash of the message attributes. "
255                     + e.getMessage(), e);
256         }
257
258         String expectedMd5Hex = BinaryUtils.toHex(md5Digest.digest());
259         if (log.isDebugEnabled()) {
260             log.debug("Expected  MD5 of message attributes: " + expectedMd5Hex);
261         }
262         return expectedMd5Hex;
263     }
264
265     /**
266      * Update the digest using a sequence of bytes that consists of the length (in 4 bytes) of the
267      * input String and the actual utf8-encoded byte values.
268      */

269     private static void updateLengthAndBytes(MessageDigest digest, String str) throws UnsupportedEncodingException {
270         byte[] utf8Encoded = str.getBytes(UTF8);
271         ByteBuffer lengthBytes = ByteBuffer.allocate(INTEGER_SIZE_IN_BYTES).putInt(utf8Encoded.length);
272         digest.update(lengthBytes.array());
273         digest.update(utf8Encoded);
274     }
275
276     /**
277      * Update the digest using a sequence of bytes that consists of the length (in 4 bytes) of the
278      * input ByteBuffer and all the bytes it contains.
279      */

280     private static void updateLengthAndBytes(MessageDigest digest, ByteBuffer binaryValue) {
281         ByteBuffer readOnlyBuffer = binaryValue.asReadOnlyBuffer();
282         int size = readOnlyBuffer.remaining();
283         ByteBuffer lengthBytes = ByteBuffer.allocate(INTEGER_SIZE_IN_BYTES).putInt(size);
284         digest.update(lengthBytes.array());
285         digest.update(readOnlyBuffer);
286     }
287 }
288