Simple Automatic Differentiation in Haskell

Tuesday, May 16, 2023

Haskell is in my opinion one of the most extraordinarily unique languages ever made. It was originally developed for teaching and research purposes, but brought forth a number of now foundational principles such as type classes and monadic IO (shoutout Wikipedia). Basically, Haskell is super fun to use!

Standard disclaimer: This is an exploration of Haskell and automatic differentiation, certainly not code anyone would want to use in production.

What We Are Building
====================================================================================================================================================================================================================================================================================================================================

Automatic differentiation is the heart and sole powering libraries like Tensorflow, PyTorch, Jax and almost any other deep learning library. During the forward step of the training cycle of a feed forward neural network, the operations performed are stored on what is typically called a tape. At the end of the feed-forward stage, the tape is played-back, or back-propogated to get the gradients for the network. Today we will be implementing this storing of operations and back-propogating in a small Haskell program.

Defining Our Types
====================================================================================================================================================================================================================================================================================================================================

The goal of our program is to perform some mathematical computations and get the gradients of that computation. Most deep learning libraries use the abstraction of tensors, and we will be no different.

data (Fractional a, Eq a) => Tensor0D a = Tensor0D
  { tid :: Int,
    value :: a
  }
  deriving (Show, Eq)

Our tensor is incredibly simple, it has a tid and a value. Note that we allow any types that are instances of both Fractional and Eq. We will stay in the first dimension.

Tensors should be able to perform operations, and we need a way to store those operations for use later.

data Operator = MP | DV | AD | NA
  deriving (Eq)

data (Fractional a, Eq a) => Operation a = Operation Operator (Tensor0D a) (Tensor0D a) (Tensor0D a)
  deriving (Eq)

data (Fractional a, Eq a) => Tape a = Tape
  {
    operations:: [Operation a],
    nextTensorId:: Int
  }

The above code gives a few more types that we use to store tensor operations. Notice that as mentioned above, like most deep learning libraries, we have our own Tape. In our case our Tape also stores the nextTensorId. If we were writing in an imperative language like Rust, we probably would store the nextTensorId in an atomic, but because Haskell does not support that kind of programming, we have the Tape store the nextTensorId.

We will also write a helper function to make creating tensors easier.

createTensor :: (Fractional a, Eq a) => a -> State (Tape a) (Tensor0D a)
createTensor value = do
  tape <- get
  let tensorId = nextTensorId tape
  put $ tape {nextTensorId = tensorId + 1}
  return $ Tensor0D tensorId value 

Notice that we are working with the State monad. Using the State monad to wrap the tape is similar in idea to Tensorflow's "with tf.GradientTape()". For instance, if we wanted to create a tensor in Tensorflow with the context of monitoring operations it might look like the following:

with tf.GradientTape() as tape:
    newTensor = tf.constant(1)
    # Some series of operations that will be added to the tape

Our Operations
====================================================================================================================================================================================================================================================================================================================================

We will pursue relative simplicity and only implement three operations:

tAdd :: (Fractional a, Eq a) => Tensor0D a -> Tensor0D a -> State (Tape a) (Tensor0D a)
tAdd t1@Tensor0D { tid = id1, value = value1 } t2@Tensor0D { tid = id2, value = value2 } = do
  tape <- get
  let tensorId = nextTensorId tape
  let ops = operations tape
  let newTensor = Tensor0D tensorId (value1 + value2)
  put $ tape {nextTensorId = tensorId + 1, operations = Operation AD newTensor t1 t2 : ops}
  return newTensor

tMul :: (Fractional a, Eq a) => Tensor0D a -> Tensor0D a -> State (Tape a) (Tensor0D a)
tMul t1@Tensor0D { tid = id1, value = value1 } t2@Tensor0D { tid = id2, value = value2 } = do
  tape <- get
  let tensorId = nextTensorId tape
  let ops = operations tape
  let newTensor = Tensor0D tensorId (value1 * value2)
  put $ tape {nextTensorId = tensorId + 1, operations = Operation MP newTensor t1 t2 : ops}
  return newTensor

tDiv :: (Fractional a, Eq a) => Tensor0D a -> Tensor0D a -> State (Tape a) (Tensor0D a)
tDiv t1@Tensor0D { tid = id1, value = value1 } t2@Tensor0D { tid = id2, value = value2 } = do
  tape <- get
  let tensorId = nextTensorId tape
  let ops = operations tape
  let newTensor = Tensor0D tensorId (value1 / value2)
  put $ tape {nextTensorId = tensorId + 1, operations = Operation DV newTensor t1 t2 : ops}
  return newTensor

Each operation performs the same process:

Going Backwards
====================================================================================================================================================================================================================================================================================================================================

We have our Tape storing our Operations, now we need to go backwards through those operations to get our gradients. What does it mean to go backwards? Let's say we have the following Haskell code (this is valid code in the context of our program):

doComputations = do
  t0 <- createTensor 1
  t1 <- createTensor 2
  t2 <- createTensor 3
  t3 <- tMul t0 t1
  t4 <- tMul t3 t2
  return t2

Let's imagine computing the gradients for this by hand. We might choose to draw out a parse tree.

    *
   / \
  *  t2
 / \
t0 t1

With the values filled in (tensor | operation, value):

         (*, 6)
          / \
         /   \
        /     \
     (*, 2) (t2, 3)
      / \ 
     /   \      
    /     \
