1 /*
2  * Copyright 2013-2019 the original author or authors.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      https://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */

16
17 package org.springframework.cloud.aws.messaging.core;
18
19 import java.nio.ByteBuffer;
20 import java.util.HashMap;
21 import java.util.Map;
22 import java.util.concurrent.ExecutionException;
23 import java.util.concurrent.Future;
24 import java.util.concurrent.TimeUnit;
25 import java.util.concurrent.TimeoutException;
26
27 import com.amazonaws.AmazonServiceException;
28 import com.amazonaws.services.sqs.AmazonSQSAsync;
29 import com.amazonaws.services.sqs.model.DeleteMessageRequest;
30 import com.amazonaws.services.sqs.model.MessageAttributeValue;
31 import com.amazonaws.services.sqs.model.ReceiveMessageRequest;
32 import com.amazonaws.services.sqs.model.ReceiveMessageResult;
33 import com.amazonaws.services.sqs.model.SendMessageRequest;
34 import com.amazonaws.services.sqs.model.SendMessageResult;
35
36 import org.springframework.messaging.Message;
37 import org.springframework.messaging.MessageDeliveryException;
38 import org.springframework.messaging.MessageHeaders;
39 import org.springframework.messaging.PollableChannel;
40 import org.springframework.messaging.support.AbstractMessageChannel;
41 import org.springframework.util.Assert;
42 import org.springframework.util.MimeType;
43 import org.springframework.util.NumberUtils;
44
45 import static org.springframework.cloud.aws.messaging.core.QueueMessageUtils.createMessage;
46
47 /**
48  * @author Agim Emruli
49  * @author Alain Sahli
50  * @since 1.0
51  */

