January 2018

Adding tail call optimization to a Lisp interpreter in Ruby

I spent the past week doing a programming retreat at the Recurse Center in New York City. One project I worked on was writing a simple Lisp interpreter in Ruby, following the excellent make-a-lisp tutorial.

One of the more educational steps was adding support for tail call optimization. I had used tail call optimization before and had a vague sense of how it worked, but it always sounded complicated. Implementing it in a toy interpreter was a great way to understand it better, and it turns out that the core idea is actually very simple!

In this post, I’ll show you why we need tail call optimization, and exactly how I implemented it in an interpreter with a surprisingly small set of changes to the code.

Blowing out the stack

To start off, let’s take a quick tour of a few constructs the Lisp interpreter supports:

; basic math
> (+ (* 2 3) (* 4 5))
26

; variables
> (def! n 1)
> n
1

; conditionals
> (if (= n 1) "yep" "nope")
"yep"

; user-defined functions
> (def! double (fn* (x) (* x 2)))
> (double 5)
10

Using these building blocks, we can define a recursive function which sums all the integers from 0 up to a given number. On each recursive call we decrement n and add to our accumulator variable acc, until we hit the base case of n = 0.

(def! sum-to
  (fn* (n acc)
    (if (= n 0)
      acc
      (sum-to (- n 1) (+ n acc)))))

; test it out:
(sum-to 3 0) ;=> 6

On small inputs, it works great. But with a large input, it doesn’t work so well – we get a Ruby error “stack level too deep”.

(sum-to 10000 0)
;=> /Users/glitt/personal-dev/mal/glitt/env.rb:16: stack level too deep (SystemStackError)

The problem

To understand why that happened, we need to look inside our interpreter.

The relevant part of the code is the EVAL function, which is the heart of the interpreter. It takes as input an abstract syntax tree (AST) of tokens and an environment of symbol definitions, and returns a new AST representing the fully-evaluated expression.

def EVAL(ast, env)
  # evaluate the AST passed in, return a new AST
end

The body of our sum-to function is a conditional expression, so let’s zoom in on the part of EVAL that’s responsible for evaluating conditionals.

def EVAL(ast, env)
  # ...
  case ast.first

  # If the first symbol in the AST is :if,
  # we're evaluating a conditional
  when :if
    # Destructure the arguments out of the AST
    conditional, true_branch, false_branch = ast[1..3]

    # Evaluate the conditional,
    # then evaluate the appropriate branch
    if truthy?(EVAL(conditional, env))
      return EVAL(true_branch, env)
    else
      return EVAL(false_branch, env)
    end
  # ...
end

Notice that we make a function call to EVAL as part of evaluating a conditional. This allocates a new stack frame every time we evaluate a conditional…hence the stack overflow when we recursively evaluate hundreds of them. To solve this we need to find a way to avoid allocating a new stack frame every time we evaluate a conditional.

The idea

Here’s one potential solution. Imagine for a second that Ruby had GOTO statements. First, put a LABEL at the top of our EVAL function. Then, when evaluating a branch of the conditional, instead of making a function call, we could just set the ast variable to point at the new code we want to evaluate, and then GOTO the top of the EVAL function. Then EVAL would continue executing, this time with the new code in the ast variable.

(If you’re paying close attention you may have noticed EVAL also has a second argument, but in this case it doesn’t change between executions so we can just leave it unchanged.)

def EVAL(ast, env)
  LABEL top_of_eval
  # ...
  case ast.first

  when :if
    conditional, true_branch, false_branch = ast[1..3]

    if truthy?(EVAL(conditional, env))
      ast = true_branch
      GOTO top_of_eval
    else
      ast = false_branch
      GOTO top_of_eval
    end
  # ...
end

It turns out Ruby doesn’t actually have gotos (well…mostly). But there’s a simple hack to get equivalent behavior: we can wrap all of EVAL in an infinite loop and then call next when we want to go back to the top of the loop.

def EVAL(ast, env)
  loop do
    # ...
    case ast.first

    when :if
      conditional, true_branch, false_branch = ast[1..3]

      if truthy?(EVAL(conditional, env))
        ast = true_branch
        next
      else
        ast = false_branch
        next
      end
    # ...
  end
end

Now when our interpreter evaluates a conditional it no longer makes a function call, and consequently no longer allocates a new stack frame.

Extending to function calls

In order to make our sum-to function run successfully, we’ll need to also make function calls tail-optimized in our interpreter.

The principle is exactly the same as with conditionals, but the implementation is a bit more involved because function calls also involve binding variables to arguments and creating a new environment for evaluation.

Let’s first look at how function definition is implemented in our interpreter. When a function is defined by the user, we simply return a Ruby Proc as our internal representation of the function. (If you’re not familiar with Ruby, this is the standard Ruby object for representing a function.)

def EVAL(ast, env)
  loop do
    case ast.first
    # ...
    when :"fn*"
      params = ast[1]
      fn_body = ast[2]

      return -> (*args) do
        # Create a new environment:
        # inherits from current environment and
        # binds variables to passed-in arguments
        new_env = Env.new(
          outer: env,
          binds: params,
          exprs: args
        )

        # Evaluate function body in context of new environment
        EVAL(fn_body, new_env)
      end
    # ...
  end
end

Later on, when the user calls the function in their code, we just call the Proc: function.call(*args)

When that happens, our Proc object makes a call to EVAL, which allocates a stack frame. By now you know that this makes it unsafe to deeply recurse using this function, so we’ll need to apply a fix similar to what we did with conditionals.

First, instead of returning a Proc to internally represent the user-defined function, we just return a hash which remembers the function’s parameters and its unevaluated body.

def EVAL(ast, env)
  loop do
    case ast.first
    # ...
    when :"fn*"
      params = ast[1]
      fn_body = ast[2]

      return {
        ast: fn_body,
        params: params,
        env: env
      }

    # ...
  end
end

We then need to change how we execute function calls. Just like we did with conditionals, we reassign the ast variable in place.

In addition, we also have to create a new environment for execution. We handle that similarly, by replacing env in place with a newly defined environment with function inputs bound.

Then we call next to go back to the top of EVAL, and the interpreter starts evaluating the function body.

def EVAL(ast, env)
  loop do
    # ...

    # ---
    # function evaluation
    # ---

    ast = function[:ast]
    env = Env.new(
      outer: function[:env],
      binds: function[:params],
      exprs: args
    )

    next
  end
end

Now that conditionals and function calls are tail-optimized, our sum-to function should be able to run an arbitrary number of times without running out of stack space. Let’s try it out:

(sum-to 10000 0)
50005000

Voila, we now have a tail call optimized interpreter! To learn more, check out the full code diff on Github or the make-a-lisp guide.

Discuss on Hacker News