Scalaでパーサーを作ってみる〜3:式の評価

Scalaの勉強をはじめたので、とりあえず簡単な数式パーサーを作ってみてます。
http://d.hatena.ne.jp/nowokay/20111101#1320102262


前回は、構文木のオブジェクトを生成しました。
http://d.hatena.ne.jp/nowokay/20111104#1320365981


これで、構文木が扱えるようになったので、あとはその構文木を処理して式を評価するだけです。
式の計算を進めて値を出すことを「評価」といいます。


その前に、前回のエントリで水島さんに教えてもらったことを取り込みます。
構文木作成に使わない構文要素は「~>」「<~」を使うと左右の項を捨てれるということでした。ということでfactorの項を書き換えてみます。

  def factor: Parser[AST] = intLiteral | "("~>expr<~")"^^{
    x=>x}


あと、左結合する演算子はchainl1を使うとすっきり書けるということで、addを書き換えてみます。

  def add:  Parser[AST] = chainl1(term,
    "+"^^{op => (left:AST, right:AST) => AddOp(left, right)}|
    "-"^^{op => (left:AST, right:AST) => SubOp(left, right)})

おー、めちゃくちゃすっきりしました!
Scalaは、書く人の習熟度によって記述がまったく変わっていく気がします。


あと、Scalaは、省略できるものを省略して言語の推論にまかせることで、変更に強いコードになっていくと思います。ということは、その記述方法によっても変更への強さが変わるということで、同じ処理を記述しても記述する人によって変更への強さがかわるということにもなります。
Javaの場合は、「interfaceを使う」などデータ構造や処理内容で変更に強いコードを実現していて、基本的には使うデータ構造や処理が同じなら同じコードになり、記述のしかたによって変更への強さが変わるということはほとんどありませんでした。
ローカル変数を使わずにそのまま引数にしたりメソッドで呼び出すことで変更への強さを実現できますが、これは微妙に処理も変わってます。


で、まあ、これで構文木が扱えるようになったので、これを処理していきます。

class ExprVisitor{
  def visit(ast:AST):Any = {
    ast match{
      case AddOp(left, right) =>{
          (visit(left), visit(right)) match{
            case (lval:Int, rval:Int) => lval + rval
          }
      }
      case SubOp(left, right) =>{
          (visit(left), visit(right)) match{
            case (lval:Int, rval:Int) => lval - rval
          }
      }
      case MulOp(left, right) =>{
          (visit(left), visit(right)) match{
            case (lval:Int, rval:Int) => lval * rval
          }
      }
      case IntVal(value) => value
    }
  }
}


これ、Javaで書くときは、パターンマッチ相当のことを、オーバーロードを使って引数の型によって呼び出されるメソッドを振り分けることで実現していました。
そうすると、構文要素を追加するたびに基底になるVisitorクラスと実処理をするXxxVisitorクラスにvisitメソッドを定義していかないといけなくてめんどくさかったのですが、Scalaでは単純な再帰で書けるので楽ですね。


また、このときのC++Javaでの記述方法をパターン化したものがGoFのVisitorパターンなのですが、Scalaではオーバーロードではなくパターンマッチを使うことでそのような設計パターンは不要になっています。
結局GoFデザインパターンC++/Javaのようなクラス指向言語でいかに柔軟性を実現するかというもので、プログラム一般の話ではないので、「アルゴリズム勉強しましょう」っていったときに「いやそれよりデザインパターンを」っていうレスが付くたびにその意図を疑問に思ってたりします。


ともかく、構文の処理もできるようになったので式の評価を実行してみます。

  def main(args: Array[String]): Unit = {
    val expr = "12-5*3+7*(2+5)*1+0"
    
    val parser = new MiniParser
    val ast = parser.parse(expr).get

    val visitor = new ExprVisitor
    var result = visitor.visit(ast);
    
    println(result)
  }


実行すると、次のように計算結果が得られました。やりました!

46


ということで、ソースはこんな感じになりました。

package miniparser

import scala.util.parsing.combinator.RegexParsers

object Main {

  def main(args: Array[String]): Unit = {
    val expr = "12-5*3+7*(2+5)*1+0"
    
    val parser = new MiniParser
    val ast = parser.parse(expr).get

    val visitor = new ExprVisitor
    var result = visitor.visit(ast);
    
    println(result)
  }
}

class ExprVisitor{
  def visit(ast:AST):Any = {
    ast match{
      case AddOp(left, right) =>{
          (visit(left), visit(right)) match{
            case (lval:Int, rval:Int) => lval + rval
          }
      }
      case SubOp(left, right) =>{
          (visit(left), visit(right)) match{
            case (lval:Int, rval:Int) => lval - rval
          }
      }
      case MulOp(left, right) =>{
          (visit(left), visit(right)) match{
            case (lval:Int, rval:Int) => lval * rval
          }
      }
      case IntVal(value) => value
    }
  }
}

trait AST
case class AddOp(left: AST, right:AST) extends AST
case class SubOp(left: AST, right:AST) extends AST
case class MulOp(left: AST, right:AST) extends AST
case class IntVal(value: Int) extends AST

class MiniParser extends RegexParsers{
  //expr ::= add
  def expr: Parser[AST] = add
  //expr ::= term {"+" term | "-" term}.
  def add:  Parser[AST] = chainl1(term,
    "+"^^{op => (left:AST, right:AST) => AddOp(left, right)}|
    "-"^^{op => (left:AST, right:AST) => SubOp(left, right)})
  //term ::= factor {"*" factor}
  def term : Parser[AST] = chainl1(factor, 
    "*"^^{op => (left:AST, right:AST) => MulOp(left, right)})
  //factor ::= intLiteral | "(" expr ")"
  def factor: Parser[AST] = intLiteral | "("~>expr<~")"^^{
    x=>x}
  //intLiteral ::= ["1"-"9"] {"0"-"9"}
  def intLiteral : Parser[AST] = """[1-9][0-9]*|0""".r^^{
    value => IntVal(value.toInt)}

  def parse(str:String) = parseAll(expr, str)
}