1 /*
2  * JBoss, Home of Professional Open Source.
3  * Copyright 2017 Red Hat, Inc., and individual contributors
4  * as indicated by the @author tags.
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  *     http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */

18
19 package org.wildfly.common.net;
20
21 import java.net.Inet4Address;
22 import java.net.InetAddress;
23 import java.util.Arrays;
24 import java.util.Iterator;
25 import java.util.NoSuchElementException;
26 import java.util.Objects;
27 import java.util.Spliterator;
28 import java.util.Spliterators;
29 import java.util.concurrent.atomic.AtomicReference;
30
31 import org.wildfly.common.Assert;
32
33 /**
34  * A table for mapping IP addresses to objects using {@link CidrAddress} instances for matching.
35  *
36  * @author <a href="mailto:david.lloyd@redhat.com">David M. Lloyd</a>
37  */

38 public final class CidrAddressTable<T> implements Iterable<CidrAddressTable.Mapping<T>> {
39
40     @SuppressWarnings("rawtypes")
41     private static final Mapping[] NO_MAPPINGS = new Mapping[0];
42
43     private final AtomicReference<Mapping<T>[]> mappingsRef;
44
45     public CidrAddressTable() {
46         mappingsRef = new AtomicReference<>(empty());
47     }
48
49     private CidrAddressTable(Mapping<T>[] mappings) {
50         mappingsRef = new AtomicReference<>(mappings);
51     }
52
53     public T getOrDefault(InetAddress address, T defVal) {
54         Assert.checkNotNullParam("address", address);
55         final Mapping<T> mapping = doGet(mappingsRef.get(), address.getAddress(), address instanceof Inet4Address ? 32 : 128, Inet.getScopeId(address));
56         return mapping == null ? defVal : mapping.value;
57     }
58
59     public T get(InetAddress address) {
60         return getOrDefault(address, null);
61     }
62
63     public T put(CidrAddress block, T value) {
64         Assert.checkNotNullParam("block", block);
65         Assert.checkNotNullParam("value", value);
66         return doPut(block, null, value, truetrue);
67     }
68
69     public T putIfAbsent(CidrAddress block, T value) {
70         Assert.checkNotNullParam("block", block);
71         Assert.checkNotNullParam("value", value);
72         return doPut(block, null, value, truefalse);
73     }
74
75     public T replaceExact(CidrAddress block, T value) {
76         Assert.checkNotNullParam("block", block);
77         Assert.checkNotNullParam("value", value);
78         return doPut(block, null, value, falsetrue);
79     }
80
81     public boolean replaceExact(CidrAddress block, T expect, T update) {
82         Assert.checkNotNullParam("block", block);
83         Assert.checkNotNullParam("expect", expect);
84         Assert.checkNotNullParam("update", update);
85         return doPut(block, expect, update, falsetrue) == expect;
86     }
87
88     public T removeExact(CidrAddress block) {
89         Assert.checkNotNullParam("block", block);
90         return doPut(block, nullnullfalsetrue);
91     }
92
93     public boolean removeExact(CidrAddress block, T expect) {
94         Assert.checkNotNullParam("block", block);
95         return doPut(block, expect, nullfalsetrue) == expect;
96     }
97
98     private T doPut(final CidrAddress block, final T expect, final T update, final boolean putIfAbsent, final boolean putIfPresent) {
99         assert putIfAbsent || putIfPresent;
100         final AtomicReference<Mapping<T>[]> mappingsRef = this.mappingsRef;
101         final byte[] bytes = block.getNetworkAddress().getAddress();
102         Mapping<T>[] oldVal, newVal;
103         int idx;
104         T existing;
105         boolean matchesExpected;
106         do {
107             oldVal = mappingsRef.get();
108             idx = doFind(oldVal, bytes, block.getNetmaskBits(), block.getScopeId());
109             if (idx < 0) {
110                 if (! putIfAbsent) {
111                     return null;
112                 }
113                 existing = null;
114             } else {
115                 existing = oldVal[idx].value;
116             }
117             if (expect != null) {
118                 matchesExpected = Objects.equals(expect, existing);
119                 if (! matchesExpected) {
120                     return existing;
121                 }
122             } else {
123                 matchesExpected = false;
124             }
125             if (idx >= 0 && ! putIfPresent) {
126                 return existing;
127             }
128             // now construct the new mapping
129             final int oldLen = oldVal.length;
130             if (update == null) {
131                 assert idx >= 0;
132                 // removal
133                 if (oldLen == 1) {
134                     newVal = empty();
135                 } else {
136                     final Mapping<T> removing = oldVal[idx];
137                     newVal = Arrays.copyOf(oldVal, oldLen - 1);
138                     System.arraycopy(oldVal, idx + 1, newVal, idx, oldLen - idx - 1);
139                     // now reparent any children that I was a parent of with my old parent
140                     for (int i = 0; i < oldLen - 1; i ++) {
141                         if (newVal[i].parent == removing) {
142                             newVal[i] = newVal[i].withNewParent(removing.parent);
143                         }
144                     }
145                 }
146             } else if (idx >= 0) {
147                 // replace
148                 newVal = oldVal.clone();
149                 final Mapping<T> oldMapping = oldVal[idx];
150                 final Mapping<T> newMapping = new Mapping<>(block, update, oldVal[idx].parent);
151                 newVal[idx] = newMapping;
152                 // now reparent any child to me
153                 for (int i = 0; i < oldLen; i ++) {
154                     if (i != idx && newVal[i].parent == oldMapping) {
155                         newVal[i] = newVal[i].withNewParent(newMapping);
156                     }
157                 }
158             } else {
159                 // add
160                 newVal = Arrays.copyOf(oldVal, oldLen + 1);
161                 final Mapping<T> newMappingParent = doGet(oldVal, bytes, block.getNetmaskBits(), block.getScopeId());
162                 final Mapping<T> newMapping = new Mapping<>(block, update, newMappingParent);
163                 newVal[-idx - 1] = newMapping;
164                 System.arraycopy(oldVal, -idx - 1, newVal, -idx, oldLen + idx + 1);
165                 // now reparent any children who have a parent of my (possibly null) parent but match me
166                 for (int i = 0; i <= oldLen; i++) {
167                     if (newVal[i] != newMapping && newVal[i].parent == newMappingParent && block.matches(newVal[i].range)) {
168                         newVal[i] = newVal[i].withNewParent(newMapping);
169                     }
170                 }
171             }
172         } while (! mappingsRef.compareAndSet(oldVal, newVal));
173         return matchesExpected ? expect : existing;
174     }
175
176     @SuppressWarnings("unchecked")
177     private static <T> Mapping<T>[] empty() {
178         return NO_MAPPINGS;
179     }
180
181     public void clear() {
182         mappingsRef.set(empty());
183     }
184
185     public int size() {
186         return mappingsRef.get().length;
187     }
188
189     public boolean isEmpty() {
190         return size() == 0;
191     }
192
193     public CidrAddressTable<T> clone() {
194         return new CidrAddressTable<>(mappingsRef.get());
195     }
196
197     public Iterator<Mapping<T>> iterator() {
198         final Mapping<T>[] mappings = mappingsRef.get();
199         return new Iterator<Mapping<T>>() {
200             int idx;
201
202             public boolean hasNext() {
203                 return idx < mappings.length;
204             }
205
206             public Mapping<T> next() {
207                 if (! hasNext()) throw new NoSuchElementException();
208                 return mappings[idx++];
209             }
210         };
211     }
212
213     public Spliterator<Mapping<T>> spliterator() {
214         final Mapping<T>[] mappings = mappingsRef.get();
215         return Spliterators.spliterator(mappings, Spliterator.IMMUTABLE | Spliterator.ORDERED);
216     }
217
218     public String toString() {
219         StringBuilder b = new StringBuilder();
220         final Mapping<T>[] mappings = mappingsRef.get();
221         b.append(mappings.length).append(" mappings");
222         for (final Mapping<T> mapping : mappings) {
223             b.append(System.lineSeparator()).append('\t').append(mapping.range);
224             if (mapping.parent != null) {
225                 b.append(" (parent ").append(mapping.parent.range).append(')');
226             }
227             b.append(" -> ").append(mapping.value);
228         }
229         return b.toString();
230     }
231
232     private int doFind(Mapping<T>[] table, byte[] bytes, int maskBits, final int scopeId) {
233         int low = 0;
234         int high = table.length - 1;
235
236         while (low <= high) {
237             // bisect the range
238             int mid = low + high >>> 1;
239
240             // compare the mapping at this location
241             Mapping<T> mapping = table[mid];
242             int cmp = mapping.range.compareAddressBytesTo(bytes, maskBits, scopeId);
243
244             if (cmp < 0) {
245                 // move to the latter half
246                 low = mid + 1;
247             } else if (cmp > 0) {
248                 // move to the former half
249                 high = mid - 1;
250             } else {
251                 // exact match is the best case
252                 return mid;
253             }
254         }
255         // return the point we would insert at (plus one, negated)
256         return -(low + 1);
257     }
258
259     private Mapping<T> doGet(Mapping<T>[] table, byte[] bytes, final int netmaskBits, final int scopeId) {
260         int idx = doFind(table, bytes, netmaskBits, scopeId);
261         if (idx >= 0) {
262             // exact match
263             assert table[idx].range.matches(bytes, scopeId);
264             return table[idx];
265         }
266         // check immediate predecessor if there is one
267         int pre = -idx - 2;
268         if (pre >= 0) {
269             if (table[pre].range.matches(bytes, scopeId)) {
270                 return table[pre];
271             }
272             // try parent
273             Mapping<T> parent = table[pre].parent;
274             while (parent != null) {
275                 if (parent.range.matches(bytes, scopeId)) {
276                     return parent;
277                 }
278                 parent = parent.parent;
279             }
280         }
281         return null;
282     }
283
284     /**
285      * A single mapping in the table.
286      *
287      * @param <T> the value type
288      */

289     public static final class Mapping<T> {
290         final CidrAddress range;
291         final T value;
292         final Mapping<T> parent;
293
294         Mapping(final CidrAddress range, final T value, final Mapping<T> parent) {
295             this.range = range;
296             this.value = value;
297             this.parent = parent;
298         }
299
300         Mapping<T> withNewParent(Mapping<T> newParent) {
301             return new Mapping<T>(range, value, newParent);
302         }
303
304         /**
305          * Get the address range of this entry.
306          *
307          * @return the address range of this entry (not {@code null})
308          */

309         public CidrAddress getRange() {
310             return range;
311         }
312
313         /**
314          * Get the stored value of this entry.
315          *
316          * @return the stored value of this entry
317          */

318         public T getValue() {
319             return value;
320         }
321
322         /**
323          * Get the parent of this entry, if any.
324          *
325          * @return the parent of this entry, or {@code nullif there is no parent
326          */

327         public Mapping<T> getParent() {
328             return parent;
329         }
330     }
331 }
332