Distributed Sample Sort

Sort data across machines by sampling keys, choosing splitters, redistributing records into ordered buckets, and sorting buckets locally.

Distributed Sample Sort

Distributed sample sort sorts a dataset spread across multiple machines. It uses sampled keys to choose global splitters. These splitters divide the key space into ordered buckets. Each machine sends records to the bucket owner, each bucket is sorted locally, and the sorted buckets form the final output.

The main purpose of sampling is load balance. Good splitters make every worker receive roughly the same amount of data.

Problem

Given $n$ records distributed across $p$ workers, sort all records by key in nondecreasing order.

The output may remain distributed as $p$ sorted partitions.

Algorithm

distributed_sample_sort(workers):
    each worker samples local keys
    gather all samples
    sort samples
    choose p - 1 splitters

    broadcast splitters to all workers

    each worker partitions local records into p buckets

    all_to_all exchange buckets

    each worker sorts received bucket

    return workers in splitter order

The splitters satisfy:

$$ s_1 \le s_2 \le \cdots \le s_{p-1} $$

Worker $0$ receives the smallest key range. Worker $p - 1$ receives the largest key range.

Bucket Assignment

Each key is assigned using binary search over the splitters.

bucket_id(key, splitters):
    return upper_bound(splitters, key)

This gives a value from $0$ to $p - 1$.

Communication

After local partitioning, workers perform an all to all exchange.

for each source worker u:
    for each target worker v:
        send bucket[u][v] to worker v

Each target worker receives records from all workers for one key interval.

Complexity

measure value
local sampling $O(n/p)$ per worker
local partitioning $O((n/p)\log p)$
communication volume $O(n)$ records
local sorting expected $O((n/p)\log(n/p))$
output partitions $p$ sorted ranges

The wall clock cost is often dominated by network exchange and skew.

Correctness

Splitters divide the global key space into ordered intervals. All records assigned to worker $i$ are less than or equal to all records assigned to worker $j$ when $i < j$. Each worker sorts its own received records. Therefore, the workers hold sorted partitions in global key order.

Reading worker outputs from $0$ to $p - 1$ gives the complete sorted dataset.

Practical Considerations

  • Oversampling improves load balance.
  • Skewed or duplicate heavy data may overload one worker.
  • All to all exchange can stress the network.
  • Compression may reduce transfer cost.
  • Local sort can use radix sort, quicksort, merge sort, or external sort.
  • Output is usually partitioned, not physically concatenated.

When to Use

Use distributed sample sort when:

  • data is too large for one machine
  • many workers are available
  • sorted distributed partitions are acceptable
  • sampling can approximate the key distribution

Avoid it when communication cost dominates or the key distribution causes severe bucket skew.

Implementation Sketch

local_sample(records, sample_count):
    return evenly_spaced_sample(records, sample_count)
choose_splitters(samples, p):
    sort samples
    splitters = []

    for i from 1 to p - 1:
        splitters.append(samples[i * length(samples) / p])

    return splitters
worker_sort(local_records, splitters):
    buckets = array of p empty lists

    for record in local_records:
        b = upper_bound(splitters, key(record))
        buckets[b].append(record)

    send bucket b to worker b

    received = receive buckets from all workers
    sort received by key

    return received

Simplified Python Model

from bisect import bisect_right
from collections import defaultdict

def distributed_sample_sort(records_by_worker, splitters):
    p = len(records_by_worker)
    inbox = [list() for _ in range(p)]

    for local_records in records_by_worker:
        for record in local_records:
            key = record[0]
            b = bisect_right(splitters, key)
            inbox[b].append(record)

    for b in range(p):
        inbox[b].sort(key=lambda x: x[0])

    return inbox