配列のシャッフルと落とし穴

風邪で寝てるあいだ、頭痛な頭でなぜか配列のシャッフルするアルゴリズムのことを思い出し、乱数と使い方の落とし穴について考えてました。

シャッフルは、たとえば輪読の順番を決めるときにつかいます。名簿を最初あいうえお順に並べて配列に入れておいてシャッフルすることで公平に順番を決めようという感じです。

シャッフル時に配列の要素を乱数をベースに入れ替えるのですが、乱数の使い方を誤ると最初の順番に依存して順番が偏ってしまうことが、落とし穴です。

結論から言うと配列シャッフルのアルゴリズムJavaで書くと以下のようになります。

class Shuffle {
  private static final String names[] = {
    "taro",
    "jiro",
    "saburo",
    "shiro",
    "goro",
  };

  public static void main(String[] args) {
    shuffle(names);
    for (String name : names) System.out.println(name);
  }

  private static <T> void shuffle(T[] array) {
    for (int i = 0; i < array.length; i++) {
      int dst = (int) Math.floor(Math.random() * (i + 1));
      swap(array, i, dst);
    }
  }

  private static <T> void swap(T[] array, int i, int j) {
    T tmp = array[i];
    array[i] = array[j];
    array[j] = tmp;
  }
}

メインのコードを切り出すと以下のとおりです:

  // GOOD
  private static <T> void shuffle(T[] array) {
    for (int i = 0; i < array.length; i++) {
      int dst = (int) Math.floor(Math.random() * (i + 1));
      swap(array, i, dst);
    }
  }

配列を頭からたどり、乱数で指定した位置と入れ替えていきます。
公平になるようシャッフルするキーポイントは、「乱数の範囲も広げていく」ことです。

良く間違える例としては以下のようなものがあります:

  // BAD1
  private static <T> void shuffle(T[] array) {
    for (int i = 0; i < array.length; i++) {
      int dst = (int) Math.floor(Math.random() * (array.length + 1));
      swap(array, i, dst);
    }
  }
  // BAD2
  private static <T> void shuffle(T[] array) {
    for (int i = 0; i < array.length; i++) {
      int src = (int) Math.floor(Math.random() * (array.length + 1));
      int dst = (int) Math.floor(Math.random() * (array.length + 1));
      swap(array, src, dst);
    }
  }

なぜ上の二つがいけないか。

アルゴリズム中で「乱数でどれかになる」のを「全パターンともできる」と見ることです。たとえば[a,b,c]で、cと3つのどれかと入れ替えると[c, b, a]、[a, c, b]、[a, b, c]のどれかになるのですが、これが全部生成されるものとみなします。つまり、0〜2の間で乱数を取ってそれを入れ替える場合、3パターンできます。繰り返しごとにこの各パターンからそれぞれさらに分裂したパターンのセット増えていきます。その中のパターンの重複数を全最終的に出てきたパターンで割った数が、シャッフルでそのパターンになる確率になります。

たとえばサイズが4の場合、このパターン分裂は以下のようになります

  • GOOD: 1 * 2 * 3 * 4
  • BAD1: 4 * 4 * 4 * 4
  • BAD2: 4*4 * 4*4 * 4*4 * 4*4

組み合わせの数はn!であり、GOODはちょうど全組み合わせが1つづつ出るようになっています。一方BAD1やBAD2はどれかが偏って出るようになってしまいます(割り切れないため、パターンごとに偏りがあるのは明確)。

つまり、最初の配列での位置が順番に影響するというわけです。一様な乱数を使ってるのに(からこそ)偏ってしまいます。

「シャッフル、乱数」等でググると悪い例で記述されたコードがいっぱい出てきます。とくにVBJavaScriptで。単に適当に入れ替えたいだけで、偏りがあってもよい状況ならいいのですが、ゲームなどではよろしくないと思います。

シャッフル分布確率の例: 4要素["a", "b", "c", "d"]の場合

good:
{"a"=>[0.25, 0.25, 0.25, 0.25],
 "b"=>[0.25, 0.25, 0.25, 0.25],
 "c"=>[0.25, 0.25, 0.25, 0.25],
 "d"=>[0.25, 0.25, 0.25, 0.25]}

bad1:
{"a"=>[0.25, 0.25, 0.25, 0.25],
 "b"=>[0.29296875, 0.22265625, 0.234375, 0.25],
 "c"=>[0.24609375, 0.28125, 0.22265625, 0.25],
 "d"=>[0.2109375, 0.24609375, 0.29296875, 0.25]}

bad2:
{"a"=>[0.296875, 0.234375, 0.234375, 0.234375],
 "b"=>[0.234375, 0.296875, 0.234375, 0.234375],
 "c"=>[0.234375, 0.234375, 0.296875, 0.234375],
 "d"=>[0.234375, 0.234375, 0.234375, 0.296875]}

bad2でのシャッフルの場合"a"が0番目に来る確率が29.6875%という意味です。

確率の計算コード

def shuffle_good(array)
  all = [array]
  (0...array.size).each do |i|
    new_all = []
    all.each do |prev|
      (0...(i+1)).each do |dst|
        current = prev.clone
        current[dst], current[i] = current[i], current[dst]
        new_all << current
      end
    end
    all = new_all
  end
  all.sort
end

def shuffle_bad1(array)
  all = [array]
  (0...array.size).each do |i|
    new_all = []
    all.each do |prev|
      (0...array.size).each do |dst|
        current = prev.clone
        current[dst], current[i] = current[i], current[dst]
        new_all << current
      end
    end
    all = new_all
  end
  all.sort
end

def shuffle_bad2(array)
  all = [array]
  (0...array.size).each do |i|
    new_all = []
    all.each do |prev|
      (0...(array.size)).each do |src|
        (0...(array.size)).each do |dst|
          current = prev.clone
          current[dst], current[src] = current[src], current[dst]
          new_all << current
        end
      end
    end
    all = new_all
  end
  all.sort
end

def dist(all)
  map = {}
  all.each do |pattern|
    pattern.each_index do |i|
      map[pattern[i]] ||= [0] * pattern.size
      map[pattern[i]][i] += 1
    end
  end
  map.each_value do |counts|
    counts.each_index do |i|
      counts[i] /= all.size.to_f
    end
  end
  map
end

names = [
"a",
"b",
"c",
"d",
]

require "pp"
all_good = shuffle_good(names)
puts "good: "
pp dist(all_good)
puts
all_bad1 = shuffle_bad1(names)
puts "bad1: "
pp dist(all_bad1)
puts
all_bad2 = shuffle_bad2(names)
puts "bad2: "
pp dist(all_bad2)
puts