関数の自動メモ化

groovy 1.8 では、memoize によってクロージャのメモ化が出来るようになったけれど、scala だってできるもん、という負け惜しみエントリ。

普通の自動メモ化

サクっと作ったものを(1〜5引数対応)をGistに上げたので簡単に紹介。

実装のポイントは単純で、下記の通り。

  // 1変数関数をメモ化する(2変数以上は tupled/untupledで対応)
  def memoize[ArgT, RetT](f: ArgT => RetT): ArgT => RetT = {
    val memo = scala.collection.mutable.Map.empty[ArgT, RetT]
    arg => memo.getOrElseUpdate(arg, f(arg))
  }

こんな風にして使う。

object Main extends Application {
  val fib: Int => Long = memoize {
    case 0 => 0
    case 1 => 1
    case n => fib(n - 1) + fib(n - 2)
  }
  println(fib(40))
}

再帰関数をメモ化する場合には、要注意な点がある。もとの関数で再帰呼び出しをする際には、メモ化された後の関数を呼び出さないとダメな点。そうしないと、再帰の呼出しがメモ化されない。上記コード例のように、あくまで、メモ化された後の関数実体が束縛される変数名をメモ化される前の関数から参照しなくてはならない。

下記の例は、ダメな例。

  // これはダメ
  val fib: Int => Long = {
    case 0 => 0
    case 1 => 1
    case n => fib(n - 1) + fib(n - 2)
  }
  val fibMemo = memoized(fib)

そう、まさに Groovy の memoize と同じ問題が起きる。

2011-05-08 追記

Groovy での再帰対応については、id:dev68 さんのエントリが参考になります。
エントリ全体としては、Groovy の Meta Object Protocol (MOP) の紹介と 1.8 での変更点に関するもので、非常に参考になりました。

再帰対応などの方法を検討(失敗談など)

普通は、この先って、いつキャシュを破棄するのかとかそういうところを詰めていくんじゃないのかなー、(例えば、外からキャッシュ解放操作を受け付けるとか、MRUキャッシュに置き換えるとか、WeakReferenceでドツボにはまるとか etc.)と思いつつ、あえて、「厳密に言えば再帰に対応していない」点を突き詰めてみる。前述のように、メモ化後のシンボルを参照すれば問題ないとは言え、おいらが許さない。

そこで、厳密に再帰に対応、かつ、動的に自動メモするにはどうすれば良いのか、検討してみた。なお(1)動的とは、関数を値として受け取るという意味で、(2)厳密に再帰対応とは、再帰を含んでいてもメモ化後に束縛されるシンボルを参照せずにメモ化できるという意味。

(1) だけならば前述の方法で、(2)だけならば後述の trait との mixin で対応可能だし、そもそも関数の形を変えていいのならば、不動点コンビネータを用いるのが定石。だけれど、関数の形を変えずに(1)と(2)を融合するのは、それなりの黒魔術が必要になる。

この先は失敗談がてら、簡単に何を検討したのかメモしておく。

不動点コンビネータ

まず、λ再帰(自分の名前を参照せずに再帰する方法)では、不動点コンビネータを使うのが定石。しかし、当然関数の形は変わってしまうので、実行中に関数を受け取ってメモ化するという点では、今回の目的にはそぐわない。

そぐわないとは分かっているけれど、あえて泥沼に足を踏み入れてみる。エイヤー!

まず 不動点コンビネータの基本である Y combinator を scala で書くと、下記のような感じ。(既に scala 実装はあるのだけれど、そこはあえて我流で)

  type F[A, B] = ((A => B), A) => B
  type Y[A, B] = (F[A, B], A) => B
  def Y[A, B](f: F[A, B], x: A): B = f((x1: A) => Y(f, x1), x)

書いてみて分かる 型つきのよさ。型がどう変化するかを type で記述できるという scala の良さを見なおした瞬間でもある。まあ、私は計算機科学に関する高等教育を受けていないので Y combinator に馴染みが薄く、先に型を書かないと飲み込めないというだけのことかもしれないけれどね。

以下、ねっとり(not サクっと)実装してみる。

