ウィキソートを使用する

ウィキソート (WikiSort) は、Pok-Son Kim と Arne Kutzner の比率に基づく安定インプレースマージを土台にした ブロックマージソートblock merge sort)である。

ボトムアップのマージソートの枠組みのなかで、大きな区間同士を併合するときに配列内の小さなバッファと回転(rotate)を組み合わせ、追加記憶領域を O(1)(固定サイズのキャッシュ程度)に抑えつつ安定に整列する。

  1. 初期整列: 長さ 4〜8 程度の小さな区間を、安定なソートネットワーク(または挿入ソート)で整える。
  2. レベルごとのマージ: 隣接する整列済み部分列 (A, B) のペアを、区間長が配列全体になるまで段階的に倍化していく。
  3. キャッシュ併合: 部分列がキャッシュ(例: 512 要素)に収まるレベルでは、通常のマージと同様にキャッシュへ退避して併合する。
  4. ブロック併合: それより大きいレベルでは、区間長の平方根程度のブロックに分割し、内部バッファから取り出した一意な値で A ブロックにタグを付け、B ブロック群のなかへ回転させながら位置を決める。その後、各 A ブロックと続く B 値をマージする。
  5. バッファの復元: 一時的に退避した内部バッファを、挿入ソートと再配置で元の位置へ戻す。
procedure wiki_sort(A)
  if length(A) < 4 then
    insertion_sort(A)
    return
  sort_runs_of_size_4_to_8(A)
  level = 4
  while level < length(A)
    for each adjacent pair (A_part, B_part) of sorted runs of length level
      if level fits in fixed_cache then
        merge_with_cache(A_part, B_part, fixed_cache)
      else
        block_merge_in_place(A_part, B_part, fixed_cache)
    level = next_merge_level(level)

最悪時間計算量は O(n log n)。安定ソートであり、等しいキーの相対順序を保つ。

マージソートのように O(n) の補助配列を常に要求しない点、クイックソートのように最悪 O(n²) に落ちない点が、インプレースかつ安定な整列を求める場面での利点になる。

実装は回転とブロック操作が多く、マージソート単体よりコード量は増える。

デモは、ウィキソートの小さなラン整列とボトムアップ併合の流れを追いやすくするため、キャッシュに収まるレベルでのマージ(マージソートと同型の併合)までを可視化している。

配列が大きくなりキャッシュを超えるレベルでは、本番のウィキソートはブロック分割・回転・内部バッファによるインプレース併合へ切り替わる。

マージソートと比べ追加配列を抑えられる一方、クイックソートと比べ安定である点が、ライブラリ実装の stable_sort に近い設計思想と言える。

計算時間量および空間計算量を計測する

Size Average time Maximum time Average memory Maximum memory
256 0.000008 0.000505 1663 1668
512 0.000017 0.000277 1666 1672
1024 0.000043 0.003383 1682 1688
2048 0.000085 0.000585 1706 1712
4096 0.000167 0.001429 1755 1760
8192 0.000347 0.000722 1850 1856
16384 0.000795 0.004058 2046 2052
32768 0.001752 0.008141 2308 2312
65536 0.003497 0.007235 3314 3316
131072 0.007856 0.041720 5103 5108
262144 0.016452 0.045890 8687 8692
計測に使用したコードを表示する

set -euo pipefail

WORKDIR="$(mktemp -d)"
trap 'rm -rf "$WORKDIR"' EXIT

cat > "$WORKDIR/Dockerfile" <<'EOF'
FROM rust:1.95.0

WORKDIR /app

RUN mkdir -p src

RUN cat > Cargo.toml <<'CARGO'
[package]
name = "rust-benchmark"
version = "0.1.0"
edition = "2021"

[profile.release]
lto = true
codegen-units = 1
panic = "abort"
CARGO

RUN cat > src/main.rs <<'RUST'
use std::{
    env,
    process::Command,
    time::{Duration, Instant},
};
const MIN_POWER: u32 = 8;
const MAX_POWER: u32 = 18;
const RUNS: usize = 8192;
const CACHE_SIZE: usize = 512;

#[derive(Clone, Copy)]
struct Range {
    start: usize,
    end: usize,
}

impl Range {
    fn new(start: usize, end: usize) -> Self {
        Self { start, end }
    }

    fn len(self) -> usize {
        self.end - self.start
    }
}

struct WikiIterator {
    size: usize,
    power_of_two: usize,
    numerator: usize,
    decimal: usize,
    denominator: usize,
    decimal_step: usize,
    numerator_step: usize,
}

impl WikiIterator {
    fn new(size: usize, min_level: usize) -> Self {
        let power_of_two = floor_power_of_two(size);
        let denominator = power_of_two / min_level;
        Self {
            size,
            power_of_two,
            numerator: 0,
            decimal: 0,
            denominator,
            decimal_step: size / denominator,
            numerator_step: size % denominator,
        }
    }

