1 /*
2  * Copyright 2013-2019 the original author or authors.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      https://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */

16
17 package org.springframework.cloud.aws.messaging.listener;
18
19 import java.util.HashMap;
20 import java.util.Map;
21 import java.util.concurrent.ConcurrentHashMap;
22 import java.util.concurrent.CountDownLatch;
23 import java.util.concurrent.ExecutionException;
24 import java.util.concurrent.Future;
25 import java.util.concurrent.TimeUnit;
26 import java.util.concurrent.TimeoutException;
27
28 import com.amazonaws.services.sqs.model.DeleteMessageRequest;
29 import com.amazonaws.services.sqs.model.Message;
30 import com.amazonaws.services.sqs.model.ReceiveMessageResult;
31
32 import org.springframework.core.task.AsyncTaskExecutor;
33 import org.springframework.messaging.MessagingException;
34 import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
35 import org.springframework.util.Assert;
36 import org.springframework.util.ClassUtils;
37
38 import static org.springframework.cloud.aws.messaging.core.QueueMessageUtils.createMessage;
39
40 /**
41  * @author Agim Emruli
42  * @author Alain Sahli
43  * @author Mete Alpaslan Katircioglu
44  * @since 1.0
45  */

46 public class SimpleMessageListenerContainer extends AbstractMessageListenerContainer {
47
48     private static final int DEFAULT_WORKER_THREADS = 2;
49
50     private static final String DEFAULT_THREAD_NAME_PREFIX = ClassUtils
51             .getShortName(SimpleMessageListenerContainer.class) + "-";
52
53     private boolean defaultTaskExecutor;
54
55     private long backOffTime = 10000;
56
57     private long queueStopTimeout = 10000;
58
59     private AsyncTaskExecutor taskExecutor;
60
61     private ConcurrentHashMap<String, Future<?>> scheduledFutureByQueue;
62
63     private ConcurrentHashMap<String, Boolean> runningStateByQueue;
64
65     protected AsyncTaskExecutor getTaskExecutor() {
66         return this.taskExecutor;
67     }
68
69     public void setTaskExecutor(AsyncTaskExecutor taskExecutor) {
70         this.taskExecutor = taskExecutor;
71     }
72
73     /**
74      * @return The number of milliseconds the polling thread must wait before trying to
75      * recover when an error occurs (e.g. connection timeout)
76      */

77     public long getBackOffTime() {
78         return this.backOffTime;
79     }
80
81     /**
82      * The number of milliseconds the polling thread must wait before trying to recover
83      * when an error occurs (e.g. connection timeout). Default is 10000 milliseconds.
84      * @param backOffTime in milliseconds
85      */

86     public void setBackOffTime(long backOffTime) {
87         this.backOffTime = backOffTime;
88     }
89
90     /**
91      * @return The number of milliseconds the
92      * {@link SimpleMessageListenerContainer#stop(String)} method waits for a queue to
93      * stop before interrupting the current thread. Default value is 10000 milliseconds
94      * (10 seconds).
95      */

96     public long getQueueStopTimeout() {
97         return this.queueStopTimeout;
98     }
99
100     /**
101      * The number of milliseconds the {@link SimpleMessageListenerContainer#stop(String)}
102      * method waits for a queue to stop before interrupting the current thread. Default
103      * value is 10000 milliseconds (10 seconds).
104      * @param queueStopTimeout in milliseconds
105      */

106     public void setQueueStopTimeout(long queueStopTimeout) {
107         this.queueStopTimeout = queueStopTimeout;
108     }
109
110     @Override
111     protected void initialize() {
112         super.initialize();
113
114         if (this.taskExecutor == null) {
115             this.defaultTaskExecutor = true;
116             this.taskExecutor = createDefaultTaskExecutor();
117         }
118
119         initializeRunningStateByQueue();
120         this.scheduledFutureByQueue = new ConcurrentHashMap<>(
121                 getRegisteredQueues().size());
122     }
123
124     private void initializeRunningStateByQueue() {
125         this.runningStateByQueue = new ConcurrentHashMap<>(getRegisteredQueues().size());
126         for (String queueName : getRegisteredQueues().keySet()) {
127             this.runningStateByQueue.put(queueName, false);
128         }
129     }
130
131     @Override
132     protected void doStart() {
133         synchronized (this.getLifecycleMonitor()) {
134             scheduleMessageListeners();
135         }
136     }
137
138     @Override
139     protected void doStop() {
140         notifyRunningQueuesToStop();
141         waitForRunningQueuesToStop();
142     }
143
144     private void notifyRunningQueuesToStop() {
145         for (Map.Entry<String, Boolean> runningStateByQueue : this.runningStateByQueue
146                 .entrySet()) {
147             if (runningStateByQueue.getValue()) {
148                 stopQueue(runningStateByQueue.getKey());
149             }
150         }
151     }
152
153     private void waitForRunningQueuesToStop() {
154         for (Map.Entry<String, Boolean> queueRunningState : this.runningStateByQueue
155                 .entrySet()) {
156             String logicalQueueName = queueRunningState.getKey();
157             Future<?> queueSpinningThread = this.scheduledFutureByQueue
158                     .get(logicalQueueName);
159
160             if (queueSpinningThread != null) {
161                 try {
162                     queueSpinningThread.get(getQueueStopTimeout(), TimeUnit.MILLISECONDS);
163                 }
164                 catch (ExecutionException | TimeoutException e) {
165                     getLogger().warn("An exception occurred while stopping queue '"
166                             + logicalQueueName + "'", e);
167                 }
168                 catch (InterruptedException e) {
169                     Thread.currentThread().interrupt();
170                 }
171             }
172         }
173     }
174
175     @Override
176     protected void doDestroy() {
177         if (this.defaultTaskExecutor) {
178             ((ThreadPoolTaskExecutor) this.taskExecutor).destroy();
179         }
180     }
181
182     /**
183      * Create a default TaskExecutor. Called if no explicit TaskExecutor has been
184      * specified.
185      * <p>
186      * The default implementation builds a
187      * {@link org.springframework.core.task.SimpleAsyncTaskExecutor} with the specified
188      * bean name (or the class name, if no bean name specified) as thread name prefix.
189      * @return a {@link org.springframework.core.task.SimpleAsyncTaskExecutor} configured
190      * with the thread name prefix
191      * @see org.springframework.core.task.SimpleAsyncTaskExecutor#SimpleAsyncTaskExecutor(String)
192      */

193     protected AsyncTaskExecutor createDefaultTaskExecutor() {
194         String beanName = getBeanName();
195         ThreadPoolTaskExecutor threadPoolTaskExecutor = new ThreadPoolTaskExecutor();
196         threadPoolTaskExecutor.setThreadNamePrefix(
197                 beanName != null ? beanName + "-" : DEFAULT_THREAD_NAME_PREFIX);
198         int spinningThreads = this.getRegisteredQueues().size();
199
200         if (spinningThreads > 0) {
201             threadPoolTaskExecutor
202                     .setCorePoolSize(spinningThreads * DEFAULT_WORKER_THREADS);
203
204             int maxNumberOfMessagePerBatch = getMaxNumberOfMessages() != null
205                     ? getMaxNumberOfMessages() : DEFAULT_MAX_NUMBER_OF_MESSAGES;
206             threadPoolTaskExecutor
207                     .setMaxPoolSize(spinningThreads * (maxNumberOfMessagePerBatch + 1));
208         }
209
210         // No use of a thread pool executor queue to avoid retaining message to long in
211         // memory
212         threadPoolTaskExecutor.setQueueCapacity(0);
213         threadPoolTaskExecutor.afterPropertiesSet();
214
215         return threadPoolTaskExecutor;
216
217     }
218
219     private void scheduleMessageListeners() {
220         for (Map.Entry<String, QueueAttributes> registeredQueue : getRegisteredQueues()
221                 .entrySet()) {
222             startQueue(registeredQueue.getKey(), registeredQueue.getValue());
223         }
224     }
225
226     protected void executeMessage(
227             org.springframework.messaging.Message<String> stringMessage) {
228         getMessageHandler().handleMessage(stringMessage);
229     }
230
231     /**
232      * Stops and waits until the specified queue has stopped. If the wait timeout
233      * specified by {@link SimpleMessageListenerContainer#getQueueStopTimeout()} is
234      * reached, the current thread is interrupted.
235      * @param logicalQueueName the name as defined on the listener method
236      */

237     public void stop(String logicalQueueName) {
238         stopQueue(logicalQueueName);
239
240         try {
241             if (isRunning(logicalQueueName)) {
242                 Future<?> future = this.scheduledFutureByQueue.remove(logicalQueueName);
243                 if (future != null) {
244                     future.get(this.queueStopTimeout, TimeUnit.MILLISECONDS);
245                 }
246             }
247         }
248         catch (InterruptedException e) {
249             Thread.currentThread().interrupt();
250         }
251         catch (ExecutionException | TimeoutException e) {
252             getLogger().warn("Error stopping queue with name: '" + logicalQueueName + "'",
253                     e);
254         }
255     }
256
257     protected void stopQueue(String logicalQueueName) {
258         Assert.isTrue(this.runningStateByQueue.containsKey(logicalQueueName),
259                 "Queue with name '" + logicalQueueName + "' does not exist");
260         this.runningStateByQueue.put(logicalQueueName, false);
261     }
262
263     public void start(String logicalQueueName) {
264         Assert.isTrue(this.runningStateByQueue.containsKey(logicalQueueName),
265                 "Queue with name '" + logicalQueueName + "' does not exist");
266
267         QueueAttributes queueAttributes = this.getRegisteredQueues()
268                 .get(logicalQueueName);
269         startQueue(logicalQueueName, queueAttributes);
270     }
271
272     /**
273      * Checks if the spinning thread for the specified queue {@code logicalQueueName} is
274      * still running (polling for new messages) or not.
275      * @param logicalQueueName the name as defined on the listener method
276      * @return {@code trueif the spinning thread for the specified queue is running
277      * otherwise {@code false}.
278      */

279     public boolean isRunning(String logicalQueueName) {
280         Future<?> future = this.scheduledFutureByQueue.get(logicalQueueName);
281         return future != null && !future.isCancelled() && !future.isDone();
282     }
283
284     protected void startQueue(String queueName, QueueAttributes queueAttributes) {
285         if (this.runningStateByQueue.containsKey(queueName)
286                 && this.runningStateByQueue.get(queueName)) {
287             return;
288         }
289
290         this.runningStateByQueue.put(queueName, true);
291         Future<?> future = getTaskExecutor()
292                 .submit(new AsynchronousMessageListener(queueName, queueAttributes));
293         this.scheduledFutureByQueue.put(queueName, future);
294     }
295
296     private static final class SignalExecutingRunnable implements Runnable {
297
298         private final CountDownLatch countDownLatch;
299
300         private final Runnable runnable;
301
302         private SignalExecutingRunnable(CountDownLatch endSignal, Runnable runnable) {
303             this.countDownLatch = endSignal;
304             this.runnable = runnable;
305         }
306
307         @Override
308         public void run() {
309             try {
310                 this.runnable.run();
311             }
312             finally {
313                 this.countDownLatch.countDown();
314             }
315         }
316
317     }
318
319     private final class AsynchronousMessageListener implements Runnable {
320
321         private final QueueAttributes queueAttributes;
322
323         private final String logicalQueueName;
324
325         private AsynchronousMessageListener(String logicalQueueName,
326                 QueueAttributes queueAttributes) {
327             this.logicalQueueName = logicalQueueName;
328             this.queueAttributes = queueAttributes;
329         }
330
331         @Override
332         public void run() {
333             while (isQueueRunning()) {
334                 try {
335                     ReceiveMessageResult receiveMessageResult = getAmazonSqs()
336                             .receiveMessage(
337                                     this.queueAttributes.getReceiveMessageRequest());
338                     CountDownLatch messageBatchLatch = new CountDownLatch(
339                             receiveMessageResult.getMessages().size());
340                     for (Message message : receiveMessageResult.getMessages()) {
341                         if (isQueueRunning()) {
342                             MessageExecutor messageExecutor = new MessageExecutor(
343                                     this.logicalQueueName, message, this.queueAttributes);
344                             getTaskExecutor().execute(new SignalExecutingRunnable(
345                                     messageBatchLatch, messageExecutor));
346                         }
347                         else {
348                             messageBatchLatch.countDown();
349                         }
350                     }
351                     try {
352                         messageBatchLatch.await();
353                     }
354                     catch (InterruptedException e) {
355                         Thread.currentThread().interrupt();
356                     }
357                 }
358                 catch (Exception e) {
359                     getLogger().warn(
360                             "An Exception occurred while polling queue '{}'. The failing operation will be "
361                                     + "retried in {} milliseconds",
362                             this.logicalQueueName, getBackOffTime(), e);
363                     try {
364                         // noinspection BusyWait
365                         Thread.sleep(getBackOffTime());
366                     }
367                     catch (InterruptedException ie) {
368                         Thread.currentThread().interrupt();
369                     }
370                 }
371             }
372
373             SimpleMessageListenerContainer.this.scheduledFutureByQueue
374                     .remove(this.logicalQueueName);
375         }
376
377         private boolean isQueueRunning() {
378             if (SimpleMessageListenerContainer.this.runningStateByQueue
379                     .containsKey(this.logicalQueueName)) {
380                 return SimpleMessageListenerContainer.this.runningStateByQueue
381                         .get(this.logicalQueueName);
382             }
383             else {
384                 getLogger().warn("Stopped queue '" + this.logicalQueueName
385                         + "' because it was not listed as running queue.");
386                 return false;
387             }
388         }
389
390     }
391
392     private final class MessageExecutor implements Runnable {
393
394         private final Message message;
395
396         private final String logicalQueueName;
397
398         private final String queueUrl;
399
400         private final boolean hasRedrivePolicy;
401
402         private final SqsMessageDeletionPolicy deletionPolicy;
403
404         private MessageExecutor(String logicalQueueName, Message message,
405                 QueueAttributes queueAttributes) {
406             this.logicalQueueName = logicalQueueName;
407             this.message = message;
408             this.queueUrl = queueAttributes.getReceiveMessageRequest().getQueueUrl();
409             this.hasRedrivePolicy = queueAttributes.hasRedrivePolicy();
410             this.deletionPolicy = queueAttributes.getDeletionPolicy();
411         }
412
413         @Override
414         public void run() {
415             String receiptHandle = this.message.getReceiptHandle();
416             org.springframework.messaging.Message<String> queueMessage = getMessageForExecution();
417             try {
418                 executeMessage(queueMessage);
419                 applyDeletionPolicyOnSuccess(receiptHandle);
420             }
421             catch (MessagingException messagingException) {
422                 applyDeletionPolicyOnError(receiptHandle);
423             }
424         }
425
426         private void applyDeletionPolicyOnSuccess(String receiptHandle) {
427             if (this.deletionPolicy == SqsMessageDeletionPolicy.ON_SUCCESS
428                     || this.deletionPolicy == SqsMessageDeletionPolicy.ALWAYS
429                     || this.deletionPolicy == SqsMessageDeletionPolicy.NO_REDRIVE) {
430                 deleteMessage(receiptHandle);
431             }
432         }
433
434         private void applyDeletionPolicyOnError(String receiptHandle) {
435             if (this.deletionPolicy == SqsMessageDeletionPolicy.ALWAYS
436                     || (this.deletionPolicy == SqsMessageDeletionPolicy.NO_REDRIVE
437                             && !this.hasRedrivePolicy)) {
438                 deleteMessage(receiptHandle);
439             }
440         }
441
442         private void deleteMessage(String receiptHandle) {
443             getAmazonSqs().deleteMessageAsync(
444                     new DeleteMessageRequest(this.queueUrl, receiptHandle),
445                     new DeleteMessageHandler(receiptHandle));
446         }
447
448         private org.springframework.messaging.Message<String> getMessageForExecution() {
449             HashMap<String, Object> additionalHeaders = new HashMap<>();
450             additionalHeaders.put(QueueMessageHandler.LOGICAL_RESOURCE_ID,
451                     this.logicalQueueName);
452             if (this.deletionPolicy == SqsMessageDeletionPolicy.NEVER) {
453                 String receiptHandle = this.message.getReceiptHandle();
454                 QueueMessageAcknowledgment acknowledgment = new QueueMessageAcknowledgment(
455                         SimpleMessageListenerContainer.this.getAmazonSqs(), this.queueUrl,
456                         receiptHandle);
457                 additionalHeaders.put(QueueMessageHandler.ACKNOWLEDGMENT, acknowledgment);
458             }
459             additionalHeaders.put(QueueMessageHandler.VISIBILITY,
460                     new QueueMessageVisibility(
461                             SimpleMessageListenerContainer.this.getAmazonSqs(),
462                             this.queueUrl, this.message.getReceiptHandle()));
463
464             return createMessage(this.message, additionalHeaders);
465         }
466
467     }
468
469 }
470