object TestMemoRecFP {
  // memoization using Y-combinator
  type F[A, B] = ((A => B), A) => B
  type Y[A, B] = (F[A, B], A) => B
  def Y[A, B](f: F[A, B], x: A): B = f((x1: A) => Y(f, x1), x)
  def memoize[A, B](f: F[A, B]): F[A, B] = {
    val cache = collection.mutable.Map.empty[A, B]
    (fx, arg) => cache.getOrElseUpdate(arg, f(fx, arg))
  }
  
  // Factorial 
  val factorial: F[Int, Long] = (f, arg) => arg match {
    case 0 => 1
    case x if x > 0 => x.toLong * f(x - 1)
    case x if x < 0 => throw new IllegalArgumentException(
        "requirement failed: argument of factorial must be > 0")
  }
  val factoricalWithMemo = memoize(factorial)
  
  // Fibonacci sequence extended to negative index
  // What is the definition of fibonacci number with negative indices?
  // @see http://en.wikipedia.org/wiki/Fibonacci_number
  val fibonacci: F[Int, BigInt] = (f, arg) => arg match {
    case 0 => 0
    case 1 => 1
    case x if (x > 0) => f(x - 1) + f(x - 2)
    case x if (x % 2 == 0) => -f(-x)  // x < 0 && x is even
    case x => f(-x)                   // x < 0 && x is odd
  }
  val fibonacciWithMemo = memoize(fibonacci)
  
  def exampleOfMemoizationWithYcomb = {
    println(Y(fibonacciWithMemo, 10))
    println(Y(fibonacci, 10))
    
    println(Y(fibonacciWithMemo, 34))
    println(Y(fibonacci, 34))
    
    println(Y(fibonacciWithMemo, 100))
    //println(Y(fibonacci, 100))  // it takes too long as if it never ends
  }
  
  def main(args : Array[String]) {
    exampleOfMemoizationWithYcomb
  }
}
メモ化する trait と mixin

コップ本の Queue をデコレートする mixin の解説部分(Doubling とか、そんな trait をつくるところ)を見ながら、必死で作ってみた。(具体的な参照箇所は、原著第2版 12.5 The Doubling stackable modification trait を見たんだけれど、日本語のコップ本も、無料の原著第1版 12.5も同内容)

やる前から分かっていることだが、再帰には対応するんだけれど、関数を値として受けとってメモ化することはできない。(関数のクラスを受け取って mixin するか、匿名クラスをその場で定義しなくてはならない)。つまり、匿名関数リテラルの値をメモ化したりできないので、本エントリの最初に掲げたやりかたよりも、実用上すごく不便。

trait Memoized1[T1, R] extends Function1[T1, R] {
  val cache = collection.mutable.Map.empty[T1, R]
  abstract override def apply (v1: T1): R =
    cache.getOrElseUpdate(v1, super.apply(v1))
}

object TestMemoWithTrait {
  def testWithTraits {

    class FibonacciClass extends Function1[Int, BigInt] {
      override def apply(n: Int) = {
        println("Function called: "+this.getClass.toString+"; arg n = "+n)
        n match {
          case x if (x < 0) =>
            throw new IllegalArgumentException("The given argument must be >= 0")
          case 0 => 0
          case 1 => 1
          case x if (x >= 2) => apply(x - 1) + apply(x - 2)
        }        
      }
    }
    object fib extends FibonacciClass
    object fibMemo extends FibonacciClass with Memoized1[Int, BigInt]
    
    val f = new Memoized1[Int, Int] {
      def apply(n: Int): Int = {
        println("Function called: "+this.getClass.toString+"; arg n = "+n)
        n match {
          case x if x < 0 => 
            throw new IllegalArgumentException("The given argument must be >= 0")
          case 0 => 0
          case 1 => 1
          case x if (x >= 2) => apply(x - 1) + apply(x - 2)
        }
      }
    }
    
    println(fib(5))
    println(fibMemo(5))
    println(f(5))
    
  }
  def main(args : Array[String]) {
    testWithTraits
  }
}
関数リテラルを受け取りつつ、その匿名クラスを Manifest でゴニョ

