Into CPS, never to return

December 25, 2024

CPS, or continuation-passing style, is an intermediate representation for programs, particularly functional programs. It’s used in compilers for languages such as SML and Scheme.

In CPS, there are two rules: first, that function/operator arguments must always be trivial; second, that function calls do not return. From this, a lot falls out.

In this post, we’ll introduce CPS by building a simple (Plotkin1) CPS transform from a small Scheme-like language. We’ll sketch some optimizations on the IR. Then we’ll look at a couple of the common ways to actually compile the IR for execution.

Mini-Scheme

We have integers: 5

We have some operations on the integers: (+ 1 2), (< 3 4) (returns 1 or 0)

We can bind variables: (let ((x 1)) x) / (letrec ...) ?

We can create single-parameter functions2: (lambda (x) (+ x 1)) and they can close over variables

We can call functions: (f x)

We can branch: (if (< x y) x y) (where we have decided to use 0 and 1 as booleans)

How do I…?

We’re going to implement a recursive function called cps incrementally, starting with the easy forms of the language and working up from there. Many people like implementing the compiler both in Scheme and for Scheme but I find that all the quasiquoting makes everything fussier than it should be and obscures the lesson, so we’re doing it in Python.

This means we have a nice clear separation of code and data. Our Python code is the compiler and we’ll lean on Python lists for S-expressions. Here’s what some sample Scheme programs might look like as Python lists:

5

["+", 1, 2]

["let", [["x", 1]], "x"]

["lambda", ["x"], ["+", "x", 1]]

Our cps function will take two arguments. The first argument, exp, is the expression to compile. The second argument, k, is a continuation. We have to do something with our values, but CPS requires that functions never returns. So what do we do? Call another function, of course.

This means that the top-level invocation of cps will be passed some useful top-level continuation like print-to-screen or write-to-file. All child invocations of cps will be passed either that continuation, a manufactured continuation, or a continuation variable.

cps(["+", 1, 2], "$print-to-screen")
# ...or...
cps(["+", 1, 2], ["cont", ["v"], ...])

So a continuation is just another function. Kind of.

While you totally can generate real first-class functions for use as continuations, it can often be useful to partition your CPS IR by separating them. All real (user) functions will take a continuation as a last parameter—for handing off their return values— and can arbitrarily escape, whereas all continuations are generated and allocated/freed in a stack-like manner. (We could even implement them using a native stack if we wanted. See “Partitioned CPS” and “Recovering the stack” from Might’s page.)

For this reason we syntactically distinguish IR function forms ["fun", ["x", "k"], ...] from IR continuation forms ["cont", ["x"], ...]. Similarly, we distinguish function calls ["f", "x"] from continuation calls ["$call-cont", "k", "x"] (where $call-cont is a special form known to the compiler).

Let’s look at how we compile integers into CPS:

def cps(exp, k):
    match exp:
        case int(_):
            return ["$call-cont", k, exp]
    raise NotImplementedError(type(exp))  # TODO

cps(5, "k")  # ["$call-cont", "k", 5]

Integers satisfy the trivial requirement. So does all constant data (if we had strings, floats, etc), variable references, and even lambda expressions. None of these require recursive evaluation, which is the core of the triviality requirement. All of this requires that our nested AST get linearized into a sequence of small operations.

Variables are similarly straightforward to compile. We leave the variable names as-is for now in our IR, so we need not keep an environment parameter around.

def cps(exp, k):
    match exp:
        case int(_) | str(_):
            return ["$call-cont", k, exp]
    raise NotImplementedError(type(exp))  # TODO

cps("x", "k")  # ["$call-cont", "k", "x"]

Now let’s look at function calls. Function calls are the first type of expression that requires recursively evaluating subexpressions. To evaluate (f x), for example, we evaluate f, then x, then do a function call. The order of evaluation is not important to this post; it is a semantic property of the language being compiled.

To convert to CPS, we have to both do recursive compilation of the arguments and also synthesize our first continuations!

To evaluate a subexpression, which could be arbitrarily complex, we have to make a recursive call to cps. Unlike normal compilers, this doesn’t return a value. Instead, you pass it a continuation (does the word “callback” help here?) to do future work when that value has a name. To generate compiler-internal names, we have a gensym function that isn’t interesting and returns unique strings.

The only thing that differentiates it from, for example, a JavaScript callback, is that it’s not a Python function but instead a function in the generated code.

def cps(exp, k):
    match exp:
        case [func, arg]:
            vfunc = gensym()
            varg = gensym()
            return cps(func, ["cont", [vfunc],
                              cps(arg, ["cont", [varg],
                                        [vfunc, varg, k]])])
    # ...

cps(["f", 1], "k")
# ["$call-cont", ["cont", ["v0"],
#            ["$call-cont", ["cont", ["v1"],
#                       ["v0", "v1", "k"]],
#                     1]],
#          "f"]

Note that our generated function call from (f x) now also has a continuation argument that was not there before. This is because (f x) does not return anything, but instead passes the value to the given continuation.

