free monad implemented with a type-aligned list

This document provides an example of how a free monad could be implemented using a type-aligned list. It introduces lower-level techniques that can be used with type-aligned lists, such as destructuring via the uncons1 method.

This document assumes that the reader is already familiar with the concept of free monads. If you want an introduction to free monads and their common use-cases, there are a number of resources available online such as the Cats free monad documentation.

There are many ways to implement free monads in Scala, and this particular implementation is mostly meant for demonstrative/learning purposes. Having said that, this is a fairly reasonable way to implement a free monad in Scala. Unlike many naive Scala implementations of free monads, it is stack-safe (that is, it won’t throw a stack overflow exception). It’s performance should also be reasonable (though this should be verified with benchmarks).

the FreeMonad data structure

We start off by defining an abstract FreeMonad type for an effect type F[_] and result value of type A:

sealed abstract class FreeMonad[F[_], A]

This is a bare-bones FreeMonad type and at this point doesn’t define anything useful; we’ll add that later.

We’ll need some ways to construct a FreeMonad. The two most straightforward ways to create a FreeMonad[F[_], A] value are by simply wrapping an A value (the Pure constructor) or an F[A] value (the Suspend constructor). Both Pure and Suspend wrap a value without providing a “continuation” function to later apply to the value. We’ll make them extend a common NoCont (_no continuation_) super-type which will end up being useful later on.

/**
 * A free monad that wraps a value and does not provide any "continuation" functions to apply to
 * the value
 */
sealed abstract class NoCont[F[_], A] extends FreeMonad[F, A]

/** A pure `A` value */
final case class Pure[F[_], A](value: A) extends NoCont[F, A]

/** An `A` value in an `F` effect context */
final case class Suspend[F[_], A](value: F[A]) extends NoCont[F, A]

def pure[F[_], A](value: A): FreeMonad[F, A] = Pure(value)

def suspend[F[_], A](value: F[A]): FreeMonad[F, A] = Suspend(value)

The third and final way to construct a FreeMonad value is to call flatMap on a FreeMonad[F, A] value, providing a function of type A => FreeMonad[F, B] to produce a FreeMonad[F, B] value. We’ll refer to functions of this A => FreeMonad[F, B] form as continuations, as they describe a way to continue the computation once an A value has been produced. In fact, let’s go ahead and define a Cont type alias for continuation functions:

/** A continuation that takes an A value and returns a new FreeMonad to continue the computation */
type Cont[F[_], A, B] = A => FreeMonad[F, B]

Now we can define a constructor for free monads that are created via flatMap. It will have two fields:

  • The initial FreeMonad[F, A] value upon which flatMap was called (init)
  • A non-empty list of continuation functions (conts). This will have a single value if flatMap was only called once, and will have an additional value for each time that flatMap is called.
final private case class FlatMapped[F[_], A, B](
  init: NoCont[F, A],
  conts: TANonEmptyList.Rev[Cont[F, *, *], A, B])
    extends FreeMonad[F, B]

The astute reader may have noticed that we used TANonEmptyList.Rev as opposed to TANonEmptyList for the continuations. This is an optimization and isn’t strictly necessary. Prepending to a TANonEmptyList is more efficient than appending to it (constant time as opposed to linear time), so we prepend to the continuation list every time that flatMap is called, resulting in a reversed list of continuations. Operations that consume the free monad value will most likely need to reverse the list, which is a linear-time operation, but better one linear-time operation than many.

putting the monad in free monad

Now we have all the pieces in place to be able to define a Monad instance for Free!

implicit def monadForFreeMonad[F[_]]: Monad[FreeMonad[F, *]] =
  new StackSafeMonad[FreeMonad[F, *]] {

    override def pure[A](x: A): FreeMonad[F, A] = Pure(x)

    override def flatMap[A, B](fa: FreeMonad[F, A])(f: A => FreeMonad[F, B]): FreeMonad[F, B] =
      fa match {
        case x: NoCont[F, A] => FlatMapped(x, TANonEmptyList.revOne[Cont[F, *, *], A, B](f))
        case FlatMapped(init, conts) => FlatMapped(init, f :: conts)
      }
  }

implementing foldMap

A free monad value is only useful if you can consume it. One of the most common ways to consume a free monad value is with a foldMap method that recursively steps through the free monad value, translating all F[_] values into G[_] values via a provided FunctionK[F, G]. The signature for foldMap looks like this:

def foldMap[F[_], G[_], A](fa: FreeMonad[F, A])(f2g: FunctionK[F, G])(implicit G: Monad[G]): G[A]

Implementing foldMap for NoCont values is pretty straightforward, so let’s start with that:

def foldMapNoCont[F[_], G[_], A](fa: NoCont[F, A])(f2g: FunctionK[F, G])(implicit
  G: Monad[G]): G[A] =
  fa match {
    case Pure(value) => G.pure(value)
    case Suspend(value) => f2g(value)
  }