先程の mixin の方法は、再帰に対応する正当な方法(他の方法としては、コンパイラプラグインか Byte code engineering みたいな黒魔術になるはず)なんだけれど、mixin はあくまでクラスに適用するものだから、クラスがわからないことにはどうにもならない。

そこで、関数の値を受け取りつつそいつ自体は完全に無視して、implicit なパラメータでその関数値のクラス(通常は関数リテラルを書いたときに自動生成される匿名クラス)を受け取って、そいつと mixin してあげようと。そんな無茶な考えが下記。ちなみに scala 2.9.0 RC1 です。

import scala.reflect.Manifest
import scala.tools.nsc.interpreter.IMain
import scala.tools.nsc.Settings
import scala.tools.nsc.settings._
import scala.tools.nsc.util.BatchSourceFile
import scala.tools.util.PathResolver

object DynamicMemoizer 
  extends java.lang.ClassLoader(getClass.getClassLoader) {
  private val id = Iterator.from(1)
  def createUniqueId = synchronized { "DynamicMemoizerKlass" + id.next }
  
  def apply[A, R](func: A => R)(implicit a: Manifest[A], r: Manifest[R]) = {
    val id = createUniqueId
    val classDef = "class %s extends %s with Memoizing[%s, %s]".
      format(id, func.getClass.getName, a.toString, r.toString)
      
    println(classDef)
    
    val settings = new Settings(println)
    settings.usejavacp.value = true
    val interpreter = new IMain(settings)
    interpreter.setContextClassLoader

    interpreter.compileSources(new BatchSourceFile("<anon>", classDef))

    val bytes = interpreter.classLoader.findBytesForClassName(id)

    val clazz = defineClass(id, bytes, 0, bytes.length).asInstanceOf[Class[(A => R) with Memoizing[A, R]]]
    
    clazz.newInstance
  }
}

trait Memoizing[T, R] extends Function1[T, R] {
  val memo = scala.collection.mutable.Map.empty[T, R]
  abstract override def apply(arg: T): R =
    memo.getOrElseUpdate(arg, super.apply(arg))
}

object DynamicMemoMain extends Application {
  
  // 比較用
  class fibClass extends (Int => Long) {
    def apply(n: Int): Long = n match {
      case 0 => 0
      case 1 => 1
      case _ => apply(n - 2) + apply(n - 1)
    }    
  }
  object fibStaticMixin extends fibClass with Memoizing[Int, Long]  
  
  val fib: Int => Long = {
    case 0 => 0
    case 1 => 1
    case n => fib(n - 1) + fib(n - 2)
  }
  // 全てはこれのために、動的自動メモ化発動!
  val fibDynamicMixin = DynamicMemoizer(fib)
  
  println(fibStaticMixin(40))
  println(fibDynamicMixin(40))
}

なんじゃこりゃーレベルOTL

AOP (メモ化とは、アスペクト横断な関心事なんですキリッ!)

なんじゃこりゃーついでに、spring aop とか使ってみる。
なんというか、メモ化って、関数呼び出しごとにキャッシュ参照するっていうアスペクト横断的関心事なんですよ、という発想。もちろん、厳密な意味では再帰呼び出しに対応しない。(最初に示した方法と同様、メモ化後のシンボルを参照する必要がある)

import org.aopalliance.intercept.{MethodInterceptor, MethodInvocation}
import org.springframework.aop.framework.ProxyFactory

// Memoization with `MethodInterceptor`
class Memoizer extends MethodInterceptor {
  import scala.collection.mutable.WrappedArray
  private val cacheMap = collection.mutable.Map.empty[WrappedArray[AnyRef], AnyRef]
  def invoke(invocation: MethodInvocation): AnyRef =
    cacheMap.getOrElseUpdate(invocation.getArguments, invocation.proceed)
}

object TestMemo {
  def memoize[T](funcObj: T): T = {
    val pf = new ProxyFactory
    pf.setTarget(funcObj)
    pf.addAdvice(new Memoizer)
    pf.getProxy.asInstanceOf[T]
  }
  
