free monad implemented with a type-aligned list
This document provides an example of how a free monad could be implemented using a type-aligned list. It introduces lower-level techniques that can be used with type-aligned lists, such as destructuring via the uncons1
method.
This document assumes that the reader is already familiar with the concept of free monads. If you want an introduction to free monads and their common use-cases, there are a number of resources available online such as the Cats free monad documentation.
There are many ways to implement free monads in Scala, and this particular implementation is mostly meant for demonstrative/learning purposes. Having said that, this is a fairly reasonable way to implement a free monad in Scala. Unlike many naive Scala implementations of free monads, it is stack-safe (that is, it won’t throw a stack overflow exception). It’s performance should also be reasonable (though this should be verified with benchmarks).
the FreeMonad data structure
We start off by defining an abstract FreeMonad type for an effect type F[_]
and result value of type A
:
sealed abstract class FreeMonad[F[_], A]
This is a bare-bones FreeMonad
type and at this point doesn’t define anything useful; we’ll add that later.
We’ll need some ways to construct a FreeMonad
. The two most straightforward ways to create a FreeMonad[F[_], A]
value are by simply wrapping an A
value (the Pure
constructor) or an F[A]
value (the Suspend
constructor). Both Pure
and Suspend
wrap a value without providing a “continuation” function to later apply to the value. We’ll make them extend a common NoCont
(_no continuation_) super-type which will end up being useful later on.
/**
* A free monad that wraps a value and does not provide any "continuation" functions to apply to
* the value
*/
sealed abstract class NoCont[F[_], A] extends FreeMonad[F, A]
/** A pure `A` value */
final case class Pure[F[_], A](value: A) extends NoCont[F, A]
/** An `A` value in an `F` effect context */
final case class Suspend[F[_], A](value: F[A]) extends NoCont[F, A]
def pure[F[_], A](value: A): FreeMonad[F, A] = Pure(value)
def suspend[F[_], A](value: F[A]): FreeMonad[F, A] = Suspend(value)
The third and final way to construct a FreeMonad
value is to call flatMap
on a FreeMonad[F, A]
value, providing a function of type A => FreeMonad[F, B]
to produce a FreeMonad[F, B]
value. We’ll refer to functions of this A => FreeMonad[F, B]
form as continuations, as they describe a way to continue the computation once an A
value has been produced. In fact, let’s go ahead and define a Cont
type alias for continuation functions:
/** A continuation that takes an A value and returns a new FreeMonad to continue the computation */
type Cont[F[_], A, B] = A => FreeMonad[F, B]
Now we can define a constructor for free monads that are created via flatMap
. It will have two fields:
- The initial
FreeMonad[F, A]
value upon whichflatMap
was called (init
) - A non-empty list of continuation functions (
conts
). This will have a single value ifflatMap
was only called once, and will have an additional value for each time thatflatMap
is called.
final private case class FlatMapped[F[_], A, B](
init: NoCont[F, A],
conts: TANonEmptyList.Rev[Cont[F, *, *], A, B])
extends FreeMonad[F, B]
The astute reader may have noticed that we used TANonEmptyList.Rev
as opposed to TANonEmptyList
for the continuations. This is an optimization and isn’t strictly necessary. Prepending to a TANonEmptyList
is more efficient than appending to it (constant time as opposed to linear time), so we prepend to the continuation list every time that flatMap
is called, resulting in a reversed list of continuations. Operations that consume the free monad value will most likely need to reverse the list, which is a linear-time operation, but better one linear-time operation than many.
putting the monad in free monad
Now we have all the pieces in place to be able to define a Monad
instance for Free
!
implicit def monadForFreeMonad[F[_]]: Monad[FreeMonad[F, *]] =
new StackSafeMonad[FreeMonad[F, *]] {
override def pure[A](x: A): FreeMonad[F, A] = Pure(x)
override def flatMap[A, B](fa: FreeMonad[F, A])(f: A => FreeMonad[F, B]): FreeMonad[F, B] =
fa match {
case x: NoCont[F, A] => FlatMapped(x, TANonEmptyList.revOne[Cont[F, *, *], A, B](f))
case FlatMapped(init, conts) => FlatMapped(init, f :: conts)
}
}
implementing foldMap
A free monad value is only useful if you can consume it. One of the most common ways to consume a free monad value is with a foldMap
method that recursively steps through the free monad value, translating all F[_]
values into G[_]
values via a provided FunctionK[F, G]
. The signature for foldMap
looks like this:
def foldMap[F[_], G[_], A](fa: FreeMonad[F, A])(f2g: FunctionK[F, G])(implicit G: Monad[G]): G[A]
Implementing foldMap
for NoCont
values is pretty straightforward, so let’s start with that:
def foldMapNoCont[F[_], G[_], A](fa: NoCont[F, A])(f2g: FunctionK[F, G])(implicit
G: Monad[G]): G[A] =
fa match {
case Pure(value) => G.pure(value)
case Suspend(value) => f2g(value)
}
Implementing foldMap
for the FlatMapped
case in a stack-safe manner is significantly more complicated. We’ll take the approach of keeping track of the current input value and a stack (a TANonEmptyList
structure) of continuations that still need to be run. First let’s define some helper type aliases for the (input, continuation) stack pairs:
/**
* A stack of continuations.
*
* @tparam I the input type of the first continuation
* @tparam O the output type of the last continuation
*/
type Conts[F[_], I, O] = TANonEmptyList[Cont[F, *, *], I, O]
/**
* A stack of continuations along with the input for the first continuation.
*
* This is essentially a tuple `(I, Conts[F, I, O])` of an input `I` value and a list of
* continuations to pass the input value into. `Tuple2K` is used instead of a vanilla tuple,
* because the `I` type is an unknown/existential type; `Tuple2K` ensures that it lines up between
* the input value and the first continuation.
*/
type InputAndConts[F[_], O] = TATuple2K[Id, Conts[F, *, O]]
/** Simple helper method to create `InputAndConts` values */
@inline def makeInputAndConts[F[_], I, A](input: I, conts: Conts[F, I, A]): InputAndConts[F, A] =
TATuple2K[Id, Conts[F, *, A], I](input, conts)
We’ll utilize these helpers in a function that performs a single “step” of the foldMap
:
- Pop the next continuation off of the stack.
- Pass the current input value into the popped continuation.
- If any further continuations are returned, prepend them to the stack.
- Find the input value for the next continuation.
- If there are no more continuations, return a
Right
wrapping the final value. Otherwise return aLeft
with the remaining continuation stack and the input value for the next continuation.
def foldMapStep[F[_], G[_], A](inputAndConts: InputAndConts[F, A])(f2g: FunctionK[F, G])(implicit
G: Monad[G]): G[Either[InputAndConts[F, A], A]] =
inputAndConts.second.uncons1 match {
// last continuation on the stack
case Left(cont) =>
// run the continuation on the input value
cont(inputAndConts.first) match {
// nothing to add to the stack; we are done
case y: NoCont[F, A] => foldMapNoCont(y)(f2g).map(Right(_))
// new continuations to add to the stack
case FlatMapped(init, conts) =>
foldMapNoCont(init)(f2g).map(x => Left(makeInputAndConts(x, conts.reverse)))
}
// multiple continuations on the stack
case Right(conts) =>
// run the first continuation on the input value
conts.head(inputAndConts.first) match {
// the result doesn't add any continuations to the stack
case y: NoCont[F, conts.P] =>
foldMapNoCont(y)(f2g).map(x => Left(makeInputAndConts(x, conts.tail)))
// the result adds continuations on the stack
case FlatMapped(init, additionalConts) =>
foldMapNoCont(init)(f2g).map(x =>
Left(makeInputAndConts(x, conts.tail.prependReversed(additionalConts))))
}
}
Finally we can define foldMap
:
def foldMap[F[_], G[_], A](fa: FreeMonad[F, A])(f2g: FunctionK[F, G])(implicit
G: Monad[G]): G[A] =
fa match {
// simple case of no continations; we are done!
case x: NoCont[F, A] => foldMapNoCont(x)(f2g)
// continuations exist; we'll need to step through the stack of continuations one at a time
case FlatMapped(init, conts) =>
// calculate the initial value and then recursively `step` through the continuation stack
foldMapNoCont(init)(f2g).flatMap(v =>
G.tailRecM(makeInputAndConts(v, conts.reverse))(foldMapStep(_)(f2g)))
}
using foldMap
To convince ourselves that our foldMap
implementation is correct and stack-safe, we can use a classic example of using Function0
as the F[_]
type to create a trampoline and compute a value via two mutually-recursive functions that would not be stack-safe in the absence of trampolining.
type Trampoline[A] = FreeMonad[Function0, A]
def defer[A](delayed: () => Trampoline[A]): Trampoline[A] =
FreeMonad.pure[Function0, Unit](()).flatMap(_ => delayed())
def even(n: Int): Trampoline[Boolean] =
if (n eqv 0) FreeMonad.pure(true) else defer(() => odd(n - 1))
def odd(n: Int): Trampoline[Boolean] =
if (n eqv 0) FreeMonad.pure(false) else defer(() => even(n - 1))
val evaluateFunction0: FunctionK[Function0, Id] = new FunctionK[Function0, Id] {
def apply[A](value: Function0[A]): A = value()
}
def evalTrampoline[A](value: Trampoline[A]): A = FreeMonad.foldMap(value)(evaluateFunction0)
assert(evalTrampoline(even(100000)))
assert(!evalTrampoline(even(100001)))