1
16 package org.springframework.data.repository.query;
17
18 import lombok.RequiredArgsConstructor;
19
20 import java.lang.reflect.Method;
21 import java.util.HashMap;
22 import java.util.List;
23 import java.util.Map;
24 import java.util.Optional;
25 import java.util.stream.Collectors;
26 import java.util.stream.Stream;
27
28 import org.aopalliance.intercept.MethodInterceptor;
29 import org.aopalliance.intercept.MethodInvocation;
30 import org.springframework.beans.factory.ListableBeanFactory;
31 import org.springframework.data.spel.ExtensionAwareEvaluationContextProvider;
32 import org.springframework.data.spel.spi.EvaluationContextExtension;
33 import org.springframework.expression.EvaluationContext;
34 import org.springframework.expression.spel.support.StandardEvaluationContext;
35 import org.springframework.lang.Nullable;
36 import org.springframework.util.Assert;
37 import org.springframework.util.ConcurrentReferenceHashMap;
38 import org.springframework.util.ReflectionUtils;
39 import org.springframework.util.StringUtils;
40
41
51 public class ExtensionAwareQueryMethodEvaluationContextProvider implements QueryMethodEvaluationContextProvider {
52
53 private final ExtensionAwareEvaluationContextProvider delegate;
54
55
61 public ExtensionAwareQueryMethodEvaluationContextProvider(ListableBeanFactory beanFactory) {
62
63 Assert.notNull(beanFactory, "ListableBeanFactory must not be null!");
64
65 this.delegate = new ExtensionAwareEvaluationContextProvider(beanFactory);
66 }
67
68
74 public ExtensionAwareQueryMethodEvaluationContextProvider(List<? extends EvaluationContextExtension> extensions) {
75
76 Assert.notNull(extensions, "EvaluationContextExtensions must not be null!");
77
78 this.delegate = new org.springframework.data.spel.ExtensionAwareEvaluationContextProvider(extensions);
79 }
80
81
85 @Override
86 public <T extends Parameters<?, ?>> EvaluationContext getEvaluationContext(T parameters, Object[] parameterValues) {
87
88 StandardEvaluationContext evaluationContext = delegate.getEvaluationContext(parameterValues);
89
90 evaluationContext.setVariables(collectVariables(parameters, parameterValues));
91
92 return evaluationContext;
93 }
94
95
103 private static Map<String, Object> collectVariables(Parameters<?, ?> parameters, Object[] arguments) {
104
105 Map<String, Object> variables = new HashMap<>();
106
107 parameters.stream()
108 .filter(Parameter::isSpecialParameter)
109 .forEach(it -> variables.put(
110 StringUtils.uncapitalize(it.getType().getSimpleName()),
111 arguments[it.getIndex()]));
112
113 parameters.stream()
114 .filter(Parameter::isNamedParameter)
115 .forEach(it -> variables.put(
116 it.getName().orElseThrow(() -> new IllegalStateException("Should never occur!")),
117 arguments[it.getIndex()]));
118
119 return variables;
120 }
121
122
130 private static List<EvaluationContextExtension> getExtensionsFrom(ListableBeanFactory beanFactory) {
131
132 Stream<EvaluationContextExtension> extensions = beanFactory
133 .getBeansOfType(EvaluationContextExtension.class, true, false).values().stream();
134
135 return extensions.collect(Collectors.toList());
136 }
137
138
144 @RequiredArgsConstructor
145 static class DelegatingMethodInterceptor implements MethodInterceptor {
146
147 private static final Map<Method, Method> METHOD_CACHE = new ConcurrentReferenceHashMap<Method, Method>();
148
149 private final Object target;
150 private final Map<String, java.util.function.Function<Object, Object>> directMappings = new HashMap<>();
151
152
159 public void registerResultMapping(String methodName, java.util.function.Function<Object, Object> mapping) {
160 this.directMappings.put(methodName, mapping);
161 }
162
163
167 @Nullable
168 @Override
169 public Object invoke(@Nullable MethodInvocation invocation) throws Throwable {
170
171 if (invocation == null) {
172 throw new IllegalArgumentException("Invocation must not be null!");
173 }
174
175 Method method = invocation.getMethod();
176 Method targetMethod = METHOD_CACHE.computeIfAbsent(method,
177 it -> Optional.ofNullable(findTargetMethod(it)).orElse(it));
178
179 Object result = method.equals(targetMethod) ? invocation.proceed()
180 : ReflectionUtils.invokeMethod(targetMethod, target, invocation.getArguments());
181
182 if (result == null) {
183 return result;
184 }
185
186 java.util.function.Function<Object, Object> mapper = directMappings.get(targetMethod.getName());
187
188 return mapper != null ? mapper.apply(result) : result;
189 }
190
191 @Nullable
192 private Method findTargetMethod(Method method) {
193
194 try {
195 return target.getClass().getMethod(method.getName(), method.getParameterTypes());
196 } catch (Exception e) {
197 return null;
198 }
199 }
200 }
201 }
202