  def exampleOfMemoizationWithAOP = {
    def sqr(x: Int): Int = {
      println("method invoked!")
      x * x
    }
    val sqrMemo = memoize(sqr _)
    
    println("sqr first call:")
    println(1 to 3 map sqr)
    println("sqr second call:")
    println(1 to 3 map sqr)
    
    println("sqrMemo first call")
    println(1 to 3 map sqrMemo)
    println("sqrMemo with memo second call")
    println(1 to 3 map sqrMemo)
  }
  
  def main(args : Array[String]) {
    exampleOfMemoizationWithAOP
  }
}
おわりに

まあ、せっかくなので、失敗談も含めて晒してみました。自分の無能さ加減(単に日本語が下手くそというだけでなく、実装力も低い)の記録ですね。後日に振り返って、「あー、おいらも成長したなー」を実感するための材料に化ける?

あと未検討な手法は、BCE(Byte code engineering) と コンパイラプラグインがあるのだけれど、BCEの場合は相互再帰とかもっと複雑な再帰をどうするかという問題があるのでとても大変。とすると、コンパイラプラグインが有望なのかな。

ハミング数の算法 3種 - infinite list, imperative queue, cyclical iteration

素因数分解 H = 2^i \cdot 3^j \cdot 5^k, \; \mathrm{where} \; i, j, k \geq 0 の形式になる数をハミング数という*1

この手の問題は、やはり Haskell で書くとスッキリする。

Hamming numbers - Rosetta Code から Haskell のコードを引用する。

hamming = 1 : map (2*) hamming `merge` map (3*) hamming `merge` map (5*) hamming
     where merge (x:xs) (y:ys)
            | x < y = x : xs `merge` (y:ys)
            | x > y = y : (x:xs) `merge` ys
            | otherwise = x : xs `merge` ys

main = do
    print $ take 20 hamming
    print $ hamming !! 1690
    print $ hamming !! 999999

これを scala に翻訳すると、下記のようになる。

ただし、この方法で第1000000項を求めようとすると、メモリを圧迫する。初項から第1000000項の全てを保持するためである。
実際、私の環境で実行するためには、JVMが最大で使用できるヒープメモリの制限を増やす必要があった。具体的には、-Xmx512m オプションを指定した。

このメモリ圧迫問題の解決法には、①明示的にキューに積んでポップするという imperative な方法と;② cyclical iteration を使う方法がある。

①の方法に関しては、scala 実装コードが、そのまんまの例。

②の方法に関しては、python 実装コードが、そのまんまの例。

cyclical iteration と言うと、難しく聞こえるが、本質は簡単。Stream(scala の無限リスト)の代わりに Iterator を用いることで、"忘れていく"無限リストを作ればよい。

python の cyclical iteration を scala に翻訳すると、下記のようになる。

cyclical iteration の implementation issues は、下記の2点:

  • 前方参照をどう解決するか
  • 後続項目の遅延評価をどうするか

以下、簡単な補足説明。

前方参照をどう解決したか。python は単に def すれば前方参照であることをごまかせるのに対して、scala は一工夫が必要。具体的には、class 定義のコンテキストにしてあげると、うまくごまかせる。object ForwardReferenceabl という ローカルな singleton を定義したのは、ひとえに前方参照をごまかすためのトリック。

では、後続項目の遅延評価をどうするか。scala において iterator を chain するのは Iterator#++() というメソッド。これは、引数を遅延評価してくれるので、一見すると、問題がないように思える。しかし、コードを書いてから分かったことだが、どうも scalaIterator インタフェースのいろいろな実装の中には、.next を呼び出した際に、さらに一個先読みする実装が混ざっている*2ようだ。

そのため、本来は初項の1だけを初期値にすればよいところを、1,2 および 3,4,5 を指定している。(先読み問題を解決するために 1,2 を指定するとともに、循環参照部は 2の倍数系列と3の倍数系列と5の倍数系列を揃えるために 3, 4, 5 と指定した)

なお、私が Hamming 数に関心を抱いたきっかけは id:dev68 さんのエントリhttp://d.hatena.ne.jp/deve68/20110423/1303521318を拝見したためです。いつも参考にさせていただいております。ありがとうございます。