    fn begin(&mut self) {
        self.numerator = 0;
        self.decimal = 0;
    }

    fn next_range(&mut self) -> Range {
        let start = self.decimal;
        self.decimal += self.decimal_step;
        self.numerator += self.numerator_step;
        if self.numerator >= self.denominator {
            self.numerator -= self.denominator;
            self.decimal += 1;
        }
        Range::new(start, self.decimal)
    }

    fn finished(&self) -> bool {
        self.decimal >= self.size
    }

    fn next_level(&mut self) -> bool {
        self.decimal_step += self.decimal_step;
        self.numerator_step += self.numerator_step;
        if self.numerator_step >= self.denominator {
            self.numerator_step -= self.denominator;
            self.decimal_step += 1;
        }
        self.decimal_step < self.size
    }

    fn length(&self) -> usize {
        self.decimal_step
    }
}

fn floor_power_of_two(value: usize) -> usize {
    let mut x = value;
    x |= x >> 1;
    x |= x >> 2;
    x |= x >> 4;
    x |= x >> 8;
    x |= x >> 16;
    #[cfg(target_pointer_width = "64")]
    {
        x |= x >> 32;
    }
    x - (x >> 1)
}

fn wiki_insertion_sort(a: &mut [usize], range: Range) {
    for i in range.start + 1..range.end {
        let temp = a[i];
        let mut j = i;
        while j > range.start && temp < a[j - 1] {
            a[j] = a[j - 1];
            j -= 1;
        }
        a[j] = temp;
    }
}

fn reverse(a: &mut [usize], range: Range) {
    let len = range.len();
    for index in (0..len / 2).rev() {
        a.swap(range.start + index, range.end - index - 1);
    }
}

fn rotate(a: &mut [usize], amount: usize, range: Range, cache: &mut [usize], cache_size: usize) {
    if range.len() == 0 {
        return;
    }
    let split = range.start + amount;
    let range1 = Range::new(range.start, split);
    let range2 = Range::new(split, range.end);
    if range1.len() <= range2.len() {
        if range1.len() <= cache_size {
            cache[..range1.len()].copy_from_slice(&a[range1.start..range1.end]);
            a.copy_within(range2.start..range2.end, range1.start);
            a[range1.start + range2.len()..range1.start + range2.len() + range1.len()]
                .copy_from_slice(&cache[..range1.len()]);
            return;
        }
    } else if range2.len() <= cache_size {
        cache[..range2.len()].copy_from_slice(&a[range2.start..range2.end]);
        a.copy_within(range1.start..range1.end, range2.end - range1.len());
        a[range1.start..range1.start + range2.len()].copy_from_slice(&cache[..range2.len()]);
        return;
    }
    reverse(a, range1);
    reverse(a, range2);
    reverse(a, range);
}

fn merge_into(from: &[usize], a: Range, b: Range, into: &mut [usize]) {
    let mut a_index = a.start;
    let mut b_index = b.start;
    let mut insert = 0;
    loop {
        if from[b_index] >= from[a_index] {
            into[insert] = from[a_index];
            a_index += 1;
            insert += 1;
            if a_index == a.end {
                into[insert..insert + b.end - b_index].copy_from_slice(&from[b_index..b.end]);
                break;
            }
        } else {
            into[insert] = from[b_index];
            b_index += 1;
            insert += 1;
            if b_index == b.end {
                into[insert..insert + a.end - a_index].copy_from_slice(&from[a_index..a.end]);
                break;
            }
        }
    }
}

fn merge_external(a: &mut [usize], a_range: Range, b: Range, cache: &mut [usize]) {
    cache[..a_range.len()].copy_from_slice(&a[a_range.start..a_range.end]);
    let mut a_index = 0;
    let mut b_index = b.start;
    let mut insert = a_range.start;
    let a_last = a_range.len();
    let b_last = b.end;
    if b.len() > 0 && a_range.len() > 0 {
        loop {
            if a[b_index] >= cache[a_index] {
                a[insert] = cache[a_index];
                a_index += 1;
                insert += 1;
                if a_index == a_last {
                    break;
                }
            } else {
                a[insert] = a[b_index];
                b_index += 1;
                insert += 1;
                if b_index == b_last {
                    break;
                }
            }
        }
    }
    a[insert..insert + a_last - a_index].copy_from_slice(&cache[a_index..a_last]);
}

