1 /*
2  *
3  *  * Copyright 2019-2020 the original author or authors.
4  *  *
5  *  * Licensed under the Apache License, Version 2.0 (the "License");
6  *  * you may not use this file except in compliance with the License.
7  *  * You may obtain a copy of the License at
8  *  *
9  *  *      https://www.apache.org/licenses/LICENSE-2.0
10  *  *
11  *  * Unless required by applicable law or agreed to in writing, software
12  *  * distributed under the License is distributed on an "AS IS" BASIS,
13  *  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  * See the License for the specific language governing permissions and
15  *  * limitations under the License.
16  *
17  */

18
19 package org.springdoc.core;
20
21 import java.util.ArrayList;
22 import java.util.Collections;
23 import java.util.HashMap;
24 import java.util.HashSet;
25 import java.util.List;
26 import java.util.Locale;
27 import java.util.Map;
28 import java.util.Optional;
29 import java.util.Set;
30 import java.util.function.Consumer;
31 import java.util.function.Supplier;
32 import java.util.stream.Collectors;
33 import java.util.stream.Stream;
34
35 import io.swagger.v3.core.util.AnnotationsUtils;
36 import io.swagger.v3.oas.annotations.Hidden;
37 import io.swagger.v3.oas.annotations.OpenAPIDefinition;
38 import io.swagger.v3.oas.annotations.tags.Tag;
39 import io.swagger.v3.oas.annotations.tags.Tags;
40 import io.swagger.v3.oas.models.Components;
41 import io.swagger.v3.oas.models.OpenAPI;
42 import io.swagger.v3.oas.models.Operation;
43 import io.swagger.v3.oas.models.Paths;
44 import io.swagger.v3.oas.models.info.Contact;
45 import io.swagger.v3.oas.models.info.Info;
46 import io.swagger.v3.oas.models.info.License;
47 import io.swagger.v3.oas.models.media.Schema;
48 import io.swagger.v3.oas.models.security.SecurityScheme;
49 import io.swagger.v3.oas.models.servers.Server;
50 import org.apache.commons.lang3.StringUtils;
51 import org.slf4j.Logger;
52 import org.slf4j.LoggerFactory;
53 import org.springdoc.core.customizers.OpenApiBuilderCustomiser;
54
55 import org.springframework.beans.factory.config.BeanDefinition;
56 import org.springframework.boot.autoconfigure.AutoConfigurationPackages;
57 import org.springframework.context.ApplicationContext;
58 import org.springframework.context.annotation.ClassPathScanningCandidateComponentProvider;
59 import org.springframework.core.annotation.AnnotatedElementUtils;
60 import org.springframework.core.annotation.AnnotationUtils;
61 import org.springframework.core.type.filter.AnnotationTypeFilter;
62 import org.springframework.stereotype.Controller;
63 import org.springframework.util.CollectionUtils;
64 import org.springframework.web.bind.annotation.ControllerAdvice;
65 import org.springframework.web.bind.annotation.RequestMapping;
66 import org.springframework.web.bind.annotation.RestController;
67 import org.springframework.web.method.HandlerMethod;
68
69 import static org.springdoc.core.Constants.DEFAULT_SERVER_DESCRIPTION;
70 import static org.springdoc.core.Constants.DEFAULT_TITLE;
71 import static org.springdoc.core.Constants.DEFAULT_VERSION;
72
73 public class OpenAPIBuilder {
74
75     private static final Logger LOGGER = LoggerFactory.getLogger(OpenAPIBuilder.class);
76
77     private final ApplicationContext context;
78
79     private final SecurityParser securityParser;
80
81     private final Map<String, Object> mappingsMap = new HashMap<>();
82
83     private final Map<HandlerMethod, io.swagger.v3.oas.models.tags.Tag> springdocTags = new HashMap<>();
84
85     private final Optional<List<OpenApiBuilderCustomiser>> openApiBuilderCustomisers;
86
87     private final SpringDocConfigProperties springDocConfigProperties;
88
89     private OpenAPI openAPI;
90
91     private OpenAPI cachedOpenAPI;
92
93     private OpenAPI calculatedOpenAPI;
94
95     private boolean isServersPresent;
96
97     private String serverBaseUrl;
98
99     OpenAPIBuilder(Optional<OpenAPI> openAPI, ApplicationContext context, SecurityParser securityParser,
100             SpringDocConfigProperties springDocConfigProperties,
101             Optional<List<OpenApiBuilderCustomiser>> openApiBuilderCustomisers) {
102         if (openAPI.isPresent()) {
103             this.openAPI = openAPI.get();
104             if (this.openAPI.getComponents() == null)
105                 this.openAPI.setComponents(new Components());
106             if (this.openAPI.getPaths() == null)
107                 this.openAPI.setPaths(new Paths());
108             if (!CollectionUtils.isEmpty(this.openAPI.getServers()))
109                 this.isServersPresent = true;
110         }
111         this.context = context;
112         this.securityParser = securityParser;
113         this.springDocConfigProperties = springDocConfigProperties;
114         this.openApiBuilderCustomisers = openApiBuilderCustomisers;
115     }
116
117     public static String splitCamelCase(String str) {
118         return str.replaceAll(
119                 String.format(
120                         "%s|%s|%s",
121                         "(?<=[A-Z])(?=[A-Z][a-z])",
122                         "(?<=[^A-Z])(?=[A-Z])",
123                         "(?<=[A-Za-z])(?=[^A-Za-z])"),
124                 "-")
125                 .toLowerCase(Locale.ROOT);
126     }
127
128     public void build() {
129         Optional<OpenAPIDefinition> apiDef = getOpenAPIDefinition();
130
131         if (openAPI == null) {
132             this.calculatedOpenAPI = new OpenAPI();
133             this.calculatedOpenAPI.setComponents(new Components());
134             this.calculatedOpenAPI.setPaths(new Paths());
135         }
136         else
137             this.calculatedOpenAPI = openAPI;
138
139         if (apiDef.isPresent()) {
140             buildOpenAPIWithOpenAPIDefinition(calculatedOpenAPI, apiDef.get());
141         }
142         // Set default info
143         else if (calculatedOpenAPI.getInfo() == null) {
144             Info infos = new Info().title(DEFAULT_TITLE).version(DEFAULT_VERSION);
145             calculatedOpenAPI.setInfo(infos);
146         }
147         // Set default mappings
148         this.mappingsMap.putAll(context.getBeansWithAnnotation(RestController.class));
149         this.mappingsMap.putAll(context.getBeansWithAnnotation(RequestMapping.class));
150         this.mappingsMap.putAll(context.getBeansWithAnnotation(Controller.class));
151
152         // default server value
153         if (CollectionUtils.isEmpty(calculatedOpenAPI.getServers()) || !isServersPresent) {
154             this.updateServers(calculatedOpenAPI);
155         }
156         // add security schemes
157         this.calculateSecuritySchemes(calculatedOpenAPI.getComponents());
158         openApiBuilderCustomisers.ifPresent(customisers -> customisers.forEach(customiser -> customiser.customise(this)));
159     }
160
161     public void updateServers(OpenAPI openAPI) {
162         Server server = new Server().url(serverBaseUrl).description(DEFAULT_SERVER_DESCRIPTION);
163         List<Server> servers = new ArrayList();
164         servers.add(server);
165         openAPI.setServers(servers);
166     }
167
168     public boolean isServersPresent() {
169         return isServersPresent;
170     }
171
172     public Operation buildTags(HandlerMethod handlerMethod, Operation operation, OpenAPI openAPI) {
173
174         // class tags
175         Set<Tags> tagsSet = AnnotatedElementUtils
176                 .findAllMergedAnnotations(handlerMethod.getBeanType(), Tags.class);
177         Set<Tag> classTags = tagsSet.stream()
178                 .flatMap(x -> Stream.of(x.value())).collect(Collectors.toSet());
179         classTags.addAll(AnnotatedElementUtils.findAllMergedAnnotations(handlerMethod.getBeanType(), Tag.class));
180
181         // method tags
182         tagsSet = AnnotatedElementUtils
183                 .findAllMergedAnnotations(handlerMethod.getMethod(), Tags.class);
184         Set<Tag> methodTags = tagsSet.stream()
185                 .flatMap(x -> Stream.of(x.value())).collect(Collectors.toSet());
186         methodTags.addAll(AnnotatedElementUtils.findAllMergedAnnotations(handlerMethod.getMethod(), Tag.class));
187
188
189         List<Tag> allTags = new ArrayList<>();
190         Set<String> tagsStr = new HashSet<>();
191
192         if (!CollectionUtils.isEmpty(methodTags)) {
193             tagsStr.addAll(methodTags.stream().map(Tag::name).collect(Collectors.toSet()));
194             allTags.addAll(methodTags);
195         }
196
197         if (!CollectionUtils.isEmpty(classTags)) {
198             tagsStr.addAll(classTags.stream().map(Tag::name).collect(Collectors.toSet()));
199             allTags.addAll(classTags);
200         }
201
202         if (springdocTags.containsKey(handlerMethod)) {
203             io.swagger.v3.oas.models.tags.Tag tag = springdocTags.get(handlerMethod);
204             tagsStr.add(tag.getName());
205             if (openAPI.getTags() == null || !openAPI.getTags().contains(tag)) {
206                 openAPI.addTagsItem(tag);
207             }
208         }
209
210         Optional<Set<io.swagger.v3.oas.models.tags.Tag>> tags = AnnotationsUtils
211                 .getTags(allTags.toArray(new Tag[0]), true);
212
213         if (tags.isPresent()) {
214             Set<io.swagger.v3.oas.models.tags.Tag> tagSet = tags.get();
215             // Existing tags
216             List<io.swagger.v3.oas.models.tags.Tag> openApiTags = openAPI.getTags();
217             if (!CollectionUtils.isEmpty(openApiTags))
218                 tagSet.addAll(openApiTags);
219             openAPI.setTags(new ArrayList<>(tagSet));
220         }
221
222         // Handle SecurityRequirement at operation level
223         io.swagger.v3.oas.annotations.security.SecurityRequirement[] securityRequirements = securityParser
224                 .getSecurityRequirements(handlerMethod);
225         if (securityRequirements != null) {
226             if (securityRequirements.length == 0)
227                 operation.setSecurity(Collections.emptyList());
228             else
229                 securityParser.buildSecurityRequirement(securityRequirements, operation);
230         }
231         if (!CollectionUtils.isEmpty(tagsStr))
232             operation.setTags(new ArrayList<>(tagsStr));
233
234
235         if (isAutoTagClasses(operation))
236             operation.addTagsItem(splitCamelCase(handlerMethod.getBeanType().getSimpleName()));
237
238         return operation;
239     }
240
241     public Schema resolveProperties(Schema schema, PropertyResolverUtils propertyResolverUtils) {
242         resolveProperty(schema::getName, schema::name, propertyResolverUtils);
243         resolveProperty(schema::getTitle, schema::title, propertyResolverUtils);
244         resolveProperty(schema::getDescription, schema::description, propertyResolverUtils);
245
246         Map<String, Schema> properties = schema.getProperties();
247         if (!CollectionUtils.isEmpty(properties)) {
248             Map<String, Schema> resolvedSchemas = properties.entrySet().stream().map(es -> {
249                 es.setValue(resolveProperties(es.getValue(), propertyResolverUtils));
250                 return es;
251             }).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
252             schema.setProperties(resolvedSchemas);
253         }
254
255         return schema;
256     }
257
258     public void setServerBaseUrl(String serverBaseUrl) {
259         this.serverBaseUrl = serverBaseUrl;
260     }
261
262     private Optional<OpenAPIDefinition> getOpenAPIDefinition() {
263         // Look for OpenAPIDefinition in a spring managed bean
264         Map<String, Object> openAPIDefinitionMap = context.getBeansWithAnnotation(OpenAPIDefinition.class);
265         OpenAPIDefinition apiDef = null;
266         if (openAPIDefinitionMap.size() > 1)
267             LOGGER.warn(
268                     "found more than one OpenAPIDefinition class. springdoc-openapi will be using the first one found.");
269         if (openAPIDefinitionMap.size() > 0) {
270             Map.Entry<String, Object> entry = openAPIDefinitionMap.entrySet().iterator().next();
271             Class<?> objClz = entry.getValue().getClass();
272             apiDef = AnnotatedElementUtils.findMergedAnnotation(objClz, OpenAPIDefinition.class);
273         }
274
275         // Look for OpenAPIDefinition in the spring classpath
276         else {
277             ClassPathScanningCandidateComponentProvider scanner = new ClassPathScanningCandidateComponentProvider(
278                     false);
279             scanner.addIncludeFilter(new AnnotationTypeFilter(OpenAPIDefinition.class));
280             if (AutoConfigurationPackages.has(context)) {
281                 List<String> packagesToScan = AutoConfigurationPackages.get(context);
282                 apiDef = getApiDefClass(scanner, packagesToScan);
283             }
284
285         }
286         return Optional.ofNullable(apiDef);
287     }
288
289     private void buildOpenAPIWithOpenAPIDefinition(OpenAPI openAPI, OpenAPIDefinition apiDef) {
290         // info
291         AnnotationsUtils.getInfo(apiDef.info()).map(this::resolveProperties).ifPresent(openAPI::setInfo);
292         // OpenApiDefinition security requirements
293         securityParser.getSecurityRequirements(apiDef.security()).ifPresent(openAPI::setSecurity);
294         // OpenApiDefinition external docs
295         AnnotationsUtils.getExternalDocumentation(apiDef.externalDocs()).ifPresent(openAPI::setExternalDocs);
296         // OpenApiDefinition tags
297         AnnotationsUtils.getTags(apiDef.tags(), false).ifPresent(tags -> openAPI.setTags(new ArrayList<>(tags)));
298         // OpenApiDefinition servers
299         Optional<List<Server>> optionalServers = AnnotationsUtils.getServers(apiDef.servers());
300         if (optionalServers.isPresent()) {
301             openAPI.setServers(optionalServers.get());
302             this.isServersPresent = true;
303         }
304         // OpenApiDefinition extensions
305         if (apiDef.extensions().length > 0) {
306             openAPI.setExtensions(AnnotationsUtils.getExtensions(apiDef.extensions()));
307         }
308     }
309
310     private Info resolveProperties(Info info) {
311         PropertyResolverUtils propertyResolverUtils = context.getBean(PropertyResolverUtils.class);
312         resolveProperty(info::getTitle, info::title, propertyResolverUtils);
313         resolveProperty(info::getDescription, info::description, propertyResolverUtils);
314         resolveProperty(info::getVersion, info::version, propertyResolverUtils);
315         resolveProperty(info::getTermsOfService, info::termsOfService, propertyResolverUtils);
316
317         License license = info.getLicense();
318         if (license != null) {
319             resolveProperty(license::getName, license::name, propertyResolverUtils);
320             resolveProperty(license::getUrl, license::url, propertyResolverUtils);
321         }
322
323         Contact contact = info.getContact();
324         if (contact != null) {
325             resolveProperty(contact::getName, contact::name, propertyResolverUtils);
326             resolveProperty(contact::getEmail, contact::email, propertyResolverUtils);
327             resolveProperty(contact::getUrl, contact::url, propertyResolverUtils);
328         }
329         return info;
330     }
331
332     private void resolveProperty(Supplier<String> getProperty, Consumer<String> setProperty,
333             PropertyResolverUtils propertyResolverUtils) {
334         String value = getProperty.get();
335         if (StringUtils.isNotBlank(value)) {
336             setProperty.accept(propertyResolverUtils.resolve(value));
337         }
338     }
339
340     private void calculateSecuritySchemes(Components components) {
341         // Look for SecurityScheme in a spring managed bean
342         Map<String, Object> securitySchemeBeans = context
343                 .getBeansWithAnnotation(io.swagger.v3.oas.annotations.security.SecurityScheme.class);
344         if (securitySchemeBeans.size() > 0) {
345             for (Map.Entry<String, Object> entry : securitySchemeBeans.entrySet()) {
346                 Class<?> objClz = entry.getValue().getClass();
347                 Set<io.swagger.v3.oas.annotations.security.SecurityScheme> apiSecurityScheme = AnnotatedElementUtils.findMergedRepeatableAnnotations(objClz, io.swagger.v3.oas.annotations.security.SecurityScheme.class);
348                 this.addSecurityScheme(apiSecurityScheme, components);
349             }
350         }
351
352         // Look for SecurityScheme in the spring classpath
353         else {
354             ClassPathScanningCandidateComponentProvider scanner = new ClassPathScanningCandidateComponentProvider(
355                     false);
356             scanner.addIncludeFilter(
357                     new AnnotationTypeFilter(io.swagger.v3.oas.annotations.security.SecurityScheme.class));
358             if (AutoConfigurationPackages.has(context)) {
359                 List<String> packagesToScan = AutoConfigurationPackages.get(context);
360                 Set<io.swagger.v3.oas.annotations.security.SecurityScheme> apiSecurityScheme = getSecuritySchemesClasses(
361                         scanner, packagesToScan);
362                 this.addSecurityScheme(apiSecurityScheme, components);
363             }
364
365         }
366     }
367
368     private void addSecurityScheme(Set<io.swagger.v3.oas.annotations.security.SecurityScheme> apiSecurityScheme,
369             Components components) {
370         for (io.swagger.v3.oas.annotations.security.SecurityScheme securitySchemeAnnotation : apiSecurityScheme) {
371             Optional<SecuritySchemePair> securityScheme = securityParser.getSecurityScheme(securitySchemeAnnotation);
372             if (securityScheme.isPresent()) {
373                 Map<String, SecurityScheme> securitySchemeMap = new HashMap<>();
374                 if (StringUtils.isNotBlank(securityScheme.get().getKey())) {
375                     securitySchemeMap.put(securityScheme.get().getKey(), securityScheme.get().getSecurityScheme());
376                     if (!CollectionUtils.isEmpty(components.getSecuritySchemes())) {
377                         components.getSecuritySchemes().putAll(securitySchemeMap);
378                     }
379                     else {
380                         components.setSecuritySchemes(securitySchemeMap);
381                     }
382                 }
383             }
384         }
385     }
386
387     private OpenAPIDefinition getApiDefClass(ClassPathScanningCandidateComponentProvider scanner,
388             List<String> packagesToScan) {
389         for (String pack : packagesToScan) {
390             for (BeanDefinition bd : scanner.findCandidateComponents(pack)) {
391                 // first one found is ok
392                 try {
393                     return AnnotationUtils.findAnnotation(Class.forName(bd.getBeanClassName()),
394                             OpenAPIDefinition.class);
395                 }
396                 catch (ClassNotFoundException e) {
397                     LOGGER.error("Class Not Found in classpath : {}", e.getMessage());
398                 }
399             }
400         }
401         return null;
402     }
403
404     private boolean isAutoTagClasses(Operation operation) {
405         return CollectionUtils.isEmpty(operation.getTags()) && springDocConfigProperties.isAutoTagClasses();
406     }
407
408     private Set<io.swagger.v3.oas.annotations.security.SecurityScheme> getSecuritySchemesClasses(
409             ClassPathScanningCandidateComponentProvider scanner, List<String> packagesToScan) {
410         Set<io.swagger.v3.oas.annotations.security.SecurityScheme> apiSecurityScheme = new HashSet<>();
411         for (String pack : packagesToScan) {
412             for (BeanDefinition bd : scanner.findCandidateComponents(pack)) {
413                 try {
414                     apiSecurityScheme.add(AnnotationUtils.findAnnotation(Class.forName(bd.getBeanClassName()),
415                             io.swagger.v3.oas.annotations.security.SecurityScheme.class));
416                 }
417                 catch (ClassNotFoundException e) {
418                     LOGGER.error("Class Not Found in classpath : {}", e.getMessage());
419                 }
420             }
421         }
422         return apiSecurityScheme;
423     }
424
425     public void addTag(Set<HandlerMethod> handlerMethods, io.swagger.v3.oas.models.tags.Tag tag) {
426         handlerMethods.forEach(handlerMethod -> springdocTags.put(handlerMethod, tag));
427     }
428
429     public Map<String, Object> getMappingsMap() {
430         return this.mappingsMap;
431     }
432
433     public void addMappings(Map<String, Object> mappings) {
434         this.mappingsMap.putAll(mappings);
435     }
436
437     public Map<String, Object> getControllerAdviceMap() {
438         Map<String, Object> controllerAdviceMap = context.getBeansWithAnnotation(ControllerAdvice.class);
439         return Stream.of(controllerAdviceMap).flatMap(mapEl -> mapEl.entrySet().stream()).filter(
440                 controller -> (AnnotationUtils.findAnnotation(controller.getValue().getClass(), Hidden.class) == null))
441                 .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (a1, a2) -> a1));
442     }
443
444     public OpenAPI calculateCachedOpenAPI() {
445         if (!this.isServersPresent())
446             this.updateServers(cachedOpenAPI);
447         return cachedOpenAPI;
448     }
449
450     public void setCachedOpenAPI(OpenAPI cachedOpenAPI) {
451         this.cachedOpenAPI = cachedOpenAPI;
452     }
453
454     public OpenAPI getCachedOpenAPI() {
455         return cachedOpenAPI;
456     }
457
458     public OpenAPI getCalculatedOpenAPI() {
459         return calculatedOpenAPI;
460     }
461
462     public void resetCalculatedOpenAPI() {
463         this.calculatedOpenAPI = null;
464     }
465 }
466