*1:正確を期せば、自然数のうち7以上の素数を因数に持たないものを昇順でN個求めよという問題を「ハミングの問題」といい、「自然数のうち7以上の素数を因数に持たないもの」をハミング数と言う。

*2:私の予想であり、コードを特定したわけではない

フィボナッチ数の算法のベンチマーク

まずは、結果を下図に示します。

なおベンチマークに使用したコードは上記の説明コードとは違い、チューニング済みです。

逐次平方変換のコードは、dev68 さんの groovy 実装 をそのまま scala に移植しました。また、このエントリをきっかけとして、このエントリが生まれました。ありがとうございます。

Binet 公式のコードは、計算しなくて良い部分(wikipedia中の(-Φ)^(-n))を省いて高速化しています。この部分は、線形代数の言葉で言えば、「絶対値が1未満のほうの固有値」に一致します。ですから、線形代数的な観点からは、「絶対値が1未満のほうの固有値だから、冪乗を繰り返すとなくなっていくんだな」と、直感的に理解することも可能です。

A linear algebra view of Fibonacci sequence

はじめに

フィボナッチ数列の高速な算法である

  • 逐次平方変換
  • Binet公式

について、線形代数的な視点から説明を与えて、ベンチマークしてみました。

なおどちらの方法についても、詳しい説明は世の中に出回っているはずですが、私にとってそれらの説明は天才のひらめきのような印象が拭えませんでした。そこで、あえて線形代数観点から説明を与えることで、「誰もが思いつきそうな、ありきたりの手法じゃないか」と感じて頂けたら嬉しいです。

なお、以下ではフィボナッチ数列の初項を F_0 = 0, F_1 = 1 としました。

逐次平方変換

フィボナッチ数列の漸化式を下記のように表現します:
[\array{F_n \quad F_{n-1} \\ F_{n-1} \quad F_{n-2}}] = [\array{1 \quad 1 \\ 1 \quad 0}] [\array{F_{n-1} \quad F_{n-2} \\ F_{n-2} \quad F_{n-3}}]

初項と係数行列が等しいことに注意すれば、下記のようになります。

[\array{F_n \quad F_{n-1} \\ F_{n-1} \quad F_{n-2}}] = [\array{1 \quad 1 \\ 1 \quad 0}]^{n-1}

ここで行列の結合性をもちいて、冪乗を下記の擬似コードのように O(log N) で計算すると、いわゆる逐次平方変換と等価な式が得られます。

// O(log N) な冪乗計算
def pow(b: Matrix, e: Int): Long =
  if (e == 0) IdentityMatrix 
  else (e % 2) match {
    case 0 => pow(b * b, e / 2)
    case 1 => b * pow(b, e - 1)
  }
}

Binet 公式

フィボナッチ数列の一般項を示す Binet 公式を、線形代数の言葉で表現すれば、「対角化」の一言につきます。

漸化式を下記のように行列表記して、

[\array{F_n \\ F_{n-1}}] = [\array{1 \quad 1 \\ 1 \quad 0}] [\array{F_{n-1} \\ F_{n-2}}]

そのまま一般項を係数行列の冪で表現し、

[\array{F_n \\ F_{n-1}}] = [\array{1 \quad 1 \\ 1 \quad 0}]^n [\array{F_{1} \\ F_{0}}]

ここで、行列の冪を計算するに際して、対角化して対角成分の冪を計算すると Binet の公式が得られます。

Wolfram alpha先生による Jordan 分解

説明的コード

どうも日本語の説明が下手くそなので、いままで述べた内容を scala コードで示します。

RubyのEnumerableを遅延評価にしてみる

最近、無意識のうちに遅延評価を前提としたコードを書くようになってきました。趣味の scala コードばかりを書いている弊害でしょうか。

そんな遅延脳が失態をやらかしました。正格評価前提の言語(Ruby)で、遅延評価を期待したコードを書いてしまい、プログラムをハングアップさせてしまいました。

# 正の偶数を小さい順に5個表示したい
(1..(1.0/0.0)).select(&:even?).take(5).each { |x| puts x }
# しかし、このコードは意図したとおりには動かない