fn merge_pair(
    a: &mut [usize],
    a_range: Range,
    b: Range,
    cache: &mut [usize],
    cache_size: usize,
) {
    if a[b.end - 1] < a[a_range.start] {
        rotate(
            a,
            a_range.len(),
            Range::new(a_range.start, b.end),
            cache,
            cache_size,
        );
    } else if a[b.start] < a[a_range.end - 1] {
        if a_range.len() + b.len() <= cache_size {
            cache[..a_range.len()].copy_from_slice(&a[a_range.start..a_range.end]);
            merge_external(a, a_range, b, cache);
        } else {
            let mut merged = Vec::with_capacity(a_range.len() + b.len());
            let (mut i, mut j) = (a_range.start, b.start);
            while i < a_range.end && j < b.end {
                if a[i] <= a[j] {
                    merged.push(a[i]);
                    i += 1;
                } else {
                    merged.push(a[j]);
                    j += 1;
                }
            }
            merged.extend_from_slice(&a[i..a_range.end]);
            merged.extend_from_slice(&a[j..b.end]);
            a[a_range.start..b.end].copy_from_slice(&merged);
        }
    }
}

fn wiki_sort(a: &mut [usize]) {
    let size = a.len();
    let mut cache = [0usize; CACHE_SIZE];
    let cache_size = CACHE_SIZE;

    if size < 4 {
        if size == 3 {
            if a[1] < a[0] {
                a.swap(0, 1);
            }
            if a[2] < a[1] {
                a.swap(1, 2);
                if a[1] < a[0] {
                    a.swap(0, 1);
                }
            }
        } else if size == 2 && a[1] < a[0] {
            a.swap(0, 1);
        }
        return;
    }

    let mut iterator = WikiIterator::new(size, 4);
    iterator.begin();
    while !iterator.finished() {
        let range = iterator.next_range();
        wiki_insertion_sort(a, range);
    }
    if size < 8 {
        return;
    }

    loop {
        if iterator.length() < cache_size {
            if (iterator.length() + 1) * 4 <= cache_size && iterator.length() * 4 <= size {
                iterator.begin();
                while !iterator.finished() {
                    let a1 = iterator.next_range();
                    let b1 = iterator.next_range();
                    let a2 = iterator.next_range();
                    let b2 = iterator.next_range();
                    let mut merged1_len = 0usize;
                    let mut merged2_len = 0usize;
                    if a[b1.end - 1] < a[a1.start] {
                        cache[b1.len()..b1.len() + a1.len()].copy_from_slice(&a[a1.start..a1.end]);
                        cache[..b1.len()].copy_from_slice(&a[b1.start..b1.end]);
                        merged1_len = a1.len() + b1.len();
                    } else if a[b1.start] < a[a1.end - 1] {
                        merge_into(a, a1, b1, &mut cache);
                        merged1_len = a1.len() + b1.len();
                    } else if !(a[b2.start] < a[a2.end - 1]) && !(a[a2.start] < a[b1.end - 1]) {
                        continue;
                    } else {
                        cache[..a1.len()].copy_from_slice(&a[a1.start..a1.end]);
                        cache[a1.len()..a1.len() + b1.len()].copy_from_slice(&a[b1.start..b1.end]);
                        merged1_len = a1.len() + b1.len();
                    }
                    let a1 = Range::new(a1.start, b1.end);
                    if a[b2.end - 1] < a[a2.start] {
                        cache[merged1_len + b2.len()..merged1_len + b2.len() + a2.len()]
                            .copy_from_slice(&a[a2.start..a2.end]);
                        cache[merged1_len..merged1_len + b2.len()].copy_from_slice(&a[b2.start..b2.end]);
                        merged2_len = a2.len() + b2.len();
                    } else if a[b2.start] < a[a2.end - 1] {
                        merge_into(a, a2, b2, &mut cache[merged1_len..]);
                        merged2_len = a2.len() + b2.len();
                    } else {
                        cache[merged1_len..merged1_len + a2.len()].copy_from_slice(&a[a2.start..a2.end]);
                        cache[merged1_len + a2.len()..merged1_len + a2.len() + b2.len()]
                            .copy_from_slice(&a[b2.start..b2.end]);
                        merged2_len = a2.len() + b2.len();
                    }
                    let a2 = Range::new(a2.start, b2.end);
                    let a3 = Range::new(0, merged1_len);
                    let b3 = Range::new(merged1_len, merged1_len + merged2_len);
                    if cache[b3.end - 1] < cache[a3.start] {
                        a[a1.start + merged2_len..a1.start + merged2_len + merged1_len]
                            .copy_from_slice(&cache[a3.start..a3.end]);
                        a[a1.start..a1.start + merged2_len].copy_from_slice(&cache[b3.start..b3.end]);
                    } else if cache[b3.start] < cache[a3.end - 1] {
                        merge_into(&cache, a3, b3, &mut a[a1.start..a1.start + merged1_len + merged2_len]);
                    } else {
                        a[a1.start..a1.start + merged1_len].copy_from_slice(&cache[a3.start..a3.end]);
                        a[a1.start + merged1_len..a1.start + merged1_len + merged2_len]
                            .copy_from_slice(&cache[b3.start..b3.end]);
                    }
                }
                iterator.next_level();
            } else {
                iterator.begin();
                while !iterator.finished() {
                    let a_range = iterator.next_range();
                    let b = iterator.next_range();
                    merge_pair(a, a_range, b, &mut cache, cache_size);
                }
            }
        } else {
            iterator.begin();
            while !iterator.finished() {
                let a_range = iterator.next_range();
                let b = iterator.next_range();
                merge_pair(a, a_range, b, &mut cache, cache_size);
            }
        }
        if !iterator.next_level() {
            break;
        }
    }
}