Calls to primitive operators like + are our other interesting case. Like function calls, we evaluate the operands recursively and pass in an additional continuation argument. Note that not all CPS implementations do this for simple math operators; some choose to allow simple arithmetic to actually “return” values.

def gensym(): ...

def cps(exp, k):
    match exp:
        case [op, x, y] if op in ["+", "-"]:
            vx = gensym()
            vy = gensym()
            return cps(x, ["cont", [vx],
                           cps(y, ["cont", [vy],
                                   [f"${op}", vx, vy, k]])])
    # ...

cps(["+", 1, 2], "k")
# ["$call-cont", ["cont", ["v0"],
#            ["$call-cont", ["cont", ["v1"],
#                       ["$+", "v0", "v1", "k"]],
#                     2]],
#          1]

We also create a special form for the operator in our CPS IR that begins with $. So + gets turned into $+ and so on. This helps distinguish operator invocations from function calls.

Now let’s look at creating functions. Lambda expressions such as (lambda (x) (+ x 1)) need to create a function at run-time and that function body contains some code. To “return”, we use $call-cont as usual, but we have to also remember to create a new fun form with a continuation parameter (and then thread that through to the function body).

def cps(exp, k):
    match exp:
        case ["lambda", [arg], body]:
            vk = gensym("k")
            return ["$call-cont", k,
                        ["fun", [arg, vk], cps(body, vk)]]
    # ...

cps(["lambda", ["x"], "x"], "k")
# ["$call-cont", "k",
#                ["fun", ["x", "k0"],
#                        ["$call-cont", "k0", "x"]]]

Alright, last in this mini language is our if expression: (if cond iftrue iffalse) where all of cond, iftrue, and iffalse can be arbitrarily nested expressions. This just means we call cps recursively three times.

We also add this new compiler builtin called ($if cond iftrue iffalse) that takes one trivial expression—the condition—and decides which of the branches to execute. This is roughly equivalent to a machine code conditional jump.

The straightforward implementation works just fine, but can you see what might go wrong?

def cps(exp, k):
    match exp:
        case ["if", cond, iftrue, iffalse]:
            vcond = gensym()
            return cps(cond, ["cont", [vcond],
                                ["$if", vcond,
                                   cps(iftrue, k),
                                   cps(iffalse, k)]])
        # ...

cps(["if", 1, 2, 3], "k")
# ["$call-cont", ["cont", ["v0"],
#                  ["$if", "v0",
#                              ["$call-cont", "k", 2],
#                              ["$call-cont", "k", 3]]],
#                1]

The problem is that our continuation, k, need not be a continuation variable—it could be an arbitrary complicated expression. Our implementation copies it into the compiled code twice, which in the worst case could lead to exponential program growth. Instead, let’s bind it to a name and then use the name twice.

def cps(exp, k):
    match exp:
        case ["if", cond, iftrue, iffalse]:
            vcond = gensym()
            vk = gensym("k")
            return ["$call-cont", ["cont", [vk],
                        cps(cond, ["cont", [vcond],
                                      ["$if", vcond,
                                        cps(iftrue, vk),
                                        cps(iffalse, vk)]])],
                        k]
        # ...

cps(["if", 1, 2, 3], "k")
# ["$call-cont", ["cont", ["k1"],
#                  ["$call-cont", ["cont", ["v0"],
#                                   ["$if", "v0",
#                                     ["$call-cont", "k1", 2],
#                                     ["$call-cont", "k1", 3]]],
#                                   1]],
#                  "k"]

Last, let can be handled by using a continuation, as we’ve bound temporary variables in previous examples. You could also handle it by desugaring it into ((lambda (x) body) value) but that will generate a lot more administrative overhead for your optimizer to get rid of later.

def cps(exp, k):
    match exp:
        case ["let", [name, value], body]:
            return cps(value, ["cont", [name],
                               cps(body, k)])
        # ...

cps(["let", ["x", 1], ["+", "x", 2]], "k")
# ['$call-cont', ['cont', ['x'],
#                 ['$call-cont', ['cont', ['v0'],
#                                 ['$call-cont', ['cont', ['v1'],
#                                                 ['$+', 'v0', 'v1', 'k']],
#                                  2]],
#                  'x']],
#  1]

There you have it. A working Mini-Scheme to CPS converter. My implementation is ~30 lines of Python code. It’s short and sweet but you might have noticed some shortcomings…

Meta-continuations

Now, you might have noticed that we’re giving names to a lot of trivial expressions—unnecessary cont forms used like let bindings. Why name the integer 3 if it’s trivial?

One approach people take to avoid this is meta-continuations, which I think many people call the “higher-order transform”. Instead of always generating conts, we can sometimes pass in a compiler-level (in this case, Python) function instead.

See Matt Might’s article and what I think is a working Python implementation.

This approach, though occasionally harder to reason about and more complex, cuts down on a significant amount of code before it is ever emitted. For few-pass compilers, for resource-constrained environments, for enormous programs, … this makes a lot of sense.

Another approach, potentially easier to reason about, is to lean on your optimizer. You’ll probably want an optimizer anyway, so you might as well use it to cut down your intermediate code too.

Optimizations

