|
| 1 | +/* |
| 2 | + * SPDX-License-Identifier: Apache-2.0 |
| 3 | + * |
| 4 | + * The OpenSearch Contributors require contributions made to |
| 5 | + * this file be licensed under the Apache-2.0 license or a |
| 6 | + * compatible open source license. |
| 7 | + */ |
| 8 | + |
| 9 | +package org.opensearch.common.round; |
| 10 | + |
| 11 | +import org.opensearch.common.annotation.InternalApi; |
| 12 | + |
| 13 | +import jdk.incubator.vector.LongVector; |
| 14 | +import jdk.incubator.vector.Vector; |
| 15 | +import jdk.incubator.vector.VectorOperators; |
| 16 | +import jdk.incubator.vector.VectorSpecies; |
| 17 | + |
| 18 | +/** |
| 19 | + * It uses vectorized B-tree search to find the round-down point. |
| 20 | + * |
| 21 | + * @opensearch.internal |
| 22 | + */ |
| 23 | +@InternalApi |
| 24 | +class BtreeSearcher implements Roundable { |
| 25 | + private static final VectorSpecies<Long> LONG_VECTOR_SPECIES = LongVector.SPECIES_PREFERRED; |
| 26 | + private static final int LANES = LONG_VECTOR_SPECIES.length(); |
| 27 | + private static final int SHIFT = log2(LANES); |
| 28 | + |
| 29 | + private final long[] values; |
| 30 | + private final long minValue; |
| 31 | + |
| 32 | + BtreeSearcher(long[] values, int size) { |
| 33 | + if (size <= 0) { |
| 34 | + throw new IllegalArgumentException("at least one value must be present"); |
| 35 | + } |
| 36 | + |
| 37 | + int blocks = (size + LANES - 1) / LANES; // number of blocks |
| 38 | + int length = 1 + blocks * LANES; // size of the backing array (1-indexed) |
| 39 | + |
| 40 | + this.minValue = values[0]; |
| 41 | + this.values = new long[length]; |
| 42 | + build(values, 0, size, this.values, 1); |
| 43 | + } |
| 44 | + |
| 45 | + /** |
| 46 | + * Builds the B-tree memory layout. |
| 47 | + * It builds the tree recursively, following an in-order traversal. |
| 48 | + * |
| 49 | + * <p> |
| 50 | + * Each block stores 'lanes' values at indices {@code i, i + 1, ..., i + lanes - 1} where {@code i} is the |
| 51 | + * starting offset. The starting offset of the root block is 1. The branching factor is (1 + lanes) so each |
| 52 | + * block can have these many children. Given the starting offset {@code i} of a block, the starting offset |
| 53 | + * of its k-th child (ranging from {@code 0, 1, ..., k}) can be computed as {@code i + ((i + k) << shift)}. |
| 54 | + * |
| 55 | + * @param src is the sorted input array |
| 56 | + * @param i is the index in the input array to read the value from |
| 57 | + * @param size the number of values in the input array |
| 58 | + * @param dst is the output array |
| 59 | + * @param j is the index in the output array to write the value to |
| 60 | + * @return the next index 'i' |
| 61 | + */ |
| 62 | + private static int build(long[] src, int i, int size, long[] dst, int j) { |
| 63 | + if (j < dst.length) { |
| 64 | + for (int k = 0; k < LANES; k++) { |
| 65 | + i = build(src, i, size, dst, j + ((j + k) << SHIFT)); |
| 66 | + |
| 67 | + // Fills the B-tree as a complete tree, i.e., all levels are completely filled, |
| 68 | + // except the last level which is filled from left to right. |
| 69 | + // The trick is to fill the destination array between indices 1...size (inclusive / 1-indexed) |
| 70 | + // and pad the remaining array with +infinity. |
| 71 | + dst[j + k] = (j + k <= size) ? src[i++] : Long.MAX_VALUE; |
| 72 | + } |
| 73 | + i = build(src, i, size, dst, j + ((j + LANES) << SHIFT)); |
| 74 | + } |
| 75 | + return i; |
| 76 | + } |
| 77 | + |
| 78 | + @Override |
| 79 | + public long floor(long key) { |
| 80 | + Vector<Long> keyVector = LongVector.broadcast(LONG_VECTOR_SPECIES, key); |
| 81 | + int i = 1, result = 1; |
| 82 | + |
| 83 | + while (i < values.length) { |
| 84 | + Vector<Long> valuesVector = LongVector.fromArray(LONG_VECTOR_SPECIES, values, i); |
| 85 | + int j = i + valuesVector.compare(VectorOperators.GT, keyVector).firstTrue(); |
| 86 | + result = (j > i) ? j : result; |
| 87 | + i += (j << SHIFT); |
| 88 | + } |
| 89 | + |
| 90 | + assert result > 1 : "key must be greater than or equal to " + minValue; |
| 91 | + return values[result - 1]; |
| 92 | + } |
| 93 | + |
| 94 | + private static int log2(int num) { |
| 95 | + if ((num <= 0) || ((num & (num - 1)) != 0)) { |
| 96 | + throw new IllegalArgumentException(num + " is not a positive power of 2"); |
| 97 | + } |
| 98 | + return 32 - Integer.numberOfLeadingZeros(num - 1); |
| 99 | + } |
| 100 | +} |
0 commit comments