1 /*
2  * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License").
5  * You may not use this file except in compliance with the License.
6  * A copy of the License is located at
7  *
8  *  http://aws.amazon.com/apache2.0
9  *
10  * or in the "license" file accompanying this file. This file is distributed
11  * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12  * express or implied. See the License for the specific language governing
13  * permissions and limitations under the License.
14  */

15
16 package software.amazon.awssdk.protocols.xml.internal.marshall;
17
18 import java.math.BigDecimal;
19 import java.time.Instant;
20 import java.util.HashMap;
21 import java.util.LinkedHashMap;
22 import java.util.List;
23 import java.util.Map;
24 import software.amazon.awssdk.annotations.SdkInternalApi;
25 import software.amazon.awssdk.core.SdkBytes;
26 import software.amazon.awssdk.core.SdkField;
27 import software.amazon.awssdk.core.SdkPojo;
28 import software.amazon.awssdk.core.protocol.MarshallLocation;
29 import software.amazon.awssdk.core.traits.ListTrait;
30 import software.amazon.awssdk.core.traits.MapTrait;
31 import software.amazon.awssdk.core.traits.XmlAttributeTrait;
32 import software.amazon.awssdk.core.traits.XmlAttributesTrait;
33 import software.amazon.awssdk.core.util.SdkAutoConstructList;
34 import software.amazon.awssdk.core.util.SdkAutoConstructMap;
35 import software.amazon.awssdk.protocols.core.ValueToStringConverter;
36
37 @SdkInternalApi
38 public class XmlPayloadMarshaller {
39
40     public static final XmlMarshaller<String> STRING = new BasePayloadMarshaller<>(ValueToStringConverter.FROM_STRING);
41
42     public static final XmlMarshaller<Integer> INTEGER = new BasePayloadMarshaller<>(ValueToStringConverter.FROM_INTEGER);
43
44     public static final XmlMarshaller<Long> LONG = new BasePayloadMarshaller<>(ValueToStringConverter.FROM_LONG);
45
46     public static final XmlMarshaller<Float> FLOAT = new BasePayloadMarshaller<>(ValueToStringConverter.FROM_FLOAT);
47
48     public static final XmlMarshaller<Double> DOUBLE = new BasePayloadMarshaller<>(ValueToStringConverter.FROM_DOUBLE);
49
50     public static final XmlMarshaller<BigDecimal> BIG_DECIMAL =
51         new BasePayloadMarshaller<>(ValueToStringConverter.FROM_BIG_DECIMAL);
52
53     public static final XmlMarshaller<Boolean> BOOLEAN = new BasePayloadMarshaller<>(ValueToStringConverter.FROM_BOOLEAN);
54
55     public static final XmlMarshaller<Instant> INSTANT =
56         new BasePayloadMarshaller<>(XmlProtocolMarshaller.INSTANT_VALUE_TO_STRING);
57
58     public static final XmlMarshaller<SdkBytes> SDK_BYTES = new BasePayloadMarshaller<>(ValueToStringConverter.FROM_SDK_BYTES);
59
60     public static final XmlMarshaller<SdkPojo> SDK_POJO = new BasePayloadMarshaller<SdkPojo>(null) {
61         @Override
62         public void marshall(SdkPojo val, XmlMarshallerContext context, String paramName,
63                              SdkField<SdkPojo> sdkField, ValueToStringConverter.ValueToString<SdkPojo> converter) {
64             context.protocolMarshaller().doMarshall(val);
65         }
66     };
67
68     public static final XmlMarshaller<List<?>> LIST = new BasePayloadMarshaller<List<?>>(null) {
69
70         @Override
71         public void marshall(List<?> val, XmlMarshallerContext context, String paramName, SdkField<List<?>> sdkField) {
72             if (!shouldEmit(val, paramName)) {
73                 return;
74             }
75
76             marshall(val, context, paramName, sdkField, null);
77         }
78
79         @Override
80         public void marshall(List<?> list, XmlMarshallerContext context, String paramName,
81                              SdkField<List<?>> sdkField, ValueToStringConverter.ValueToString<List<?>> converter) {
82             ListTrait listTrait = sdkField
83                 .getOptionalTrait(ListTrait.class)
84                 .orElseThrow(() -> new IllegalStateException(paramName + " member is missing ListTrait"));
85
86             if (!listTrait.isFlattened()) {
87                 context.xmlGenerator().startElement(paramName);
88             }
89
90             SdkField memberField = listTrait.memberFieldInfo();
91             String memberLocationName = listMemberLocationName(listTrait, paramName);
92
93             for (Object listMember : list) {
94                 context.marshall(MarshallLocation.PAYLOAD, listMember, memberLocationName, memberField);
95             }
96
97             if (!listTrait.isFlattened()) {
98                 context.xmlGenerator().endElement();
99             }
100         }
101
102         private String listMemberLocationName(ListTrait listTrait, String listLocationName) {
103             String locationName = listTrait.memberLocationName();
104
105             if (locationName == null) {
106                 locationName = listTrait.isFlattened() ? listLocationName : "member";
107             }
108
109             return locationName;
110         }
111
112         @Override
113         protected boolean shouldEmit(List list, String paramName) {
114             return super.shouldEmit(list, paramName) &&
115                    (!list.isEmpty() || !(list instanceof SdkAutoConstructList));
116         }
117     };
118
119     // We ignore flattened trait for maps. For rest-xml, none of the services have flattened maps
120     public static final XmlMarshaller<Map<String, ?>> MAP = new BasePayloadMarshaller<Map<String, ?>>(null) {
121
122         @Override
123         public void marshall(Map<String, ?> map, XmlMarshallerContext context, String paramName,
124                              SdkField<Map<String, ?>> sdkField, ValueToStringConverter.ValueToString<Map<String, ?>> converter) {
125
126             MapTrait mapTrait = sdkField.getOptionalTrait(MapTrait.class)
127                                         .orElseThrow(() -> new IllegalStateException(paramName + " member is missing MapTrait"));
128
129             for (Map.Entry<String, ?> entry : map.entrySet()) {
130                 context.xmlGenerator().startElement("entry");
131                 context.marshall(MarshallLocation.PAYLOAD, entry.getKey(), mapTrait.keyLocationName(), null);
132                 context.marshall(MarshallLocation.PAYLOAD, entry.getValue(), mapTrait.valueLocationName(),
133                                  mapTrait.valueFieldInfo());
134                 context.xmlGenerator().endElement();
135             }
136         }
137
138         @Override
139         protected boolean shouldEmit(Map map, String paramName) {
140             return super.shouldEmit(map, paramName) &&
141                    (!map.isEmpty() || !(map instanceof SdkAutoConstructMap));
142         }
143     };
144
145     private XmlPayloadMarshaller() {
146     }
147
148     /**
149      * Base payload marshaller for xml protocol. Marshalling happens only when both element name and value are present.
150      *
151      * Marshalling for simple types is done in the base class.
152      * Complex types should override the
153      * {@link #marshall(Object, XmlMarshallerContext, String, SdkField, ValueToStringConverter.ValueToString)} method.
154      *
155      * @param <T> Type to marshall
156      */

157     private static class BasePayloadMarshaller<T> implements XmlMarshaller<T> {
158
159         private final ValueToStringConverter.ValueToString<T> converter;
160
161         private BasePayloadMarshaller(ValueToStringConverter.ValueToString<T> converter) {
162             this.converter = converter;
163         }
164
165         @Override
166         public void marshall(T val, XmlMarshallerContext context, String paramName, SdkField<T> sdkField) {
167             if (!shouldEmit(val, paramName)) {
168                 return;
169             }
170
171             // Should ignore marshalling for xml attribute
172             if (isXmlAttribute(sdkField)) {
173                 return;
174             }
175
176             if (sdkField != null && sdkField.getOptionalTrait(XmlAttributesTrait.class).isPresent()) {
177                 XmlAttributesTrait attributeTrait = sdkField.getTrait(XmlAttributesTrait.class);
178                 Map<String, String> attributes = attributeTrait.attributes()
179                                                                .entrySet()
180                                                                .stream()
181                                                                .collect(LinkedHashMap::new, (m, e) -> m.put(e.getKey(),
182                                                                                                             e.getValue()
183                                                                                                              .attributeGetter()
184                                                                                                              .apply(val)),
185                                                                         HashMap::putAll);
186                 context.xmlGenerator().startElement(paramName, attributes);
187             } else {
188                 context.xmlGenerator().startElement(paramName);
189             }
190
191             marshall(val, context, paramName, sdkField, converter);
192             context.xmlGenerator().endElement();
193         }
194
195         void marshall(T val, XmlMarshallerContext context, String paramName, SdkField<T> sdkField,
196                       ValueToStringConverter.ValueToString<T> converter) {
197             context.xmlGenerator().xmlWriter().value(converter.convert(val, sdkField));
198         }
199
200         protected boolean shouldEmit(T val, String paramName) {
201             return val != null && paramName != null;
202         }
203
204         private boolean isXmlAttribute(SdkField<T> sdkField) {
205             return sdkField != null && sdkField.getOptionalTrait(XmlAttributeTrait.class).isPresent();
206         }
207     }
208
209 }
210