Develop with pleasure!

福岡でCloudとかBlockchainとか。

数論変換(NTT)

最近の主要なゼロ知識証明システムで、多数の多項式の演算が必要になる。多項式の加算であれば、単に同じ次数の係数同士を加算するだけなので、必要な演算回数は係数の個数分。n-1次の多項式であればO(n)の加算で済む。一方、多項式の乗算の場合、その結果の次数が同じ項をまとめる必要があり、単純に計算するとO(n2)の乗算と同程度の加算が必要になる。

たとえば、2つの3次多項式を

  •  {p(x) = 2x^{3} + 3x^{2} + x + 5}
  •  {q(x) = x^{3} + 4x^{2} + 2x + 1}

とした場合、その積は、

 {f(x) = 2x^{6} + 11x^{5} + 17x^{4} + 17x^{3} + 25x^{2} + 11x + 5}

となる。

畳み込み演算

2つの多項式の積 {f(x) = p(x) \cdot q(x)}の各係数は、各多項式の係数列の畳み込みで求めることができる。具体的には、f(x)のk次の係数 {r_k}は、

 {\displaystyle r_k = \sum_{i=0}^{k} p_i \cdot q_{k - i}}

で計算できる( {p_i, q_i}が存在しない場合は0とみなす)。

たとえば、上記の2つの多項式の係数列は、低次数項順に並べ替えると、

  •  {p = \lbrack 5, 1, 3, 2 \rbrack}
  •  {q = \lbrack 1, 2, 4, 1 \rbrack}

で、各係数は、

  •  {r_{0} = 5 \times 1 = 5}
  •  {r_{1} = 5 \times2 + 1 \times 1 = 10 + 1 = 11}
  •  {r_{2} = 5 \times 4 + 1 \times 2 + 3 \times 1 = 20 + 2 + 3 = 25}
  •  {r_{3} = 5 \times 1 + 1 \times 4 + 3 \times 2 + 2 \times 1 = 5 + 4 + 6 + 2 = 17}
  •  {r_{4} = 1 \times 1 + 3 \times 4 + 2 \times 2 = 1 + 12 + 4 = 17}
  •  {r_{5} = 3 \times 1 + 2 \times 4 = 3 + 8 = 11}
  •  {r_{6} = 2 \times 1 = 2}

n-1次の多項式の場合、この単純な畳込み演算の計算量はO(n2)。

離散フーリエ変換

上記のような畳み込み演算のように多項式の各係数を直接使って計算する以外に、任意のx値に対して評価値を計算しその積を使って計算する方法もある。単純な方法だと、n-1次の多項式の場合、n個の任意のx値を選択し、 {(x_i, f(x_i))}を求めてラグランジュ補間を使ってf(x)を求めるとか。

