Computing Symbolic Gradient Vectors with Plain Haskell
Posted April 27, 2016 ‐ 5 min read
While writing my previous post, I was curious how easy it would be to implement 's automatic differentiation for back propagation. In TensorFlow's web site they call it 'automatic differentiation' but in fact they probably do 'symbolic differentiation', as . The difference between the two relates to whether the differentiation is done during the original computation or beforehand. It makes sense to do the latter, because then you can maintain a separate computational graph of the back propagation to perform the updates.
I've looked into this topic in theecosystem, and found many useful and extensive libraries, namely by Edward Kmett. However, to use these libraries or understand many of the blog posts on the subject requires some advanced Haskell, and I was wondering whether one can get going with differentiation using very basic and lean use of Haskell.
So, how easy it would be to compute gradients of single-output functions, using Haskell with only the basic arsenal at our hands?
Data and imports
First, we perform some imports and declare the basic data type to hold our expression tree:
data Expr = Term Int -- 'Term 0' is x0, 'Term 1' is x1, etc.. | Lit Float -- Constant numbers | Neg Expr -- -f | Mul Expr Expr -- a + b | Add Expr Expr -- a * b | Sin Expr -- sin a | Cos Expr -- cos a
We shall be able to derive symbolic gradients for any function built with this data type.
Poor man's pretty-printing
One cannot go by without a nice String representation:
Term v) = concat ["x", show v] fshow (Lit v) = show v fshow (Mul e1 e2) = concat ["(", fshow e1, " * ", fshow e2, ")"] fshow (Add e1 e2) = concat ["(", fshow e1, " + ", fshow e2, ")"] fshow (Neg e) = concat ["-(", fshow e, ")"] fshow (Sin e) = concat ["sin(", fshow e, ")"] fshow (Cos e) = concat ["cos(", fshow e, ")"] fshow (
This implementation is basic in so that a sequence of summations will bear a horrible representation similar to
(x1 + (x2 + (x3 + (...)))) - however it's enough to get us going.
Thedemonstrates with the following function:
$$ f(x_1, x_2) = \sin x_1 + x_1x_2 $$
It should be easy enough to represent it with our Haskell data, and use
fshow from above:
λ> let wikipediaFunc = (Sin (Term 1)) `Add` ((Term 1) `Mul` (Term 2)) λ> fshow wikipediaFunc "(sin(x1) + (x1 * x2))"
gradient function below takes an expression, and returns a map from each term number to the expression that computes it. The definition of the function is recursive and based on known simple derivation rules:
Neg e) = Map.map Neg (gradient e) gradient (Cos e) = Map.map (Mul (Neg (Sin e))) (gradient e) gradient (Sin e) = Map.map (Mul (Cos e)) (gradient e) gradient (Term i) = Map.fromList [(i, Lit 1.0)] gradient (Lit _) = Map.empty gradient (Add e1 e2) = Map.unionWith Add (gradient e1) (gradient e2) gradient (Mul e1 e2) = Map.unionWith Add (Map.map (Mul e2) (gradient e1)) (Map.map (Mul e1) (gradient e2)) gradient (
The interesting parts are where
Map.unionWith is used for addition and multiplication. Notice how easily the
Mul part relates to the known derivation rule:
$$(f(x)g(x))' = g(x)f'(x) + g'(x)f(x)$$
The documentation for
Before testing it, we'll add just two helper functions. The first function simplifies expressions by getting rid of the $1.0$ literals we have added.
Mul (Lit 1.0) e) = simplify e simplify (Mul e (Lit 1.0)) = simplify e simplify (Add e1 e2) = Add (simplify e1) (simplify e2) simplify (Mul e1 e2) = Mul (simplify e1) (simplify e2) simplify e = e simplify (
The second function will do all the work at the program's top level to compute the gradient and print it:
= do putStrLn $ "f(..) = " ++ fshow func forM_ (Map.toList $ gradient func) $ \(k, v) -> do putStrLn $ "∂f / ∂" ++ fshow (Term k) ++ " = " ++ (fshow . simplify) v showGradient func
Does it work?
> showGradient wikipediaFunc f(..) = (sin(x1) + (x1 * x2)) ∂f / ∂x1 = (cos(x1) + x2) ∂f / ∂x2 = x1
Looks that it does. We have arrived at the same results as Wikipedia.
Will it work with something more complex?
> showGradient (Sin (Mul (Term 2 `Add` Lit 5.1) $ Cos (Term 1))) `Mul` (Term 1) `Mul` (Term 3) f(..) = ((sin(((x2 + 5.1) * cos(x1))) * x1) * x3) ∂f / ∂x1 = (x3 * ((x1 * (cos(((x2 + 5.1) * cos(x1))) * ((x2 + 5.1) * -(sin(x1))))) + sin(((x2 + 5.1) * cos(x1))))) ∂f / ∂x2 = (x3 * (x1 * (cos(((x2 + 5.1) * cos(x1))) * cos(x1)))) ∂f / ∂x3 = (sin(((x2 + 5.1) * cos(x1))) * x1)
Comparing with, it seems to get it right.
Advance extensions of what I illustrated here can add a considerable amount of functionality and ease of use. We will definitely need to support matrices for instance, if we would like to derive a back-propagation graph. You can browse thepackage to get some ideas.