Implementing foldMap for the FlatMapped case in a stack-safe manner is significantly more complicated. We’ll take the approach of keeping track of the current input value and a stack (a TANonEmptyList structure) of continuations that still need to be run. First let’s define some helper type aliases for the (input, continuation) stack pairs:

/**
 * A stack of continuations.
 *
 * @tparam I the input type of the first continuation
 * @tparam O the output type of the last continuation
 */
type Conts[F[_], I, O] = TANonEmptyList[Cont[F, *, *], I, O]

/**
 * A stack of continuations along with the input for the first continuation.
 *
 * This is essentially a tuple `(I, Conts[F, I, O])` of an input `I` value and a list of
 * continuations to pass the input value into. `Tuple2K` is used instead of a vanilla tuple,
 * because the `I` type is an unknown/existential type; `Tuple2K` ensures that it lines up between
 * the input value and the first continuation.
 */
type InputAndConts[F[_], O] = TATuple2K[Id, Conts[F, *, O]]

/** Simple helper method to create `InputAndConts` values */
@inline def makeInputAndConts[F[_], I, A](input: I, conts: Conts[F, I, A]): InputAndConts[F, A] =
  TATuple2K[Id, Conts[F, *, A], I](input, conts)

We’ll utilize these helpers in a function that performs a single “step” of the foldMap:

  • Pop the next continuation off of the stack.
  • Pass the current input value into the popped continuation.
  • If any further continuations are returned, prepend them to the stack.
  • Find the input value for the next continuation.
  • If there are no more continuations, return a Right wrapping the final value. Otherwise return a Left with the remaining continuation stack and the input value for the next continuation.
def foldMapStep[F[_], G[_], A](inputAndConts: InputAndConts[F, A])(f2g: FunctionK[F, G])(implicit
  G: Monad[G]): G[Either[InputAndConts[F, A], A]] =
  inputAndConts.second.uncons1 match {
    // last continuation on the stack
    case Left(cont) =>
      // run the continuation on the input value
      cont(inputAndConts.first) match {
        // nothing to add to the stack; we are done
        case y: NoCont[F, A] => foldMapNoCont(y)(f2g).map(Right(_))
        // new continuations to add to the stack
        case FlatMapped(init, conts) =>
          foldMapNoCont(init)(f2g).map(x => Left(makeInputAndConts(x, conts.reverse)))
      }
    // multiple continuations on the stack
    case Right(conts) =>
      // run the first continuation on the input value
      conts.head(inputAndConts.first) match {
        // the result doesn't add any continuations to the stack
        case y: NoCont[F, conts.P] =>
          foldMapNoCont(y)(f2g).map(x => Left(makeInputAndConts(x, conts.tail)))
        // the result adds continuations on the stack
        case FlatMapped(init, additionalConts) =>
          foldMapNoCont(init)(f2g).map(x =>
            Left(makeInputAndConts(x, conts.tail.prependReversed(additionalConts))))
      }
  }

Finally we can define foldMap:

def foldMap[F[_], G[_], A](fa: FreeMonad[F, A])(f2g: FunctionK[F, G])(implicit
  G: Monad[G]): G[A] =
  fa match {
    // simple case of no continations; we are done!
    case x: NoCont[F, A] => foldMapNoCont(x)(f2g)
    // continuations exist; we'll need to step through the stack of continuations one at a time
    case FlatMapped(init, conts) =>
      // calculate the initial value and then recursively `step` through the continuation stack
      foldMapNoCont(init)(f2g).flatMap(v =>
        G.tailRecM(makeInputAndConts(v, conts.reverse))(foldMapStep(_)(f2g)))
  }

using foldMap

To convince ourselves that our foldMap implementation is correct and stack-safe, we can use a classic example of using Function0 as the F[_] type to create a trampoline and compute a value via two mutually-recursive functions that would not be stack-safe in the absence of trampolining.

type Trampoline[A] = FreeMonad[Function0, A]

def defer[A](delayed: () => Trampoline[A]): Trampoline[A] =
  FreeMonad.pure[Function0, Unit](()).flatMap(_ => delayed())

def even(n: Int): Trampoline[Boolean] =
  if (n eqv 0) FreeMonad.pure(true) else defer(() => odd(n - 1))

def odd(n: Int): Trampoline[Boolean] =
  if (n eqv 0) FreeMonad.pure(false) else defer(() => even(n - 1))

val evaluateFunction0: FunctionK[Function0, Id] = new FunctionK[Function0, Id] {
  def apply[A](value: Function0[A]): A = value()
}

def evalTrampoline[A](value: Trampoline[A]): A = FreeMonad.foldMap(value)(evaluateFunction0)

assert(evalTrampoline(even(100000)))
assert(!evalTrampoline(even(100001)))
The source code for this page can be found here.