1 /*
2  * Copyright 2013-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License").
5  * You may not use this file except in compliance with the License.
6  * A copy of the License is located at
7  *
8  *  http://aws.amazon.com/apache2.0
9  *
10  * or in the "license" file accompanying this file. This file is distributed
11  * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12  * express or implied. See the License for the specific language governing
13  * permissions and limitations under the License.
14  */

15 package com.amazonaws.util;
16 import static com.amazonaws.util.CodecUtils.sanityCheckLastPos;
17
18 /**
19  * A Base 64 codec implementation.
20  * 
21  * @author Hanson Char
22  */

23 class Base64Codec implements Codec {
24     private static final int OFFSET_OF_a = 'a' - 26;
25     private static final int OFFSET_OF_0 = '0' - 52;
26     private static final int OFFSET_OF_PLUS = '+' - 62;
27     private static final int OFFSET_OF_SLASH = '/' - 63;
28     
29     private static final int MASK_2BITS = (1 << 2) - 1;
30     private static final int MASK_4BITS = (1 << 4) - 1;
31     private static final int MASK_6BITS = (1 << 6) - 1;
32     // Alphabet as defined at http://www.ietf.org/rfc/rfc4648.txt
33     private static final byte PAD = '=';
34     
35     private static class LazyHolder {
36         private static final byte[] DECODED = decodeTable();
37         
38         private static byte[] decodeTable() {
39             final byte[] dest = new byte['z'+1];
40             
41             for (int i=0; i <= 'z'; i++) 
42             {
43                 if (i >= 'A' && i <= 'Z')
44                     dest[i] = (byte)(i - 'A');
45                 else if (i >= '0' && i <= '9')
46                     dest[i] = (byte)(i - OFFSET_OF_0);
47                 else if (i == '+')
48                     dest[i] = (byte)(i - OFFSET_OF_PLUS);
49                 else if (i == '/')
50                     dest[i] = (byte)(i - OFFSET_OF_SLASH);
51                 else if (i >= 'a' && i <= 'z')
52                     dest[i] = (byte)(i - OFFSET_OF_a);
53                 else 
54                     dest[i] = -1;
55             }
56             return dest;
57         }
58     }
59
60     private final byte[] alphabets;
61
62     Base64Codec() {
63         this.alphabets = CodecUtils.toBytesDirect("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/");
64     }
65     
66     protected Base64Codec(byte[] alphabets) {
67         this.alphabets = alphabets;
68     }
69
70     @Override
71     public byte[] encode(byte[] src) {
72         final int num3bytes = src.length / 3;
73         final int remainder = src.length % 3;
74         
75         if (remainder == 0)
76         {
77             byte[] dest = new byte[num3bytes * 4];
78     
79             for (int s=0,d=0; s < src.length; s+=3, d+=4)
80                 encode3bytes(src, s, dest, d);
81             return dest;
82         }
83         
84         byte[] dest = new byte[(num3bytes+1) * 4];
85         int s=0, d=0;
86         
87         for (; s < src.length-remainder; s+=3, d+=4)
88             encode3bytes(src, s, dest, d);
89         
90         switch(remainder) {
91             case 1:
92                 encode1byte(src, s, dest, d);
93                 break;
94             case 2:
95                 encode2bytes(src, s, dest, d);
96                 break;
97             default:
98                 throw new IllegalStateException();
99         }
100         return dest;
101     }
102     
103     void encode3bytes(byte[] src, int s, byte[] dest, int d) {
104         // operator precedence in descending order: >>> or <<, &, |
105         byte p;
106         dest[d++] = (byte)alphabets[(p=src[s++]) >>> 2 & MASK_6BITS];                         // 6 
107         dest[d++] = (byte)alphabets[(p & MASK_2BITS) << 4 | (p=src[s++]) >>> 4 & MASK_4BITS]; // 2 4
108         dest[d++] = (byte)alphabets[(p & MASK_4BITS) << 2 | (p=src[s]) >>> 6 & MASK_2BITS];   //   4 2
109         dest[d] = (byte)alphabets[p & MASK_6BITS];                                            //     6
110         return;
111     }
112     
113     void encode2bytes(byte[] src, int s, byte[] dest, int d) {
114         // operator precedence in descending order: >>> or <<, &, |
115         byte p;
116         dest[d++] = (byte)alphabets[(p=src[s++]) >>> 2 & MASK_6BITS];                         // 6 
117         dest[d++] = (byte)alphabets[(p & MASK_2BITS) << 4 | (p=src[s]) >>> 4 & MASK_4BITS];   // 2 4
118         dest[d++] = (byte)alphabets[(p & MASK_4BITS) << 2];                                   //   4
119         dest[d] = PAD;
120         return;
121     }
122     
123     void encode1byte(byte[] src, int s, byte[] dest, int d) {
124         // operator precedence in descending order: >>> or <<, &, |
125         byte p;
126         dest[d++] = (byte)alphabets[(p=src[s]) >>> 2 & MASK_6BITS];                           // 6 
127         dest[d++] = (byte)alphabets[(p & MASK_2BITS) << 4];                                   // 2
128         dest[d++] = PAD;
129         dest[d] = PAD;
130         return;
131     }
132     
133     void decode4bytes(byte[] src, int s, byte[] dest, int d) {
134         int p=0;
135         // operator precedence in descending order: >>> or <<, &, |
136         dest[d++] = (byte)
137                     (
138                         pos(src[s++]) << 2
139                         | (p=pos(src[s++])) >>> 4 & MASK_2BITS
140                     )
141                     ;                                               // 6 2
142         dest[d++] = (byte)
143                     (
144                         (p & MASK_4BITS) << 4 
145                         | (p=pos(src[s++])) >>> 2 & MASK_4BITS
146                     )
147                     ;                                               //   4 4
148         dest[d] = (byte)
149                     (
150                         (p & MASK_2BITS) << 6
151                         | pos(src[s])
152                     )
153                     ;                                               //     2 6
154         return;
155     }
156     
157     /**
158      * @param n the number of final quantum in bytes to decode into.  Ranges from 1 to 3, inclusive.
159      */

160     void decode1to3bytes(int n, byte[] src, int s, byte[] dest, int d) {
161         int p=0;
162         // operator precedence in descending order: >>> or <<, &, |
163         dest[d++] = (byte)
164                     (
165                         pos(src[s++]) << 2
166                         | (p=pos(src[s++])) >>> 4 & MASK_2BITS
167                     )
168                     ;                                               // 6 2
169         if (n == 1) {
170             sanityCheckLastPos(p, MASK_4BITS);
171             return;
172         }
173         
174         dest[d++] = (byte)
175                     (
176                         (p & MASK_4BITS) << 4 
177                         | (p=pos(src[s++])) >>> 2 & MASK_4BITS
178                     )
179                     ;                                               //   4 4
180         if (n == 2) {
181             sanityCheckLastPos(p, MASK_2BITS);
182             return;
183         }
184         
185         dest[d] = (byte)
186                     (
187                         (p & MASK_2BITS) << 6
188                         | pos(src[s])
189                     )
190                     ;                                               //     2 6
191         return;
192     }
193
194     @Override
195     public byte[] decode(byte[] src, final int length) 
196     {
197         if (length % 4 != 0)
198             throw new IllegalArgumentException
199             ("Input is expected to be encoded in multiple of 4 bytes but found: " + length);
200
201         int pads=0;
202         int last = length-1;
203         
204         // max possible padding in b64 encoding is 2
205         for (; pads < 2 && last > -1; last--, pads++) {
206             if (src[last] != PAD)
207                 break;
208         }
209         
210         final int fq; // final quantum in unit of bytes
211         
212         switch(pads) {
213             case 0:
214                 fq=3;
215                 break// final quantum of encoding input is an integral multiple of 24 bits
216             case 1:
217                 fq=2;
218                 break// final quantum of encoding input is exactly 16 bits
219             case 2:
220                 fq=1;
221                 break// final quantum of encoding input is exactly 8 bits
222             default:
223                 throw new Error("Impossible");
224         }
225         final byte[] dest = new byte[length / 4 * 3 - (3-fq)]; 
226         int s=0, d=0;
227         
228         // % has a higher precedence than - than <
229         for (; d < dest.length - fq%3; s+=4,d+=3)
230             decode4bytes(src, s, dest, d);
231
232         if (fq < 3)
233             decode1to3bytes(fq, src, s, dest, d);
234         return dest;
235     }
236     
237     protected int pos(byte in) {
238         int pos = LazyHolder.DECODED[in];
239         
240         if (pos > -1)
241             return pos;
242         throw new IllegalArgumentException("Invalid base 64 character: \'" + (char)in + "\'");
243     }
244 }
245