(t0, 1) (t1, 2)
If we start at the top, we can go backwards (represented as b) down the tree filling in the derivatives, multiplying through operations exactly how the chain rule teaches us. This is really just a different way to view the chain rule.

           (*, 6)
            / \
    (b, 3) /   \ (b, 2)
          /     \
       (*, 2) (t2, 3)
        / \ 
(b, 2) /   \ (b, 1)      
      /     \
  (t0, 1) (t1, 2)

To calculate the derivative for a tensor we simply follow the chain from the top multiplying each (b, value) together.

We can utilize this exact method to calculate the derivatives programmatically. Recall that Tape stores a list of Operations. We want to convert that list into a tree that follows the structure we wrote above, and then go backwards down the tree to get the derivatives.

Let's first build the tree.

data (Fractional a, Eq a) => TensorTree a = Empty | Cons (Tensor0D a) Operator (TensorTree a) (TensorTree a) deriving (Eq)

appendTree :: (Fractional a, Eq a) => Operation a -> TensorTree a -> TensorTree a
appendTree (Operation op t1 t2 t3) Empty = Cons t1 op (Cons t2 NA Empty Empty) (Cons t3 NA Empty Empty)
appendTree fullOp@(Operation op t1@Tensor0D { tid = opId } t2 t3) tree@(Cons treeTop@Tensor0D { tid = id } treeOp leftTree@(Cons Tensor0D { tid = leftId, value = leftValue } _ _ _) rightTree@(Cons Tensor0D { tid = rightId, value = rightValue } _ _ _))
  | opId == leftId = Cons treeTop treeOp (Cons t1 op (Cons t2 NA Empty Empty) (Cons t3 NA Empty Empty)) rightTree
  | opId == rightId = Cons treeTop treeOp leftTree (Cons t1 op (Cons t2 NA Empty Empty) (Cons t3 NA Empty Empty))
  | otherwise =
    let newLeftTree = appendTree fullOp leftTree
        newRightTree = appendTree fullOp rightTree
    in if newLeftTree /= leftTree
          then Cons treeTop treeOp newLeftTree rightTree
          else Cons treeTop treeOp leftTree newRightTree
appendTree _ tree@(Cons _ _ Empty Empty) = tree

buildTree :: (Fractional a, Eq a) => [Operation a] -> TensorTree a -> TensorTree a
buildTree (x:y) tree = buildTree y $ appendTree x tree
buildTree _ tree = tree

We introduced one new type TensorTree, a recursive data structure that can be Empty or have an Operator with a left and right tree.

The function buildTree takes a list of Operations and a current TensorTree, and returns a new TensorTree. The function itself is pretty boring and kind of gross, further exploration of this monstrosity doesn't feel necessary.

applyGrads :: (Fractional a) => Operator -> a -> a -> a -> (a, a)
applyGrads op parentGrads leftValue rightValue
  | op == MP = (parentGrads * rightValue, parentGrads * leftValue)
  | op == DV = (parentGrads * (1 / rightValue), parentGrads * (-1) * (leftValue / (rightValue * rightValue)))
  | op == AD = (parentGrads, parentGrads)

backTree :: (Fractional a, Eq a) => TensorTree a -> Map.Map Int a -> Map.Map Int a
backTree (Cons Tensor0D { tid = id } op leftTree@(Cons Tensor0D { tid = leftId, value = leftValue } _ _ _) rightTree@(Cons Tensor0D { tid = rightId, value = rightValue } _ _ _)) map =
  let pGrads = Map.findWithDefault 1 id map
      (leftGrads, rightGrads) = applyGrads op pGrads leftValue rightValue 
      leftMap = Map.delete id $ Map.insert leftId leftGrads map
      rightMap = Map.insert rightId rightGrads map
  in Map.unionWith (+) (backTree leftTree leftMap) (backTree rightTree rightMap)
backTree (Cons Tensor0D { tid = id } op Empty Empty) map = map

The function backTree takes a TensorTree a map, and returns an updated map with the gradients of the tensors in the TensorTree.

We have also created a helper function applyGrads which takes an Operator and left and right Fractional types, and returns the grads for left and right values for that operation.

Tying it all together
====================================================================================================================================================================================================================================================================================================================================

Let's augment our doComputations function to include more computations and return the gradients and final tensor. We will also write a helper function to facilitate building the TensorTree and going backwards through the TensorTree aptly called backward.

backward :: (Fractional a, Eq a) => State (Tape a) (Map.Map Int a)
backward = do
  tape <- get
  let ops = operations tape
  let tree = buildTree ops Empty
  return $ backTree tree Map.empty

doComputations :: (Fractional a, Eq a) => State (Tape a) (Tensor0D a, Map.Map Int a)
doComputations = do
  t0 <- createTensor 1.5
  t1 <- createTensor 2.5
  t2 <- createTensor 3.5
  t3 <- createTensor 4.5
  t4 <- tMul t0 t1
  t5 <- tDiv t4 t2
  t6 <- tAdd t5 t3
  grads <- backward
  return (t6, grads)

To execute this code, we include this very simple main function:

main :: IO ()
main = do
  let (tensor, grads) = evalState doComputations newTape 
  print tensor 
  print grads

Running the final program produces:

Tensor0D {tid = 6, value = 5.571428571428571}
fromList [(0,0.7142857142857142),(1,0.42857142857142855),(2,-0.30612244897959184),(3,1.0),(4,0.2857142857142857),(5,1.0)]

Which when compared with https://www.derivative-calculator.net/ is correct!

Thank you for reading!

the repo