Just like any IR, it’s possible to optimize by doing recursive rewrites. We won’t implement any here (for now… maybe I’ll come back to this) but will sketch out a few common ones.

The simplest one is probably constant folding, like turning (+ 3 4) into 7. The CPS equivalent looks kind of like this:

["$+", "3", "4", "k"]  # => ["$call-cont", "k", 7]
["$if", 1, "t", "f"]  # => t
["$if", 0, "t", "f"]  # => f

An especially important optimization, particularly if using the simple CPS transformation that we’ve been using, is beta reduction. This is the process of turning expressions like ((lambda (x) (+ x 1)) 2) into (+ 2 1) by substituting the 2 for x. For example, in CPS:

["$call-cont", ["cont", ["k1"],
                 ["$call-cont", ["cont", ["v0"],
                                  ["$if", "v0",
                                    ["$call-cont", "k1", 2],
                                    ["$call-cont", "k1", 3]]],
                                  1]],
                 "k"]

# into
["$call-cont", ["cont", ["v0"],
                 ["$if", "v0",
                   ["$call-cont", "k", 2],
                   ["$call-cont", "k", 3]]],
                 1]

# into
["$if", 1,
  ["$call-cont", "k", 2],
  ["$call-cont", "k", 3]]

# into (via constant folding)
["$call-cont", "k", 2]

Substitution has to be scope-aware, and therefore requires threading an environment parameter through your optimizer.

As an aside: even if you “alphatise” your expressions to make them have unique variable bindings and names, you might accidentally create second bindings of the same names when substituting. For example:

# substitute(haystack, name, replacement)
substitute(["+", "x", "x"], "x", ["let", ["x0", 1], "x0"])

This would create two bindings of x0, which violates the global uniqueness property.

You may not always want to perform this rewrite. To avoid code blowup, you may only want to substitute if the function or continuation’s parameter name appears zero or one times in the body. Or, if it occurs more than once, substitute only if the expression being substituted is an integer/variable. This is a simple heuristic that will avoid some of the worst-case scenarios but may not be maximally beneficial—it’s a local optimum.

Another thing to be aware of is that substitution may change evaluation order. So even if you only have one parameter reference, you may not want to substitute:

((lambda (f) (begin (g) f))
 (do-a-side-effect))

As the program is right now, do-a-side-effect will be called before g and the result will become f. If you substitute do-a-side-effect for f in your optimizer, g will be called before do-a-side-effect. You can be more aggressive if your analyzer tells you what functions are side-effect free, but otherwise… be careful with function calls.

There are also more advanced optimizations but they go beyond an introduction to CPS and I don’t feel confident enough to sketch them out.

Alright, we’ve done a bunch of CPS→CPS transformations. Now we would like to execute the optimized code. To do that, we have to transform out of CPS into something designed to be executed.

To C, perchance to dream

In this section we’ll list a couple of approaches to generating executable code from CPS but we won’t implement any.

You can generate naive C code pretty directly from CPS. The fun and cont forms become top-level C functions. In order to support closures, you need to do free variable analysis and allocate closure structures for each. (See also the approach in scrapscript in the section called “functions”.) Then you can do a very generic calling convention where you pass closures around. Unfortunately, this is not very efficient and doesn’t guarantee tail-call elimination.

To support tail-call elimination, you can use trampolines. This mostly involves allocating a frame-like closure on the heap with each tail-call. If you have a garbage collector, this isn’t too bad; the frames do not live very long. In fact, if you instrument the factorial example in Eli’s blog post, you can see that the trampoline frames live only until the next one gets allocated.

This observation led to the development of the Cheney on the MTA algorithm, which uses the C call stack as a young generation for the garbage collector. It uses setjmp and longjmp to unwind the stack. This approach is used in CHICKEN Scheme and Cyclone Scheme. Take a look at Baker’s 1994 implementation.

If you don’t want to do any of this trampoline stuff, you can also do the One Big Switch approach which stuffs each of the funs and conts into a case in a massive switch statement. Calls become gotos. You can manage your stack roots pretty easily in one contiguous array. However, as you might imagine, larger Scheme programs might cause trouble with some C compilers.

Last, you need not generate C. You can also do your own lowering from CPS into a lower-level IR and then to some kind of assembly language.

Wrapping up

You have seen how to produce CPS, how to optimize it, and how to eliminate it. There’s much more to learn, if you are interested. Please send me material if you find it useful.

I had originally planned to write a CPS-based optimizer and code generator for scrapscript but I got stuck on the finer details of compiling pattern matching to CPS. Maybe I will return to this in the future by desugaring it to nested ifs or something.

Check out the code.

Acknowledgements

Thanks to Vaibhav Sagar and Kartik Agaram for giving feedback on this post. Thanks to Olin Shivers for an excellent course on compiling functional programming languages.

See also

  1. Earlier this year, my grandmother mentioned offhand that she was getting brunch with the Plotkins. I, midway through a course by Olin Shivers on compiling functional programming languages, did a double take. Surely she couldn’t mean… but yep, apparently my grandmother and Gordon Plotkin are friends! 

  2. It’s a several-line change to the compiler to handle multiple parameters but for this post it’s just noise so I leave it as an exercise.