Ruby の Enumerable#select は正格評価なので、select した時点で自然数から偶数全てを抽出した無限大の配列を作ろうとしてハングアップしてしまうんですね。。。

Rubyでも遅延評価できたらいいなあ。例えば 下記の scala コードのように。

// 正の偶数を小さい順に5個表示する
Iterator.from(1) filter {_ % 2 == 0} take 5 foreach println
// これは意図通りに動く

そこで、Ruby で遅延評価をするメソッド lazy_* を付け足してみました。

module Enumerable
  def self.make_lazy(*syms)
    syms.each do |sym|
      class_eval <<-"EOD"
        def lazy_#{sym}(*arg, &blk)
          Enumerator.new do |e|
            each do |x|
              [x].#{sym}(*arg, &blk).each { |y| e << y }
            end
          end
        end
      EOD
    end
  end
  
  #-- Enumeratorを返すメソッドを作成
  make_lazy :collect, :map, :select, :reject, :grep
  make_lazy :find_all, :flat_map, :concat_collector
end
 
#-- 遅延評価により無限リストでもOK
(1..(1.0/0.0)).lazy_map(&:even?).take(5).each { |x| puts x }
(1..(1.0/0.0)).lazy_select(&:even?).take(5).each { |x| puts x }
 
#-- 標準メソッドではアウト
#(1..(1.0/0.0)).map(&:even?).take(5).each { |x| puts x }
#(1..(1.0/0.0)).select(&:even?).take(5).each { |x| puts x }

これで多い日も安心。

基礎から始める覆面算

Scala の for 式で覆面算を解いているエントリを見つけた人が Haskell と Groovy で解きながらリスト内包表記について考えていた

このようなエントリを見かけると、いつも面白いなーと感心するけど、どうも私は実戦で使いこなせていない。おそらくは、パフォーマンスと柔軟性に欠けるからだと思う。例えば、覆面算の問題パターンを変えると List comprehension を使えないんじゃないとか、パフォーマンスはバックトラックに比べてどうなるのとか。

もうちょっと詳しく考えるため、実際にコードを書いてみた。

まずは、元ネタと同じように素直に総当りをしてみる。

確かに scala の for式 は便利だなと感じる反面、実行時間は5秒も掛かっている。遅い、遅すぎる。

そこで、for式の範囲内で枝刈りをしてみることにした。

枝刈りのアイディアは、下位から順番に数字を当てはめていき、筆算が成立しなくなった時点で其の当てはめを打ち切るもの。

実行時間は20ms未満となった。なるほど、この程度の問題であれば、for式の範疇でも上手く枝刈りが出来るのか。

次は覆面算の問題を引数で与えられるように一般化してみる。枝刈りの方針については、今までと全く同じだが、任意の数の筆算を許可しようと思うと、構造が全く変わってしまった。

もはや、for式は使えない。実際、私は for式を諦めて再帰で書いた。このあたりをスマートに書きたいのだが、うまい手はないものだろうか?

ベンチマーク結果一覧。各メソッドとも5回実行、毎回の所要時間を System.nanotime()で計測。

trial12345
method #167447234006797389219678783148966217006436632244455
method #21823689518029371174583371730269417080834
method #3120678891119130650118630953119162734118002919
method #47930882179116997790931048018192578304648
method #54062666940049150405003804107312040473074

完全に自分メモエントリ OTL

Re:今流行のお題を出してみた(一方通行を許可した迷路を作成)

こちらの問題を解いてみました。

グラフ連結に関する理論を全く使っていないという意味で、力技です。一方通行のドアが壁に存在する期待値が40%くらいまでなら、なんとかなりました。
35%を切ってくると、このままでは厳しいかもしれません。

再帰の仕方とかは、私よりも皆さんのほうがよくご存知だとおもいますので、答案のポイントだけ下記に記します:

  • 迷路を掘っていくのではない;迷路になるように埋めていく。

なお、SVGでの視覚化部分は、出題者 aya_eiya さんのルーチンを使わせていただきました。

一方通行を許可した迷路を作成 · GitHub

探索ルーチンの部分とか、トホホという感じですね。かっこいい書き方を勉強して出直します(泣)