[Haskell-cafe] Hit a wall with the type system

Chris Smith cdsmith at twu.net
Wed Nov 28 23:02:20 EST 2007


I was talking to a few people about this on #haskell, and it was 
suggested I ask here.  I should say that I'm playing around here; don't 
mistake this for an urgent request or a serious problem.

Suppose I wanted to implement automatic differentiation of simple 
functions on real numbers; then I'd take the operations from Num, 
Fractional, and Floating, and define how to perform them on pairs of 
values and their differentials, and then I'd write a differentiate 
function... but finding an appropriate type for that function seems to 
be a challenge.

We have:

1. Differentiating a function of the most general type (Num a => a -> a) 
should produce a result of type (Num a => a -> a).

2. Differentiating a function of the more specific type (Fractional a => 
a -> a) should produce a result of that type (Fractional a => a -> a).

3. Differentiating a function of the most specific type (Floating a => a 
-> a) should produce a result of type (Floating a => a -> a).

4. BUT, differentiating a function which is of a more specific type than 
(Floating a => a -> a) is not, in general, possible.

So differentiate should have type A a => (forall b. A b => b -> b) -> a 
-> a, but ONLY if the type class A is a superclass of Floating.

Two partial solutions are: I can just define the differentiate function 
for Floating; but that means if I differentiate (\x -> x + 1), the 
result is a function only on floating point numbers, which is less than 
desirable.  Or, I can define several functions: say, diffNum, 
diffFractional, and diffFloating... all of which have precisely the same 
implementation, but different types and require copy/paste to make them 
work.

Any thoughts?

For reference, here's the code I kludged together.  (Again, I'm only 
playing around... so I wrote this very quickly and may have gotten some 
things wrong; don't use my code without checking it first!  In 
particular, I know that this code produces derivative functions whose 
domain is too large.)

> data AD a = AD a a deriving Eq
> 
> instance Show a => Show (AD a) where
>     show (AD x e) = show x ++ " + " ++ show e ++ " eps"
> 
> instance Num a => Num (AD a) where
>     (AD x e) + (AD y f)       = AD (x + y)         (e + f)
>     (AD x e) - (AD y f)       = AD (x - y)         (e - f)
>     (AD x e) * (AD y f)       = AD (x * y)         (e * y + x * f)
>     negate (AD x e)           = AD (negate x)      (negate e)
>     abs (AD 0 _)              = error "not differentiable: |0|"
>     abs (AD x e)              = AD (abs x)         (e * signum x)
>     signum (AD 0 e)           = error "not differentiable: signum(0)"
>     signum (AD x e)           = AD (signum x)      0
>     fromInteger i             = AD (fromInteger i) 0
> 
> instance Fractional a => Fractional (AD a) where
>     (AD x e) / (AD y f)       = AD (x / y) ((e * y - x * f) / (y * y))
>     recip (AD x e)            = AD (1 / x) ((-e) / (x * x))
>     fromRational x            = AD (fromRational x) 0
> 
> instance Floating a => Floating (AD a) where
>     pi                        = AD pi        0
>     exp (AD x e)              = AD (exp x)   (e * exp x)
>     sqrt (AD x e)             = AD (sqrt x)  (e / (2 * sqrt x))
>     log (AD x e)              = AD (log x)   (e / x)
>     (AD x e) ** (AD y f)      = AD (x ** y)  (e * y * (x ** (y-1)) +
>                                               f * (x ** y) * log x)
>     sin (AD x e)              = AD (sin x)   (e * cos x)
>     cos (AD x e)              = AD (cos x)   (-e * sin x)
>     asin (AD x e)             = AD (asin x)  (e / sqrt (1 - x ** 2))
>     acos (AD x e)             = AD (acos x)  (-e / sqrt (1 - x ** 2))
>     atan (AD x e)             = AD (atan x)  (e / (1 + x ** 2))
>     sinh (AD x e)             = AD (sinh x)  (e * cosh x)
>     cosh (AD x e)             = AD (cosh x)  (e * sinh x)
>     asinh (AD x e)            = AD (asinh x) (e / sqrt (x^2 + 1))
>     acosh (AD x e)            = AD (acosh x) (e / sqrt (x^2 - 1))
>     atanh (AD x e)            = AD (atanh x) (e / (1 - x^2))
> 
> diffNum        :: Num b        => (forall a. Num a        => a -> a) -> b -> b
> diffFractional :: Fractional b => (forall a. Fractional a => a -> a) -> b -> b
> diffFloating   :: Floating b   => (forall a. Floating a   => a -> a) -> b -> b
> 
> diffNum f x        = let AD y dy = f (AD x 1) in dy
> diffFractional f x = let AD y dy = f (AD x 1) in dy
> diffFloating f x   = let AD y dy = f (AD x 1) in dy

-- 
Chris Smith



More information about the Haskell-Cafe mailing list