演算を高速にする方法の1つが、高速フーリエ変換(FFT)を用いる方法。FFTは離散フーリエ変換(DFT)を効率的に計算するアルゴリズムで、DFTを用いた方式では、評価点として1のn乗根 {\omega_n = e^{2\pi i / n}}(iは虚数単位)という特殊な複素数のべき乗を使用する。

  1. p(x)についてDFTにより評価値を導出( {\displaystyle P_k = \sum_{j=0}^{n-1} p_j \cdot \omega^{jk}}
  2. q(x)についてDFTにより評価値を導出( {\displaystyle Q_k = \sum_{j=0}^{n-1} q_j \cdot \omega^{jk}}
  3. 点ごとに評価値を乗算( {\displaystyle F_k = P_k \cdot Q_k}
  4. 逆DFTでf(x)の係数を算出( {\displaystyle f_j = \frac{1}{n} \sum_{k=0}^{n-1} F_k \cdot \omega^{-jk}}

Rubyでコードを書くと↓

# DFT(単純な O(n²) 版)
def dft(coeffs)
  n = coeffs.size
  # e^(2πi/n) = cos(2π/n) + i*sin(2π/n)
  angle = 2 * Math::PI / n

  (0...n).map do |k|
    (0...n).sum do |j|
      theta = angle * j * k
      coeffs[j] * Complex(Math.cos(theta), Math.sin(theta))
    end
  end
end

# 逆DFT
def idft(values)
  n = values.size
  angle = -2 * Math::PI / n

  (0...n).map do |j|
    (0...n).sum do |k|
      theta = angle * j * k
      values[k] * Complex(Math.cos(theta), Math.sin(theta))
    end / n
  end
end

# DFTを使った多項式乗算
def poly_multiply_dft(p, q)
  # 積の次数に合わせてサイズを決定(2のべき乗にパディング)
  result_size = p.size + q.size - 1
  n = 1
  n *= 2 while n < result_size

  # ゼロパディング
  p_padded = p + [0] * (n - p.size)
  q_padded = q + [0] * (n - q.size)

  # Step 1, 2: DFTで点値表現に変換
  p_dft = dft(p_padded)
  q_dft = dft(q_padded)

  # Step 3: 点ごとに乗算
  r_dft = p_dft.zip(q_dft).map { |a, b| a * b }

  # Step 4: 逆DFTで係数表現に戻す
  r = idft(r_dft)

  # 実部を取り出して丸める(浮動小数点誤差の除去)
  r.map { |c| c.real.round }
end

# 実行
p_coeffs = [5, 1, 3, 2]  # p(x) = 5 + x + 3x² + 2x³
q_coeffs = [1, 2, 4, 1]  # q(x) = 1 + 2x + 4x² + x³

puts "p(x)の係数: #{p_coeffs}"
puts "q(x)の係数: #{q_coeffs}"
puts "DFTベース:   #{poly_multiply_dft(p_coeffs, q_coeffs)}"

高速フーリエ変換

上記のDFTベースだと計算量はO(n2)で、畳み込みと変わらないけど、FFTの分割統治法アプローチで高速化できる。具体的には、係数を偶数番目と奇数番目に分けて半分のサイズのDFTに分解し、それを再帰的に繰り返し、2つの小さなDFTの結果を組み合わせて最終的にサイズnのDFTの結果を得る(この結合処理をバタフライ演算と呼ぶ)。結果、計算量はO(n log n)になる。

Rubyで書くと↓

# FFT版
def fft(coeffs)
  n = coeffs.size
  return coeffs if n == 1

  # 偶数インデックスと奇数インデックスに分割
  even = (0...n).step(2).map { |i| coeffs[i] }
  odd  = (1...n).step(2).map { |i| coeffs[i] }

  # 再帰的にFFT
  e = fft(even)
  o = fft(odd)

  # バタフライ演算で結合
  angle = 2 * Math::PI / n
  result = Array.new(n)

  (0...n/2).each do |k|
    omega_k = Complex(Math.cos(angle * k), Math.sin(angle * k))
    t = omega_k * o[k]
    result[k]       = e[k] + t
    result[k + n/2] = e[k] - t
  end

  result
end

# 逆FFT
def ifft(values)
  n = values.size
  return values if n == 1

  even = (0...n).step(2).map { |i| values[i] }
  odd  = (1...n).step(2).map { |i| values[i] }

  e = ifft(even)
  o = ifft(odd)

  # 逆FFTはωの指数が負
  angle = -2 * Math::PI / n
  result = Array.new(n)

  (0...n/2).each do |k|
    omega_k = Complex(Math.cos(angle * k), Math.sin(angle * k))
    t = omega_k * o[k]
    result[k]       = e[k] + t
    result[k + n/2] = e[k] - t
  end

  result
end

# FFTを使った多項式乗算
def poly_multiply_fft(p, q)
  result_size = p.size + q.size - 1
  n = 1
  n *= 2 while n < result_size

  p_padded = p + [0] * (n - p.size)
  q_padded = q + [0] * (n - q.size)

  p_fft = fft(p_padded)
  q_fft = fft(q_padded)

  r_fft = p_fft.zip(q_fft).map { |a, b| a * b }

  r = ifft(r_fft)

  # 逆FFTでは最後にnで割る
  r.map { |c| (c.real / n).round }
end

# 実行
p_coeffs = [5, 1, 3, 2]
q_coeffs = [1, 2, 4, 1]

puts "p(x)の係数: #{p_coeffs}"
puts "q(x)の係数: #{q_coeffs}"
puts "FFTベース:  #{poly_multiply_fft(p_coeffs, q_coeffs)}"

数論変換

ただFFTは複素数の浮動小数点演算を行うため丸め誤差が発生する可能性があり、ゼロ知識証明システムでは有限体上での正確な演算が必要になるため、この誤差が問題になる。

そこで、複素数の1のn乗根を有限体上の原始n乗根に置き換えたのが数論変換(NTT=Number Theoretic Transform)。素数pに対して、p - 1がnで割り切れる場合、有限体 {\mathbb F_p}に、以下を満たす位数nの元 {\omega}が存在する。

  •  {\omega^{n} \equiv 1 \pmod{p}}
  •  {\omega^{k} \not\equiv 1 \pmod{p}}(1 ≤ k < nに対して)

この原始根を使うことで、FFTと同じアルゴリズムが有限体上で動作する。つまり計算量O(n log n)で有限体上の多項式の乗算が誤差なく行える。

NTTを利用する上で重要なのがpの選択に伴う位数nのサイズ。このサイズにより、扱える多項式の係数の数が制限される。

p - 1 = 2k × m(mは奇数)とした場合、 {\mathbb F_p}上では、最大2k個までの次数を持つ積多項式の計算が可能になる。そのため、主要なゼロ知識証明システムではNTTに適した素数の有限体を選ぶのが重要になっている。Bitcoinが使っているsecp256k1とかはこの数が少ないので不向きな有限体になる。

NTTを使った計算をRubyで書くと↓

# NTT用の定数
# p = 998244353 = 119 × 2^23 + 1 はNTT-friendlyな素数
# 原始根 g = 3
MOD = 998244353
PRIMITIVE_ROOT = 3

# モジュラ逆元
def mod_inverse(a, mod = MOD)
  mod_pow(a, mod - 2, mod)
end

# モジュラべき乗
def mod_pow(base, exp, mod = MOD)
  result = 1
  base %= mod
  while exp > 0
    result = result * base % mod if exp.odd?
    exp >>= 1
    base = base * base % mod
  end
  result
end

# NTT(FFTと同じ構造、複素数→有限体に置き換え)
def ntt(coeffs, inverse: false)
  n = coeffs.size
  return coeffs if n == 1

  # 偶数インデックスと奇数インデックスに分割
  even = (0...n).step(2).map { |i| coeffs[i] }
  odd  = (1...n).step(2).map { |i| coeffs[i] }

  # 再帰的にNTT
  e = ntt(even, inverse: inverse)
  o = ntt(odd, inverse: inverse)

  # 原始n乗根を計算
  # ω = g^((p-1)/n) mod p
  # 逆変換の場合は ω^(-1) を使う
  if inverse
    omega = mod_pow(PRIMITIVE_ROOT, (MOD - 1) - (MOD - 1) / n)
  else
    omega = mod_pow(PRIMITIVE_ROOT, (MOD - 1) / n)
  end

  # バタフライ演算
  result = Array.new(n)
  omega_k = 1

  (0...n/2).each do |k|
    t = omega_k * o[k] % MOD
    result[k]       = (e[k] + t) % MOD
    result[k + n/2] = (e[k] - t + MOD) % MOD 
    omega_k = omega_k * omega % MOD
  end

  result
end

# 逆NTT
def intt(values)
  n = values.size
  result = ntt(values, inverse: true)

  # 最後にnの逆元を掛ける
  n_inv = mod_inverse(n)
  result.map { |x| x * n_inv % MOD }
end

# NTTを使った多項式乗算
def poly_multiply_ntt(p, q)
  result_size = p.size + q.size - 1
  n = 1
  n *= 2 while n < result_size

  p_padded = p + [0] * (n - p.size)
  q_padded = q + [0] * (n - q.size)

  p_ntt = ntt(p_padded)
  q_ntt = ntt(q_padded)

  r_ntt = p_ntt.zip(q_ntt).map { |a, b| a * b % MOD }

  intt(r_ntt)
end

# 実行
p_coeffs = [5, 1, 3, 2]
q_coeffs = [1, 2, 4, 1]

puts "p(x)の係数: #{p_coeffs}"
puts "q(x)の係数: #{q_coeffs}"
puts "NTTベース:  #{poly_multiply_ntt(p_coeffs, q_coeffs)}"