1
18
19 package org.wildfly.common.net;
20
21 import java.net.Inet4Address;
22 import java.net.InetAddress;
23 import java.util.Arrays;
24 import java.util.Iterator;
25 import java.util.NoSuchElementException;
26 import java.util.Objects;
27 import java.util.Spliterator;
28 import java.util.Spliterators;
29 import java.util.concurrent.atomic.AtomicReference;
30
31 import org.wildfly.common.Assert;
32
33
38 public final class CidrAddressTable<T> implements Iterable<CidrAddressTable.Mapping<T>> {
39
40 @SuppressWarnings("rawtypes")
41 private static final Mapping[] NO_MAPPINGS = new Mapping[0];
42
43 private final AtomicReference<Mapping<T>[]> mappingsRef;
44
45 public CidrAddressTable() {
46 mappingsRef = new AtomicReference<>(empty());
47 }
48
49 private CidrAddressTable(Mapping<T>[] mappings) {
50 mappingsRef = new AtomicReference<>(mappings);
51 }
52
53 public T getOrDefault(InetAddress address, T defVal) {
54 Assert.checkNotNullParam("address", address);
55 final Mapping<T> mapping = doGet(mappingsRef.get(), address.getAddress(), address instanceof Inet4Address ? 32 : 128, Inet.getScopeId(address));
56 return mapping == null ? defVal : mapping.value;
57 }
58
59 public T get(InetAddress address) {
60 return getOrDefault(address, null);
61 }
62
63 public T put(CidrAddress block, T value) {
64 Assert.checkNotNullParam("block", block);
65 Assert.checkNotNullParam("value", value);
66 return doPut(block, null, value, true, true);
67 }
68
69 public T putIfAbsent(CidrAddress block, T value) {
70 Assert.checkNotNullParam("block", block);
71 Assert.checkNotNullParam("value", value);
72 return doPut(block, null, value, true, false);
73 }
74
75 public T replaceExact(CidrAddress block, T value) {
76 Assert.checkNotNullParam("block", block);
77 Assert.checkNotNullParam("value", value);
78 return doPut(block, null, value, false, true);
79 }
80
81 public boolean replaceExact(CidrAddress block, T expect, T update) {
82 Assert.checkNotNullParam("block", block);
83 Assert.checkNotNullParam("expect", expect);
84 Assert.checkNotNullParam("update", update);
85 return doPut(block, expect, update, false, true) == expect;
86 }
87
88 public T removeExact(CidrAddress block) {
89 Assert.checkNotNullParam("block", block);
90 return doPut(block, null, null, false, true);
91 }
92
93 public boolean removeExact(CidrAddress block, T expect) {
94 Assert.checkNotNullParam("block", block);
95 return doPut(block, expect, null, false, true) == expect;
96 }
97
98 private T doPut(final CidrAddress block, final T expect, final T update, final boolean putIfAbsent, final boolean putIfPresent) {
99 assert putIfAbsent || putIfPresent;
100 final AtomicReference<Mapping<T>[]> mappingsRef = this.mappingsRef;
101 final byte[] bytes = block.getNetworkAddress().getAddress();
102 Mapping<T>[] oldVal, newVal;
103 int idx;
104 T existing;
105 boolean matchesExpected;
106 do {
107 oldVal = mappingsRef.get();
108 idx = doFind(oldVal, bytes, block.getNetmaskBits(), block.getScopeId());
109 if (idx < 0) {
110 if (! putIfAbsent) {
111 return null;
112 }
113 existing = null;
114 } else {
115 existing = oldVal[idx].value;
116 }
117 if (expect != null) {
118 matchesExpected = Objects.equals(expect, existing);
119 if (! matchesExpected) {
120 return existing;
121 }
122 } else {
123 matchesExpected = false;
124 }
125 if (idx >= 0 && ! putIfPresent) {
126 return existing;
127 }
128
129 final int oldLen = oldVal.length;
130 if (update == null) {
131 assert idx >= 0;
132
133 if (oldLen == 1) {
134 newVal = empty();
135 } else {
136 final Mapping<T> removing = oldVal[idx];
137 newVal = Arrays.copyOf(oldVal, oldLen - 1);
138 System.arraycopy(oldVal, idx + 1, newVal, idx, oldLen - idx - 1);
139
140 for (int i = 0; i < oldLen - 1; i ++) {
141 if (newVal[i].parent == removing) {
142 newVal[i] = newVal[i].withNewParent(removing.parent);
143 }
144 }
145 }
146 } else if (idx >= 0) {
147
148 newVal = oldVal.clone();
149 final Mapping<T> oldMapping = oldVal[idx];
150 final Mapping<T> newMapping = new Mapping<>(block, update, oldVal[idx].parent);
151 newVal[idx] = newMapping;
152
153 for (int i = 0; i < oldLen; i ++) {
154 if (i != idx && newVal[i].parent == oldMapping) {
155 newVal[i] = newVal[i].withNewParent(newMapping);
156 }
157 }
158 } else {
159
160 newVal = Arrays.copyOf(oldVal, oldLen + 1);
161 final Mapping<T> newMappingParent = doGet(oldVal, bytes, block.getNetmaskBits(), block.getScopeId());
162 final Mapping<T> newMapping = new Mapping<>(block, update, newMappingParent);
163 newVal[-idx - 1] = newMapping;
164 System.arraycopy(oldVal, -idx - 1, newVal, -idx, oldLen + idx + 1);
165
166 for (int i = 0; i <= oldLen; i++) {
167 if (newVal[i] != newMapping && newVal[i].parent == newMappingParent && block.matches(newVal[i].range)) {
168 newVal[i] = newVal[i].withNewParent(newMapping);
169 }
170 }
171 }
172 } while (! mappingsRef.compareAndSet(oldVal, newVal));
173 return matchesExpected ? expect : existing;
174 }
175
176 @SuppressWarnings("unchecked")
177 private static <T> Mapping<T>[] empty() {
178 return NO_MAPPINGS;
179 }
180
181 public void clear() {
182 mappingsRef.set(empty());
183 }
184
185 public int size() {
186 return mappingsRef.get().length;
187 }
188
189 public boolean isEmpty() {
190 return size() == 0;
191 }
192
193 public CidrAddressTable<T> clone() {
194 return new CidrAddressTable<>(mappingsRef.get());
195 }
196
197 public Iterator<Mapping<T>> iterator() {
198 final Mapping<T>[] mappings = mappingsRef.get();
199 return new Iterator<Mapping<T>>() {
200 int idx;
201
202 public boolean hasNext() {
203 return idx < mappings.length;
204 }
205
206 public Mapping<T> next() {
207 if (! hasNext()) throw new NoSuchElementException();
208 return mappings[idx++];
209 }
210 };
211 }
212
213 public Spliterator<Mapping<T>> spliterator() {
214 final Mapping<T>[] mappings = mappingsRef.get();
215 return Spliterators.spliterator(mappings, Spliterator.IMMUTABLE | Spliterator.ORDERED);
216 }
217
218 public String toString() {
219 StringBuilder b = new StringBuilder();
220 final Mapping<T>[] mappings = mappingsRef.get();
221 b.append(mappings.length).append(" mappings");
222 for (final Mapping<T> mapping : mappings) {
223 b.append(System.lineSeparator()).append('\t').append(mapping.range);
224 if (mapping.parent != null) {
225 b.append(" (parent ").append(mapping.parent.range).append(')');
226 }
227 b.append(" -> ").append(mapping.value);
228 }
229 return b.toString();
230 }
231
232 private int doFind(Mapping<T>[] table, byte[] bytes, int maskBits, final int scopeId) {
233 int low = 0;
234 int high = table.length - 1;
235
236 while (low <= high) {
237
238 int mid = low + high >>> 1;
239
240
241 Mapping<T> mapping = table[mid];
242 int cmp = mapping.range.compareAddressBytesTo(bytes, maskBits, scopeId);
243
244 if (cmp < 0) {
245
246 low = mid + 1;
247 } else if (cmp > 0) {
248
249 high = mid - 1;
250 } else {
251
252 return mid;
253 }
254 }
255
256 return -(low + 1);
257 }
258
259 private Mapping<T> doGet(Mapping<T>[] table, byte[] bytes, final int netmaskBits, final int scopeId) {
260 int idx = doFind(table, bytes, netmaskBits, scopeId);
261 if (idx >= 0) {
262
263 assert table[idx].range.matches(bytes, scopeId);
264 return table[idx];
265 }
266
267 int pre = -idx - 2;
268 if (pre >= 0) {
269 if (table[pre].range.matches(bytes, scopeId)) {
270 return table[pre];
271 }
272
273 Mapping<T> parent = table[pre].parent;
274 while (parent != null) {
275 if (parent.range.matches(bytes, scopeId)) {
276 return parent;
277 }
278 parent = parent.parent;
279 }
280 }
281 return null;
282 }
283
284
289 public static final class Mapping<T> {
290 final CidrAddress range;
291 final T value;
292 final Mapping<T> parent;
293
294 Mapping(final CidrAddress range, final T value, final Mapping<T> parent) {
295 this.range = range;
296 this.value = value;
297 this.parent = parent;
298 }
299
300 Mapping<T> withNewParent(Mapping<T> newParent) {
301 return new Mapping<T>(range, value, newParent);
302 }
303
304
309 public CidrAddress getRange() {
310 return range;
311 }
312
313
318 public T getValue() {
319 return value;
320 }
321
322
327 public Mapping<T> getParent() {
328 return parent;
329 }
330 }
331 }
332