fn benchmark_sort(array: &mut [usize]) {

    wiki_sort(array);

}

fn shuffled(size: usize, seed: u64) -> Vec<usize> {
    let mut v: Vec<usize> = (1..=size).collect();

    let mut state = seed;

    for i in (1..size).rev() {
        state ^= state << 13;
        state ^= state >> 7;
        state ^= state << 17;

        let j = (state as usize) % (i + 1);

        v.swap(i, j);
    }

    v
}

fn memory_usage_kb() -> usize {
    let contents = std::fs::read_to_string("/proc/self/status")
        .unwrap_or_default();

    for line in contents.lines() {
        if let Some(rest) = line.strip_prefix("VmHWM:") {
            let kb = rest
                .split_whitespace()
                .next()
                .unwrap_or("0")
                .parse::<usize>()
                .unwrap_or(0);

            return kb;
        }
    }

    0
}

fn micros(d: Duration) -> u128 {
    d.as_micros()
}

fn run_once(size: usize, seed: usize) -> (u128, usize) {
    let expected: Vec<usize> = (1..=size).collect();
    let mut array = shuffled(size, seed as u64);

    let start = Instant::now();

    benchmark_sort(&mut array);

    let elapsed = start.elapsed();

    if array != expected {
        panic!(
            "sort failed with seed {} for size {}",
            seed,
            size
        );
    }

    (micros(elapsed), memory_usage_kb())
}

fn run_child(args: &[String]) {
    let size = args[2].parse::<usize>().expect("invalid size");
    let seed = args[3].parse::<usize>().expect("invalid seed");
    let (elapsed_us, mem) = run_once(size, seed);
    println!("{} {}", elapsed_us, mem);
}

fn main() {
    let args: Vec<String> = env::args().collect();
    if args.get(1).is_some_and(|arg| arg == "--run-once") {
        run_child(&args);
        return;
    }

    println!(
        "| {:>10} | {:>15} | {:>15} | {:>15} | {:>15} |",
        "Size",
        "Average time",
        "Maximum time",
        "Average memory",
        "Maximum memory"
    );

    println!(
        "|{:-<11}:|{:-<16}:|{:-<16}:|{:-<16}:|{:-<16}:|",
        "",
        "",
        "",
        "",
        ""
    );

    for power in MIN_POWER..=MAX_POWER {
        let size = 1usize << power;

        let mut total_time: u128 = 0;
        let mut max_time: u128 = 0;

        let mut total_mem: usize = 0;
        let mut max_mem: usize = 0;

        for seed in 1..=RUNS {
            let output = Command::new(env::current_exe().expect("failed to find current executable"))
                .arg("--run-once")
                .arg(size.to_string())
                .arg(seed.to_string())
                .output()
                .expect("failed to run benchmark child process");

            if !output.status.success() {
                panic!(
                    "benchmark child process failed: {}",
                    String::from_utf8_lossy(&output.stderr)
                );
            }

            let stdout = String::from_utf8(output.stdout)
                .expect("child process returned non-UTF-8 output");
            let mut fields = stdout.split_whitespace();
            let elapsed_us = fields
                .next()
                .expect("missing elapsed time")
                .parse::<u128>()
                .expect("invalid elapsed time");
            let mem = fields
                .next()
                .expect("missing memory usage")
                .parse::<usize>()
                .expect("invalid memory usage");

            total_time += elapsed_us;

            if elapsed_us > max_time {
                max_time = elapsed_us;
            }

            total_mem += mem;

            if mem > max_mem {
                max_mem = mem;
            }
        }

        let avg_time = total_time / RUNS as u128;
        let avg_mem = total_mem / RUNS;

        println!(
            "| {:>10} | {:>15} | {:>15} | {:>15} | {:>15} |",
            size,
            format!("{}.{:06}", avg_time / 1_000_000, avg_time % 1_000_000),
            format!("{}.{:06}", max_time / 1_000_000, max_time % 1_000_000),
            avg_mem,
            max_mem
        );
    }
}
RUST

RUN cargo build --release

CMD ["./target/release/rust-benchmark"]
EOF

docker build -t rust-benchmark "$WORKDIR"
docker run --rm --init rust-benchmark