tala-embed

SIMD-accelerated vector operations, quantization, and the HNSW approximate nearest neighbor index. This crate is the computational engine behind TALA's semantic search: it computes cosine similarity, dot products, and L2 distances using runtime-dispatched AVX2+FMA intrinsics on x86_64, with a scalar fallback that compiles everywhere. It also provides INT8 and FP16 quantization for storage compression and an HNSW index that delivers sub-millisecond top-k search over tens of thousands of vectors.

Key Types

TypeDescription
AlignedVec64-byte aligned f32 vector for SIMD operations
HnswIndexHierarchical Navigable Small World approximate nearest neighbor index

Key Modules

ModuleDescription
scalarPortable fallback: dot_product, cosine_similarity, l2_distance_sq, norm_sq
avx2AVX2+FMA implementations (x86_64 only, #[cfg(target_arch = "x86_64")])
quantizeQuantization: f32_to_int8, int8_to_f32, f32_to_f16, f16_to_f32

Dispatch Functions

FunctionDescription
cosine_similarity(a, b)Cosine similarity with runtime ISA dispatch
dot_product(a, b)Dot product with runtime ISA dispatch
l2_distance_sq(a, b)L2 distance squared with runtime ISA dispatch
batch_cosine(query, corpus, dim, results)Single-threaded batch cosine similarity
batch_cosine_parallel(query, corpus, dim, results)Rayon-parallel batch cosine similarity

AlignedVec

A heap-allocated f32 vector guaranteed to begin at a 64-byte aligned address. This alignment is required for optimal AVX-512 loads and beneficial for AVX2. AlignedVec manages its own memory through the global allocator with Layout::from_size_align(len * 4, 64).

#![allow(unused)]
fn main() {
pub struct AlignedVec {
    ptr: *mut f32,
    len: usize,
    cap: usize,
}
}

AlignedVec implements Send, Sync, Clone, Deref<Target = [f32]>, DerefMut, From<Vec<f32>>, and From<&[f32]>. It can be used anywhere a &[f32] is expected.

Methods

#![allow(unused)]
fn main() {
impl AlignedVec {
    /// Allocate a zero-initialized vector of `len` elements at 64-byte alignment.
    pub fn new(len: usize) -> Self;

    /// View as a shared float slice.
    pub fn as_slice(&self) -> &[f32];

    /// View as a mutable float slice.
    pub fn as_mut_slice(&mut self) -> &mut [f32];

    /// Number of elements.
    pub fn len(&self) -> usize;

    /// True if the vector contains no elements.
    pub fn is_empty(&self) -> bool;
}
}

Example

#![allow(unused)]
fn main() {
use tala_embed::AlignedVec;

let mut v = AlignedVec::new(384);
assert_eq!(v.len(), 384);
assert!(v.as_slice().as_ptr() as usize % 64 == 0);

// Populate from a Vec<f32>:
let data: Vec<f32> = (0..384).map(|i| i as f32).collect();
let aligned = AlignedVec::from(data);
assert_eq!(aligned[0], 0.0);
assert_eq!(aligned[383], 383.0);
}

scalar Module

Portable implementations of the core vector operations. These are used on architectures without AVX2 support and serve as the reference implementation for correctness testing.

#![allow(unused)]
fn main() {
pub mod scalar {
    /// Dot product of two equal-length slices.
    #[inline]
    pub fn dot_product(a: &[f32], b: &[f32]) -> f32;

    /// Squared L2 norm of a vector: sum of x_i^2.
    #[inline]
    pub fn norm_sq(a: &[f32]) -> f32;

    /// Cosine similarity: dot(a,b) / (||a|| * ||b|| + epsilon).
    #[inline]
    pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32;

    /// Squared L2 distance: sum of (a_i - b_i)^2.
    #[inline]
    pub fn l2_distance_sq(a: &[f32], b: &[f32]) -> f32;
}
}

avx2 Module

AVX2+FMA implementations available on x86_64. All functions are marked #[target_feature(enable = "avx2,fma")] and must be called through unsafe blocks. The dispatch functions handle this automatically.

#![allow(unused)]
fn main() {
#[cfg(target_arch = "x86_64")]
pub mod avx2 {
    /// Dot product -- 4-way unrolled (32 floats per iteration) to saturate
    /// FMA throughput. Four independent accumulator chains reduce the critical
    /// path 4x vs a single accumulator.
    #[target_feature(enable = "avx2,fma")]
    pub unsafe fn dot_product(a: &[f32], b: &[f32]) -> f32;

    /// Cosine similarity -- 2-way unrolled (6 accumulators: 2 dot + 2 norm_a
    /// + 2 norm_b). Fits within AVX2's 16 YMM register budget (10 live regs).
    #[target_feature(enable = "avx2,fma")]
    pub unsafe fn cosine_similarity(a: &[f32], b: &[f32]) -> f32;

    /// Squared L2 distance -- 4-way unrolled (sub then FMA).
    #[target_feature(enable = "avx2,fma")]
    pub unsafe fn l2_distance_sq(a: &[f32], b: &[f32]) -> f32;
}
}

Dispatch Functions

These top-level functions detect AVX2+FMA at runtime via is_x86_feature_detected! and dispatch to the fastest available implementation. On non-x86_64 architectures, they fall through directly to the scalar path. Call these from application code; never call the avx2 module functions directly.

#![allow(unused)]
fn main() {
/// Cosine similarity with automatic SIMD dispatch.
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32;

/// Dot product with automatic SIMD dispatch.
pub fn dot_product(a: &[f32], b: &[f32]) -> f32;

/// L2 distance squared with automatic SIMD dispatch.
pub fn l2_distance_sq(a: &[f32], b: &[f32]) -> f32;
}

Batch Operations

#![allow(unused)]
fn main() {
/// Compute cosine similarity of `query` against every vector in `corpus`.
/// `corpus` is a flat buffer: vector i = corpus[i*dim .. (i+1)*dim].
/// Results are written into `results[0..n]`.
pub fn batch_cosine(query: &[f32], corpus: &[f32], dim: usize, results: &mut [f32]);

/// Parallel batch cosine using Rayon. Processes corpus in chunks of 256
/// vectors per thread.
pub fn batch_cosine_parallel(query: &[f32], corpus: &[f32], dim: usize, results: &mut [f32]);
}

Example

#![allow(unused)]
fn main() {
use tala_embed::{cosine_similarity, batch_cosine};

let a = vec![1.0, 0.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert!((sim - 1.0).abs() < 1e-5);

// Batch: query against 3 corpus vectors of dim 4
let query = vec![1.0, 0.0, 0.0, 0.0];
let corpus = vec![
    1.0, 0.0, 0.0, 0.0,  // identical to query
    0.0, 1.0, 0.0, 0.0,  // orthogonal
    0.5, 0.5, 0.0, 0.0,  // partially similar
];
let mut results = vec![0.0f32; 3];
batch_cosine(&query, &corpus, 4, &mut results);
assert!(results[0] > results[2]); // identical > partial
assert!(results[2] > results[1]); // partial > orthogonal
}

quantize Module

Quantization routines for compressing f32 embeddings to INT8 or FP16. Useful for reducing storage footprint (4x for INT8, 2x for FP16) at the cost of some precision.

#![allow(unused)]
fn main() {
pub mod quantize {
    /// Symmetric f32 -> INT8 quantization.
    /// Returns (quantized_bytes, scale_factor). Dequantize with: f32 = i8 * scale.
    pub fn f32_to_int8(src: &[f32]) -> (Vec<i8>, f32);

    /// INT8 -> f32 dequantization.
    pub fn int8_to_f32(src: &[i8], scale: f32) -> Vec<f32>;

    /// f32 -> IEEE 754 half-precision (FP16), stored as u16.
    pub fn f32_to_f16(src: &[f32]) -> Vec<u16>;

    /// FP16 (u16) -> f32 conversion.
    pub fn f16_to_f32(src: &[u16]) -> Vec<f32>;
}
}

Example

#![allow(unused)]
fn main() {
use tala_embed::quantize::{f32_to_int8, int8_to_f32};

let original = vec![0.5, -0.3, 0.9, 0.0];
let (quantized, scale) = f32_to_int8(&original);
let recovered = int8_to_f32(&quantized, scale);

for (o, r) in original.iter().zip(recovered.iter()) {
    assert!((o - r).abs() < 0.01);
}
}

HnswIndex

A Hierarchical Navigable Small World index for approximate nearest neighbor search. Uses L2 distance (via cached norms: ||a||^2 + ||b||^2 - 2*dot(a,b)) internally and returns results sorted by L2 distance. All stored vectors are kept in AlignedVec for SIMD-friendly access.

The implementation uses generation-based visited tracking (a Vec<u32> with a monotonic generation counter) to avoid per-search HashSet allocation. Visited state resets in O(1) by incrementing the generation.

#![allow(unused)]
fn main() {
pub struct HnswIndex { /* private */ }

impl HnswIndex {
    /// Create a new index with the given dimensionality, connectivity `m`,
    /// and construction beam width `ef_construction`. Uses seed 42 for the RNG.
    pub fn new(dim: usize, m: usize, ef_construction: usize) -> Self;

    /// Create an index with an explicit RNG seed for deterministic builds.
    pub fn with_seed(dim: usize, m: usize, ef_construction: usize, seed: u64) -> Self;

    /// Insert a vector into the index. Returns the vector's internal index.
    pub fn insert(&mut self, vector: Vec<f32>) -> usize;

    /// Search for the `k` nearest neighbors of `query`.
    /// `ef` controls the search beam width (higher = more accurate, slower).
    /// Returns (internal_index, L2_distance) pairs sorted by distance ascending.
    pub fn search(&mut self, query: &[f32], k: usize, ef: usize) -> Vec<(usize, f32)>;

    /// Access a stored vector by its internal index.
    #[inline]
    pub fn get_vector(&self, idx: usize) -> &[f32];

    /// Number of vectors in the index.
    pub fn len(&self) -> usize;

    /// True if the index contains no vectors.
    pub fn is_empty(&self) -> bool;
}
}

Parameters

ParameterTypical ValueEffect
m16Maximum connections per node per layer. Higher values improve recall at the cost of memory and insert time.
ef_construction200Beam width during insertion. Higher values build a better graph but slow down construction.
ef (search)50Beam width during search. Must be >= k. Higher values improve recall at the cost of latency.

Example

#![allow(unused)]
fn main() {
use tala_embed::HnswIndex;

let dim = 4;
let mut index = HnswIndex::new(dim, 16, 200);

// Insert 100 vectors
for i in 0..100 {
    let v = vec![i as f32; dim];
    index.insert(v);
}

// Search for the 5 nearest neighbors of a query
let query = vec![50.0; dim];
let results = index.search(&query, 5, 50);

assert_eq!(results.len(), 5);
// Results are sorted by L2 distance ascending
for window in results.windows(2) {
    assert!(window[0].1 <= window[1].1);
}
}

Performance

The HNSW index delivers sub-millisecond search latency at 10K vectors with ef=50. Measured benchmarks:

OperationCorpus SizeTime
Search (top-10, ef=50)10K vectors, dim=384139 us
Semantic query (full pipeline)10K vectors, dim=384151 us