52 public class QueueMessageChannel extends AbstractMessageChannel
53         implements PollableChannel {
54
55     static final String ATTRIBUTE_NAMES = "All";
56
57     private static final String MESSAGE_ATTRIBUTE_NAMES = "All";
58
59     private final AmazonSQSAsync amazonSqs;
60
61     private final String queueUrl;
62
63     public QueueMessageChannel(AmazonSQSAsync amazonSqs, String queueUrl) {
64         this.amazonSqs = amazonSqs;
65         this.queueUrl = queueUrl;
66     }
67
68     private static boolean isSkipHeader(String headerName) {
69         return SqsMessageHeaders.SQS_DELAY_HEADER.equals(headerName)
70                 || SqsMessageHeaders.SQS_DEDUPLICATION_ID_HEADER.equals(headerName)
71                 || SqsMessageHeaders.SQS_GROUP_ID_HEADER.equals(headerName);
72     }
73
74     @Override
75     protected boolean sendInternal(Message<?> message, long timeout) {
76         try {
77             sendMessageAndWaitForResult(prepareSendMessageRequest(message), timeout);
78         }
79         catch (AmazonServiceException e) {
80             throw new MessageDeliveryException(message, e.getMessage(), e);
81         }
82         catch (ExecutionException e) {
83             throw new MessageDeliveryException(message, e.getMessage(), e.getCause());
84         }
85         catch (TimeoutException e) {
86             return false;
87         }
88
89         return true;
90     }
91
92     private SendMessageRequest prepareSendMessageRequest(Message<?> message) {
93         SendMessageRequest sendMessageRequest = new SendMessageRequest(this.queueUrl,
94                 String.valueOf(message.getPayload()));
95
96         if (message.getHeaders().containsKey(SqsMessageHeaders.SQS_GROUP_ID_HEADER)) {
97             sendMessageRequest.setMessageGroupId(message.getHeaders()
98                     .get(SqsMessageHeaders.SQS_GROUP_ID_HEADER, String.class));
99         }
100
101         if (message.getHeaders()
102                 .containsKey(SqsMessageHeaders.SQS_DEDUPLICATION_ID_HEADER)) {
103             sendMessageRequest.setMessageDeduplicationId(message.getHeaders()
104                     .get(SqsMessageHeaders.SQS_DEDUPLICATION_ID_HEADER, String.class));
105         }
106
107         if (message.getHeaders().containsKey(SqsMessageHeaders.SQS_DELAY_HEADER)) {
108             sendMessageRequest.setDelaySeconds(message.getHeaders()
109                     .get(SqsMessageHeaders.SQS_DELAY_HEADER, Integer.class));
110         }
111
112         Map<String, MessageAttributeValue> messageAttributes = getMessageAttributes(
113                 message);
114         if (!messageAttributes.isEmpty()) {
115             sendMessageRequest.withMessageAttributes(messageAttributes);
116         }
117
118         return sendMessageRequest;
119     }
120
121     private void sendMessageAndWaitForResult(SendMessageRequest sendMessageRequest,
122             long timeout) throws ExecutionException, TimeoutException {
123         if (timeout > 0) {
124             Future<SendMessageResult> sendMessageFuture = this.amazonSqs
125                     .sendMessageAsync(sendMessageRequest);
126
127             try {
128                 sendMessageFuture.get(timeout, TimeUnit.MILLISECONDS);
129             }
130             catch (InterruptedException e) {
131                 Thread.currentThread().interrupt();
132             }
133         }
134         else {
135             this.amazonSqs.sendMessage(sendMessageRequest);
136         }
137     }
138
139     private Map<String, MessageAttributeValue> getMessageAttributes(Message<?> message) {
140         HashMap<String, MessageAttributeValue> messageAttributes = new HashMap<>();
141         for (Map.Entry<String, Object> messageHeader : message.getHeaders().entrySet()) {
142             String messageHeaderName = messageHeader.getKey();
143             Object messageHeaderValue = messageHeader.getValue();
144
145             if (isSkipHeader(messageHeaderName)) {
146                 continue;
147             }
148
149             if (MessageHeaders.CONTENT_TYPE.equals(messageHeaderName)
150                     && messageHeaderValue != null) {
151                 messageAttributes.put(messageHeaderName,
152                         getContentTypeMessageAttribute(messageHeaderValue));
153             }
154             else if (MessageHeaders.ID.equals(messageHeaderName)
155                     && messageHeaderValue != null) {
156                 messageAttributes.put(messageHeaderName,
157                         getStringMessageAttribute(messageHeaderValue.toString()));
158             }
159             else if (messageHeaderValue instanceof String) {
160                 messageAttributes.put(messageHeaderName,
161                         getStringMessageAttribute((String) messageHeaderValue));
162             }
163             else if (messageHeaderValue instanceof Number) {
164                 messageAttributes.put(messageHeaderName,
165                         getNumberMessageAttribute(messageHeaderValue));
166             }
167             else if (messageHeaderValue instanceof ByteBuffer) {
168                 messageAttributes.put(messageHeaderName,
169                         getBinaryMessageAttribute((ByteBuffer) messageHeaderValue));
170             }
171             else {
172                 this.logger.warn(String.format(
173                         "Message header with name '%s' and type '%s' cannot be sent as"
174                                 + " message attribute because it is not supported by SQS.",
175                         messageHeaderName, messageHeaderValue != null
176                                 ? messageHeaderValue.getClass().getName() : ""));
177             }
178         }
179
180         return messageAttributes;
181     }
182
183     private MessageAttributeValue getBinaryMessageAttribute(
184             ByteBuffer messageHeaderValue) {
185         return new MessageAttributeValue().withDataType(MessageAttributeDataTypes.BINARY)
186                 .withBinaryValue(messageHeaderValue);
187     }
188
189     private MessageAttributeValue getContentTypeMessageAttribute(
190             Object messageHeaderValue) {
191         if (messageHeaderValue instanceof MimeType) {
192             return new MessageAttributeValue()
193                     .withDataType(MessageAttributeDataTypes.STRING)
194                     .withStringValue(messageHeaderValue.toString());
195         }
196         else if (messageHeaderValue instanceof String) {
197             return new MessageAttributeValue()
198                     .withDataType(MessageAttributeDataTypes.STRING)
199                     .withStringValue((String) messageHeaderValue);
200         }
201         return null;
202     }
203
204     private MessageAttributeValue getStringMessageAttribute(String messageHeaderValue) {
205         return new MessageAttributeValue().withDataType(MessageAttributeDataTypes.STRING)
206                 .withStringValue(messageHeaderValue);
207     }
208
209     private MessageAttributeValue getNumberMessageAttribute(Object messageHeaderValue) {
210         Assert.isTrue(
211                 NumberUtils.STANDARD_NUMBER_TYPES.contains(messageHeaderValue.getClass()),
212                 "Only standard number types are accepted as message header.");
213
214         return new MessageAttributeValue()
215                 .withDataType(MessageAttributeDataTypes.NUMBER + "."
216                         + messageHeaderValue.getClass().getName())
217                 .withStringValue(messageHeaderValue.toString());
218     }
219
220     @Override
221     public Message<String> receive() {
222         return this.receive(0);
223     }
224
225     @Override
226     public Message<String> receive(long timeout) {
227         ReceiveMessageResult receiveMessageResult = this.amazonSqs.receiveMessage(
228                 new ReceiveMessageRequest(this.queueUrl).withMaxNumberOfMessages(1)
229                         .withWaitTimeSeconds(Long.valueOf(timeout).intValue())
230                         .withAttributeNames(ATTRIBUTE_NAMES)
231                         .withMessageAttributeNames(MESSAGE_ATTRIBUTE_NAMES));
232         if (receiveMessageResult.getMessages().isEmpty()) {
233             return null;
234         }
235         com.amazonaws.services.sqs.model.Message amazonMessage = receiveMessageResult
236                 .getMessages().get(0);
237         Message<String> message = createMessage(amazonMessage);
238         this.amazonSqs.deleteMessage(new DeleteMessageRequest(this.queueUrl,
239                 amazonMessage.getReceiptHandle()));
240         return message;
241     }
242
243 }
244