I have recently been working on 4-bit PQ, from Accelerated Nearest Neighbor Search with Quick ADC. Intuitively, 4-bit PQ should be much faster than 8-bit PQ, but after implementing it end to end I found quite a few traps.
Product Quantization
Start with ordinary PQ. Product quantization splits a D-dimensional vector into M equal-sized sub-vectors. For each sub-vector position, it runs k-means over all vectors and produces 2^B centroids. Each original vector is then represented by the id of the nearest centroid for every sub-vector.
Here B is the number of bits. It is usually 8, for a simple reason: one sub-vector code can be stored in a u8, which gives a good compression ratio. When people talk about PQ, they usually mean B = 8, which compresses D f32 values into M u8 values.
At query time, to compute the distance from a query vector q to all vectors in the dataset, we first build a distance table. Since every quantized sub-vector is represented by a centroid id, the table precomputes the distance from q to every centroid:
distance_table[i][j] = distance from q to the j-th centroid of the i-th block
For a compressed vector code, the distance is:
sum over j in [0, M): distance_table[j][code[j]]
Transposing
I first saw this idea in Accelerated Nearest Neighbor Search with Quick ADC, but the same optimization also applies to 8-bit PQ and works very well.
Consider the straightforward way to compute distances for compressed vectors:
let mut dists = vec![0.0f32; n]
for i in 0..n {
for j in 0..M {
dists[i] += distance_table[j][code[i][j]]
}
}
The problem is that accesses to the distance table are effectively random, so the computation becomes memory-bound. If we switch the loop order, things get much better:
let mut dists = vec![0.0f32; n]
for j in 0..M {
for i in 0..n {
dists[i] += distance_table[j][code[i][j]]
}
}
The second dimension of the distance table is 2^8 = 256 f32 values, or 1 KiB, so locality is still good. The problem is that access to code becomes random. That is easy to fix by transposing the code matrix:
let mut dists = vec![0.0f32; n]
for j in 0..M {
for i in 0..n {
dists[i] += distance_table[j][code[j][i]]
}
}
This improves locality and also gives the compiler a much better chance to generate efficient code. The performance gain is significant.
4-bit PQ
Back to 4-bit PQ. Based on the optimization above, we can go one step further. Registers are faster than cache. Ignoring register size for a moment, the computation is equivalent to this pseudo-code:
let mut dists = vec![0.0f32; n]
for j in 0..M {
let num_centroids = 2.pow(B);
for i in (0..n).step_by(num_centroids) {
let shuffled = shuffle(distance_table[j], code[j][i..i+num_centroids]);
dists[i..i+num_centroids] += shuffled;
}
}
In other words, the core computation can be implemented with shuffle plus add. For 8-bit PQ, distance_table[i] has 256 f32 values, or 1 KiB. Current registers are at most 512 bits on x86 and 128 bits on ARM. To make the table fit, the paper uses two tricks:
- Use 4 bits instead of 8 bits, so there are only 2^4 = 16 centroids. Sixteen f32 values are still 512 bits, so this is not enough by itself.
- Quantize the distances to u8 by scalar quantization. Sixteen u8 values fit exactly in 128 bits.
Implementation
The idea is straightforward up to this point, but distance quantization confused me a lot in the actual implementation. In the paper, dists[i..i+num_centroids] is also a u8x16, which means the whole accumulation process must avoid u8 overflow. How is that possible?
The simplest solution is to set the scalar quantization max value to the sum of the whole distance table. Unsurprisingly, this destroys recall.
I went back to the paper to see what the authors did, and the answer is a good reminder of the gap between papers and production implementations. They first run a brute-force search, for example top-200, then use the farthest distance as the max value. During accumulation they use saturating add. If the query is top-100, the results that can actually make it into the final answer should not overflow. This is too impractical for my use case, so I did not use it.
In the end, I gave up on keeping the result in u8x16 all the way through and wrote intermediate results back to memory. Because 4-bit PQ packs two blocks into one byte, I needed to store the sum of two blocks in u8x16. That means I only need the maximum sum of neighboring blocks in the distance table as the max value for scalar quantization. This halves the number of memory writes, so performance is still good, and more importantly, recall loss is small.