Simple Linear Regression in One Pass
Previous: The IO Rosetta Stone
Next: Proving My Point
I recently had to implement linear regression for a hobby project. As I’d never learned how to calculate it, this was a great opportunity to learn. As I dug in, I saw that I’d need to be making two passes over my data. Let’s see if we can fix that.
The Problem
Simple linear regression is where we take a set of points in the plane and compute the slope and \(y\)-intercept of the line that most-closely fits the points.
Precisely, if our data is \( \left\{ (x_i, y_i) \vert i \in I \right\} \), then we are looking for the pair \(m\), \(b\) which minimize the cost
\[ \sum_{i \in I} \left( m x_i + b - y_i \right)^2 \text{,} \]
where \(m\) is the slope of our model, \(b\) is the \(y\)-intercept of our model, and \(y = mx + b\) is the value our model predicts for arbitrary input \(x\).
There are many approaches to this problem, ranging from multi-variable Calculus to geometry to linear algebra. We will use a statistical approach.
(Stating without proof) We have the following:
\[ m = r \frac{s_y}{s_x} \]
and
\[ b = \bar{y} - m \bar{x} \]
where
- \(\bar{x}\), \(\bar{y}\) are the means of \(x_i\), \(y_i\),
- \(s_x\), \(s_y\) are the standard deviations of \(x_i\), \(y_i\), and
- \(r\) is Pearson’s correlation coefficient, defined as
\[ r = \frac{ \sum_i \left( (x_i - \bar{x}) (y_i - \bar{y}) \right) }{ \sqrt{ \sum_i (x_i - \bar{x})^2 \sum_i (y_i - \bar{y})^2 } }\text{.} \]
(This useful video walks you through an example by hand.)
Recall the formula formula for standard deviation \(s_x\) (and analogous for \(s_y\)),
\[ s_x = \sqrt{\frac{\sum_i (x_i - \bar{x})^2}{n - 1}} \text{,} \]
where \(n\) is the size of the sample \(I\).
First, let’s simplify the formula for \(m\).
\[ m = r \frac{s_y}{s_x} = r \frac{ \sqrt{\frac{\sum_i (y_i - \bar{y})^2}{n - 1}} }{ \sqrt{\frac{\sum_i (x_i - \bar{x})^2}{n - 1}} } = r \sqrt{\frac{\sum_i (y_i - \bar{y})^2}{\sum_i (x_i - \bar{x})^2}} \]
\[ = \frac{ \sum_i \left( (x_i - \bar{x}) (y_i - \bar{y}) \right) }{ \sqrt{ \sum_i (x_i - \bar{x})^2 \sum_i (y_i - \bar{y})^2 } } \sqrt{\frac{\sum_i (y_i - \bar{y})^2}{\sum_i (x_i - \bar{x})^2}} \]
\[ = \sum_i \left( (x_i - \bar{x}) (y_i - \bar{y}) \right) \sqrt{\frac{ \sum_i (y_i - \bar{y})^2 }{ \sum_i (x_i - \bar{x})^2 \sum_i (y_i - \bar{y})^2 \sum_i (x_i - \bar{x})^2 }} \]
\[ = \sum_i \left( (x_i - \bar{x}) (y_i - \bar{y}) \right) \sqrt{\frac{1}{ \left( \sum_i (x_i - \bar{x})^2 \right)^2 }} \]
\[ = \frac{ \sum_i \left( (x_i - \bar{x}) (y_i - \bar{y}) \right) }{ \sum_i (x_i - \bar{x})^2 } \]
This final form is much nicer. To work with it, we no longer need to know the standard deviations \(s_x\), \(s_y\), though we still need to know the means \(\bar{x}\), \(\bar{y}\).
Two-pass Implementation
The formulas for \(m\) and \(b\) above give us a straight-forward way to implement simple linear regression:
-
Find
avg_x
andavg_y
, -
Use
avg_x
andavg_y
in the above formula forslope
\(m\), and -
Use
avg_x
,avg_y
, andslope
in the above formula forintercept
\(b\).
This works, but it makes six passes over the list points
. If this list is very long, then each pass will take a non-trivial amount of time. We can refactor this using monoids and foldMap
to reduce the number of list passes we need to make.
A type a
is a monoid if it has an associative binary operation <>
with an identity element mempty
.
foldMap
takes a list [x]
and a mapping function x -> a
and, if a
is a monoid, applies the function to the list elements and combines the results using <>
. Concretely, foldMap f [x1, x2, x3] == f x1 <> f x2 <> f x3
. foldMap
of an empty list gives mempty
.
We’ll use the monoids Sum a
and (a, b)
in our refactor.
Armed with these components, let’s take a closer look at calculating avg_x
and avg_y
. We’re currently passing over the list four times to calculate them.
We can get this down to three passes if we share the length.
Next, we can use the Sum a
and (a, b)
monoids to get both sums in one pass.
We’re down to two passes, but we can get the length at the same time we get the sums with this one neat trick: count the number of list elements by packing a Sum 1
into the result of our mapping function.
This is a bit tricky, so let’s look at an example to see how it works:
So we’ve gotten both averages in just one pass! Let’s refactor the rest of the program.
We make two passes over our list, one using the mapping function avg
and one using the mapping function reg
. This is already a huge improvement over our first draft.
Monoid-valued Functions
At first glance, it seems like two passes is the best we can hope for, since reg
needs to use of the results avg_x
and avg_y
of the first pass. Fortunately, as programmers we have a one-size-fits-all tool for abstracting away information that we don’t yet have: functions. We can make reg
a function that gets avg_x
and avg_y
as inputs.
I assert that this solves all of our problems, but the astute reader might now notice a subtlety. reg
has the special requirement that it needs to map our list elements into a monoid. In its original form we have reg :: (a, a) -> (Sum a, Sum a)
, which does indeed map a list element (a, a)
into the monoid (Sum a, Sum a)
. Our new version of reg
, though, has a different signature. We now have reg :: (a, a) -> ( (a, a) -> (Sum a, Sum a) )
, which maps a list element (a, a)
into the function type (a, a) -> (Sum a, Sum a)
. It turns out that the type (a, a) -> (Sum a, Sum a)
is a perfectly reasonable monoid already.
The instance above says that any function type where the target type a
is a monoid is itself a monoid in a formulaic way and with no conditions on the source type x
. In other words, monoid-valued functions always form a monoid. We just need to make sure that our definition of <>
is associative and that const empty
is an identity for <>
.
The three stanzas above demonstrate that f <> (g <> h)
is (f <> g) <> h
, f <> const mempty
is f
, and const mempty <> f
is f
. Thus, x -> a
is a valid monoid.
One-pass Implementation
Armed with this new method of building up monoids, we refactor our program to make just one pass over points
.
Analysis
Using the ideas above, I wrote three implementations of simple linear regression and compared their performance on a sample data set. I compiled each implementation to a standalone Macos binary using GHC 8.8.3 and the -O2
optimization flag.
The data set (provided by Kaggle) features ocean temperatures vs. salinity and consists of 864,862 data points, which I then quadrupled to 3,459,448 data points using Unix cat
. I ran each implementation four times, using Unix time
to record how long each run took.
Notably, I’m ignoring memory in this analysis, though memory complexity would be an important question to answer in a more comprehensive study. Also, for reasons unknown to me, the second implementations gives a slightly different result than the first and third. I should probably look into why that is.
Naive Implementation
We directly port the math over to Haskell, ignoring performance concerns such as multiple list passes and common subexpressions. This one tends to finish in about 15 seconds.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
-- naive-linear-regression.hs
import Data.Maybe (mapMaybe)
main :: IO ()
main =
print . simpleLinearRegression . parse =<< getContents
readM :: Read a => String -> Maybe a
readM str =
case reads str of
[(x, "")] -> Just x
_ -> Nothing
parse :: String -> [(Double, Double)]
parse =
mapMaybe parseLine . lines . filter (/= '\r')
where
parseLine raw =
let (x, y) = break (== ',') raw
in (,) <$> readM x <*> readM (tail y)
simpleLinearRegression :: Fractional a => [(a, a)] -> (a, a)
simpleLinearRegression points =
(slope, intercept)
where
avg_x = sum [x | (x, _) <- points] / fromIntegral (length points)
avg_y = sum [y | (_, y) <- points] / fromIntegral (length points)
xys = sum [(x - avg_x) * (y - avg_y) | (x, y) <- points]
xxs = sum [(x - avg_x) * (x - avg_x) | (x, _) <- points]
slope = xys / xxs
intercept = avg_y - slope * avg_x
One-pass Implementation
We use built-in Monoid
instances to condense our calculations into one list pass. Surprisingly, this one takes longer than the naive implementation, tending to finish in about 22 seconds.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
-- one-pass-linear-regression.hs
import Data.Maybe (mapMaybe)
import Data.Monoid (Sum(Sum))
main :: IO ()
main =
print . simpleLinearRegression . parse =<< getContents
readM :: Read a => String -> Maybe a
readM str =
case reads str of
[(x, "")] -> Just x
_ -> Nothing
parse :: String -> [(Double, Double)]
parse =
mapMaybe parseLine . lines . filter (/= '\r')
where
parseLine raw =
let (x, y) = break (== ',') raw
in (,) <$> readM x <*> readM (tail y)
simpleLinearRegression :: Fractional a => [(a, a)] -> (a, a)
simpleLinearRegression points =
(slope, intercept)
where
avg (x, y) = (Sum 1, Sum x, Sum y)
reg (x, y) (avg_x, avg_y) =
(Sum $ x' * y', Sum $ x' * x')
where
x' = x - avg_x
y' = y - avg_y
((Sum n, Sum xs, Sum ys), getReg) =
foldMap (\p -> (avg p, reg p)) points
avg_x = xs / n
avg_y = ys / n
(Sum xys, Sum xxs) = getReg (avg_x, avg_y)
slope = xys / xxs
intercept = avg_y - slope * avg_x
One-pass Unboxed Implementation
Not content to have the naive implementation win, I still had a few tricks up my sleeve to squeeze more calculations out of each CPU cycle. For one, since we know we’re consuming our whole input list, we can switch from foldMap
to the strict foldMap'
. For another, the one-pass implementation achieves a satisfying level of code reuse through polymorphism and composition of tuples, but those extra pointers add up. We can gain some performance by monomorphising and unboxing all of our data, at the cost of implementing Semigroup
and Monoid
instances by hand. This one tends to finish in under 12 seconds.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
-- one-pass-strict-unboxed-linear-regression.hs
{-# LANGUAGE BangPatterns #-}
import Data.Foldable (foldMap')
import Data.Maybe (mapMaybe)
main :: IO ()
main =
print . simpleLinearRegression . parse =<< getContents
readM :: Read a => String -> Maybe a
readM str =
case reads str of
[(x, "")] -> Just x
_ -> Nothing
parse :: String -> [Pair]
parse =
mapMaybe parseLine . lines . filter (/= '\r')
where
parseLine raw =
let (x, y) = break (== ',') raw
in Pair <$> readM x <*> readM (tail y)
data Pair =
Pair {-# UNPACK #-} !Double {-# UNPACK #-} !Double
deriving Show
instance Semigroup Pair where
Pair x1 y1 <> Pair x2 y2 = Pair (x1 + x2) (y1 + y2)
instance Monoid Pair where
mempty = Pair 0 0
data Stat =
Stat
{-# UNPACK #-} !Double
{-# UNPACK #-} !Double
{-# UNPACK #-} !Double
!(Pair -> Pair)
instance Semigroup Stat where
Stat n1 x1 y1 f1 <> Stat n2 x2 y2 f2 =
Stat (n1 + n2) (x1 + x2) (y1 + y2) (f1 <> f2)
instance Monoid Stat where
mempty = Stat 0 0 0 (const mempty)
simpleLinearRegression :: [Pair] -> Pair
simpleLinearRegression points =
Pair slope intercept
where
avgReg !(Pair x y) =
Stat 1 x y $ \(Pair !avg_x !avg_y) ->
let
x' = x - avg_x
y' = y - avg_y
in
Pair (x' * y') (x' * x')
Stat n xs ys getReg = foldMap' avgReg points
avg_x = xs / n
avg_y = ys / n
Pair xys xxs = getReg (Pair avg_x avg_y)
slope = xys / xxs
intercept = avg_y - slope * avg_x
Conclusions
I was caught off-guard by the test results. I thought it’d go without saying that the one-pass version would trounce the naive version. Instead, the opposite was true. The naive version did better than the polymorphic one-pass version, and I found myself having to really reach to beat the naive implementation in the one-pass unboxed version. It’d be interesting to compare the Core GHC produces for these three implementations to see how many list passes the naive version really ends up doing.
I think this blog post is less a story about how awesome monoids are (They are.) and more a story about how awesome GHC is (It is.). To me, it’s amazing that GHC can take the straightforward (and, frankly, kinda sloppy) naive implementation and compile it down to efficient code. To me, this reinforces a general theme in Haskell: do the obvious, simple thing first.
Appendix
In trying to find out why the six-pass naive version did so well compared to the first one-pass version, I went ahead and implemented the six-pass naive version in python.
The six-pass python version runs in about 3 seconds.
So my new question is, why the hell is my Haskell so slow? What am I doing wrong here. Will update when I find out.
Appendix B
Profiling immediately pointed to readM
as the culprit, taking 78% of the time. After a quick search for “Haskell is read slow?”, I found this SO question. The answer suggested refactoring to use bytestring instead, and that cut a big chunk of time out.
Refactored to use bytestring, the naive implementation takes just over 4 seconds, the polymorphic one-pass version takes just over 7 seconds, and the strict unboxed one-pass version takes just under 4 seconds, still slower than the naive Python implementation, which takes just under 3.5 seconds.
I think the data set I’m using (about 3.5 million list elements, about 42 megabytes) is just too small to make multiple passes a problem; however, on a significantly larger dataset, such as one that could not fit in memory, a one-pass implementation would OOM building up the closure needed to compute xys
and xxs
, so I’m not sure what to do in that situation. If you have an idea leave a comment, or reply on Twitter.
Appendix 3
Sat down with my math friends (virtually) and we came up with a way to do it in constant memory. Expect that, and some interesting observations made by my Twitter friends, in a follow up post.
Previous: The IO Rosetta Stone
Next: Proving My Point