1 /*
2  * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License").
5  * You may not use this file except in compliance with the License.
6  * A copy of the License is located at
7  *
8  *  http://aws.amazon.com/apache2.0
9  *
10  * or in the "license" file accompanying this file. This file is distributed
11  * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12  * express or implied. See the License for the specific language governing
13  * permissions and limitations under the License.
14  */

15
16 package software.amazon.awssdk.protocols.xml.internal.unmarshall;
17
18 import static java.util.Collections.singletonList;
19
20 import java.time.Instant;
21 import java.util.Collections;
22 import java.util.HashMap;
23 import java.util.List;
24 import java.util.Map;
25 import software.amazon.awssdk.annotations.SdkInternalApi;
26 import software.amazon.awssdk.core.SdkField;
27 import software.amazon.awssdk.core.SdkPojo;
28 import software.amazon.awssdk.core.protocol.MarshallLocation;
29 import software.amazon.awssdk.core.protocol.MarshallingType;
30 import software.amazon.awssdk.core.traits.PayloadTrait;
31 import software.amazon.awssdk.core.traits.TimestampFormatTrait;
32 import software.amazon.awssdk.core.traits.XmlAttributeTrait;
33 import software.amazon.awssdk.http.SdkHttpFullResponse;
34 import software.amazon.awssdk.protocols.core.StringToInstant;
35 import software.amazon.awssdk.protocols.core.StringToValueConverter;
36 import software.amazon.awssdk.protocols.query.unmarshall.XmlElement;
37 import software.amazon.awssdk.protocols.query.unmarshall.XmlErrorUnmarshaller;
38 import software.amazon.awssdk.utils.CollectionUtils;
39 import software.amazon.awssdk.utils.builder.Buildable;
40
41 @SdkInternalApi
42 public final class XmlProtocolUnmarshaller implements XmlErrorUnmarshaller {
43
44     public static final StringToValueConverter.StringToValue<Instant> INSTANT_STRING_TO_VALUE
45         = StringToInstant.create(getDefaultTimestampFormats());
46
47     private static final XmlUnmarshallerRegistry REGISTRY = createUnmarshallerRegistry();
48
49     private XmlProtocolUnmarshaller() {
50     }
51
52     public static XmlProtocolUnmarshaller create() {
53         return new XmlProtocolUnmarshaller();
54     }
55
56     public <TypeT extends SdkPojo> TypeT unmarshall(SdkPojo sdkPojo,
57                                                     SdkHttpFullResponse response) {
58
59         XmlElement document = XmlResponseParserUtils.parse(sdkPojo, response);
60         return unmarshall(sdkPojo, document, response);
61     }
62
63     /**
64      * This method is also used to unmarshall exceptions. We use this since we've already parsed the XML
65      * and the result root is in a different location depending on the protocol/service.
66      */

67     @Override
68     public <TypeT extends SdkPojo> TypeT unmarshall(SdkPojo sdkPojo,
69                                                     XmlElement resultRoot,
70                                                     SdkHttpFullResponse response) {
71         XmlUnmarshallerContext unmarshallerContext = XmlUnmarshallerContext.builder()
72                                                                            .response(response)
73                                                                            .registry(REGISTRY)
74                                                                            .protocolUnmarshaller(this)
75                                                                            .build();
76         return (TypeT) unmarshall(unmarshallerContext, sdkPojo, resultRoot);
77     }
78
79     SdkPojo unmarshall(XmlUnmarshallerContext context, SdkPojo sdkPojo, XmlElement root) {
80         for (SdkField<?> field : sdkPojo.sdkFields()) {
81             XmlUnmarshaller<Object> unmarshaller = REGISTRY.getUnmarshaller(field.location(), field.marshallingType());
82
83             if (root != null && field.location() == MarshallLocation.PAYLOAD) {
84                 if (isAttribute(field)) {
85                     root.getOptionalAttributeByName(field.unmarshallLocationName())
86                         .ifPresent(e -> field.set(sdkPojo, e));
87                 } else {
88                     List<XmlElement> element = isExplicitPayloadMember(field) ?
89                                                singletonList(root) :
90                                                root.getElementsByName(field.unmarshallLocationName());
91
92                     if (!CollectionUtils.isNullOrEmpty(element)) {
93                         Object unmarshalled = unmarshaller.unmarshall(context, element, (SdkField<Object>) field);
94                         field.set(sdkPojo, unmarshalled);
95                     }
96                 }
97             } else {
98                 Object unmarshalled = unmarshaller.unmarshall(context, null, (SdkField<Object>) field);
99                 field.set(sdkPojo, unmarshalled);
100             }
101         }
102
103         if (!(sdkPojo instanceof Buildable)) {
104             throw new RuntimeException("The sdkPojo passed to the unmarshaller is not buildable (must implement "
105                                        + "Buildable)");
106         }
107         return (SdkPojo) ((Buildable) sdkPojo).build();
108     }
109
110     private boolean isAttribute(SdkField<?> field) {
111         return field.containsTrait(XmlAttributeTrait.class);
112     }
113
114     private boolean isExplicitPayloadMember(SdkField<?> field) {
115         return field.containsTrait(PayloadTrait.class);
116     }
117
118     private static Map<MarshallLocation, TimestampFormatTrait.Format> getDefaultTimestampFormats() {
119         Map<MarshallLocation, TimestampFormatTrait.Format> formats = new HashMap<>();
120         formats.put(MarshallLocation.HEADER, TimestampFormatTrait.Format.RFC_822);
121         formats.put(MarshallLocation.PAYLOAD, TimestampFormatTrait.Format.ISO_8601);
122         return Collections.unmodifiableMap(formats);
123     }
124
125     private static XmlUnmarshallerRegistry createUnmarshallerRegistry() {
126         return XmlUnmarshallerRegistry
127             .builder()
128             .statusCodeUnmarshaller(MarshallingType.INTEGER, (context, content, field) -> context.response().statusCode())
129             .headerUnmarshaller(MarshallingType.STRING, HeaderUnmarshaller.STRING)
130             .headerUnmarshaller(MarshallingType.INTEGER, HeaderUnmarshaller.INTEGER)
131             .headerUnmarshaller(MarshallingType.LONG, HeaderUnmarshaller.LONG)
132             .headerUnmarshaller(MarshallingType.DOUBLE, HeaderUnmarshaller.DOUBLE)
133             .headerUnmarshaller(MarshallingType.BOOLEAN, HeaderUnmarshaller.BOOLEAN)
134             .headerUnmarshaller(MarshallingType.INSTANT, HeaderUnmarshaller.INSTANT)
135             .headerUnmarshaller(MarshallingType.FLOAT, HeaderUnmarshaller.FLOAT)
136             .headerUnmarshaller(MarshallingType.MAP, HeaderUnmarshaller.MAP)
137
138             .payloadUnmarshaller(MarshallingType.STRING, XmlPayloadUnmarshaller.STRING)
139             .payloadUnmarshaller(MarshallingType.INTEGER, XmlPayloadUnmarshaller.INTEGER)
140             .payloadUnmarshaller(MarshallingType.LONG, XmlPayloadUnmarshaller.LONG)
141             .payloadUnmarshaller(MarshallingType.FLOAT, XmlPayloadUnmarshaller.FLOAT)
142             .payloadUnmarshaller(MarshallingType.DOUBLE, XmlPayloadUnmarshaller.DOUBLE)
143             .payloadUnmarshaller(MarshallingType.BIG_DECIMAL, XmlPayloadUnmarshaller.BIG_DECIMAL)
144             .payloadUnmarshaller(MarshallingType.BOOLEAN, XmlPayloadUnmarshaller.BOOLEAN)
145             .payloadUnmarshaller(MarshallingType.INSTANT, XmlPayloadUnmarshaller.INSTANT)
146             .payloadUnmarshaller(MarshallingType.SDK_BYTES, XmlPayloadUnmarshaller.SDK_BYTES)
147             .payloadUnmarshaller(MarshallingType.SDK_POJO, XmlPayloadUnmarshaller::unmarshallSdkPojo)
148             .payloadUnmarshaller(MarshallingType.LIST, XmlPayloadUnmarshaller::unmarshallList)
149             .payloadUnmarshaller(MarshallingType.MAP, XmlPayloadUnmarshaller::unmarshallMap)
150             .build();
151     }
152 }
153