I recently had to implement linear regression for a hobby project. As I’d never learned how to calculate it, this was a great opportunity to learn. As I dug in, I saw that I’d need to be making two passes over my data. Let’s see if we can fix that.

## The Problem

Simple linear regression is where we take a set of points in the plane and compute the slope and $$y$$-intercept of the line that most-closely fits the points.

Precisely, if our data is $$\left\{ (x_i, y_i) \vert i \in I \right\}$$, then we are looking for the pair $$m$$, $$b$$ which minimize the cost

$\sum_{i \in I} \left( m x_i + b - y_i \right)^2 \text{,}$

where $$m$$ is the slope of our model, $$b$$ is the $$y$$-intercept of our model, and $$y = mx + b$$ is the value our model predicts for arbitrary input $$x$$.

There are many approaches to this problem, ranging from multi-variable Calculus to geometry to linear algebra. We will use a statistical approach.

(Stating without proof) We have the following:

$m = r \frac{s_y}{s_x}$

and

$b = \bar{y} - m \bar{x}$

where

• $$\bar{x}$$, $$\bar{y}$$ are the means of $$x_i$$, $$y_i$$,
• $$s_x$$, $$s_y$$ are the standard deviations of $$x_i$$, $$y_i$$, and
• $$r$$ is Pearson’s correlation coefficient, defined as

$r = \frac{ \sum_i \left( (x_i - \bar{x}) (y_i - \bar{y}) \right) }{ \sqrt{ \sum_i (x_i - \bar{x})^2 \sum_i (y_i - \bar{y})^2 } }\text{.}$

(This useful video walks you through an example by hand.)

Recall the formula formula for standard deviation $$s_x$$ (and analogous for $$s_y$$),

$s_x = \sqrt{\frac{\sum_i (x_i - \bar{x})^2}{n - 1}} \text{,}$

where $$n$$ is the size of the sample $$I$$.

First, let’s simplify the formula for $$m$$.

$m = r \frac{s_y}{s_x} = r \frac{ \sqrt{\frac{\sum_i (y_i - \bar{y})^2}{n - 1}} }{ \sqrt{\frac{\sum_i (x_i - \bar{x})^2}{n - 1}} } = r \sqrt{\frac{\sum_i (y_i - \bar{y})^2}{\sum_i (x_i - \bar{x})^2}}$

$= \frac{ \sum_i \left( (x_i - \bar{x}) (y_i - \bar{y}) \right) }{ \sqrt{ \sum_i (x_i - \bar{x})^2 \sum_i (y_i - \bar{y})^2 } } \sqrt{\frac{\sum_i (y_i - \bar{y})^2}{\sum_i (x_i - \bar{x})^2}}$

$= \sum_i \left( (x_i - \bar{x}) (y_i - \bar{y}) \right) \sqrt{\frac{ \sum_i (y_i - \bar{y})^2 }{ \sum_i (x_i - \bar{x})^2 \sum_i (y_i - \bar{y})^2 \sum_i (x_i - \bar{x})^2 }}$

$= \sum_i \left( (x_i - \bar{x}) (y_i - \bar{y}) \right) \sqrt{\frac{1}{ \left( \sum_i (x_i - \bar{x})^2 \right)^2 }}$

$= \frac{ \sum_i \left( (x_i - \bar{x}) (y_i - \bar{y}) \right) }{ \sum_i (x_i - \bar{x})^2 }$

This final form is much nicer. To work with it, we no longer need to know the standard deviations $$s_x$$, $$s_y$$, though we still need to know the means $$\bar{x}$$, $$\bar{y}$$.

## Two-pass Implementation

The formulas for $$m$$ and $$b$$ above give us a straight-forward way to implement simple linear regression:

1. Find avg_x and avg_y,

2. Use avg_x and avg_y in the above formula for slope $$m$$, and

3. Use avg_x, avg_y, and slope in the above formula for intercept $$b$$.

This works, but it makes six passes over the list points. If this list is very long, then each pass will take a non-trivial amount of time. We can refactor this using monoids and foldMap to reduce the number of list passes we need to make.

A type a is a monoid if it has an associative binary operation <> with an identity element mempty.

foldMap takes a list [x] and a mapping function x -> a and, if a is a monoid, applies the function to the list elements and combines the results using <>. Concretely, foldMap f [x1, x2, x3] == f x1 <> f x2 <> f x3. foldMap of an empty list gives mempty.

We’ll use the monoids Sum a and (a, b) in our refactor.

Armed with these components, let’s take a closer look at calculating avg_x and avg_y. We’re currently passing over the list four times to calculate them.

We can get this down to three passes if we share the length.

Next, we can use the Sum a and (a, b) monoids to get both sums in one pass.

We’re down to two passes, but we can get the length at the same time we get the sums with this one neat trick: count the number of list elements by packing a Sum 1 into the result of our mapping function.

This is a bit tricky, so let’s look at an example to see how it works:

So we’ve gotten both averages in just one pass! Let’s refactor the rest of the program.

We make two passes over our list, one using the mapping function avg and one using the mapping function reg. This is already a huge improvement over our first draft.

## Monoid-valued Functions

At first glance, it seems like two passes is the best we can hope for, since reg needs to use of the results avg_x and avg_y of the first pass. Fortunately, as programmers we have a one-size-fits-all tool for abstracting away information that we don’t yet have: functions. We can make reg a function that gets avg_x and avg_y as inputs.

I assert that this solves all of our problems, but the astute reader might now notice a subtlety. reg has the special requirement that it needs to map our list elements into a monoid. In its original form we have reg :: (a, a) -> (Sum a, Sum a), which does indeed map a list element (a, a) into the monoid (Sum a, Sum a). Our new version of reg, though, has a different signature. We now have reg :: (a, a) -> ( (a, a) -> (Sum a, Sum a) ), which maps a list element (a, a) into the function type (a, a) -> (Sum a, Sum a). It turns out that the type (a, a) -> (Sum a, Sum a) is a perfectly reasonable monoid already.

The instance above says that any function type where the target type a is a monoid is itself a monoid in a formulaic way and with no conditions on the source type x. In other words, monoid-valued functions always form a monoid. We just need to make sure that our definition of <> is associative and that const empty is an identity for <>.

The three stanzas above demonstrate that f <> (g <> h) is (f <> g) <> h, f <> const mempty is f, and const mempty <> f is f. Thus, x -> a is a valid monoid.

## One-pass Implementation

Armed with this new method of building up monoids, we refactor our program to make just one pass over points.

## Analysis

Using the ideas above, I wrote three implementations of simple linear regression and compared their performance on a sample data set. I compiled each implementation to a standalone Macos binary using GHC 8.8.3 and the -O2 optimization flag.

The data set (provided by Kaggle) features ocean temperatures vs. salinity and consists of 864,862 data points, which I then quadrupled to 3,459,448 data points using Unix cat. I ran each implementation four times, using Unix time to record how long each run took.

Notably, I’m ignoring memory in this analysis, though memory complexity would be an important question to answer in a more comprehensive study. Also, for reasons unknown to me, the second implementations gives a slightly different result than the first and third. I should probably look into why that is.

### Naive Implementation

We directly port the math over to Haskell, ignoring performance concerns such as multiple list passes and common subexpressions. This one tends to finish in about 15 seconds.

### One-pass Implementation

We use built-in Monoid instances to condense our calculations into one list pass. Surprisingly, this one takes longer than the naive implementation, tending to finish in about 22 seconds.

### One-pass Unboxed Implementation

Not content to have the naive implementation win, I still had a few tricks up my sleeve to squeeze more calculations out of each CPU cycle. For one, since we know we’re consuming our whole input list, we can switch from foldMap to the strict foldMap'. For another, the one-pass implementation achieves a satisfying level of code reuse through polymorphism and composition of tuples, but those extra pointers add up. We can gain some performance by monomorphising and unboxing all of our data, at the cost of implementing Semigroup and Monoid instances by hand. This one tends to finish in under 12 seconds.

## Conclusions

I was caught off-guard by the test results. I thought it’d go without saying that the one-pass version would trounce the naive version. Instead, the opposite was true. The naive version did better than the polymorphic one-pass version, and I found myself having to really reach to beat the naive implementation in the one-pass unboxed version. It’d be interesting to compare the Core GHC produces for these three implementations to see how many list passes the naive version really ends up doing.

I think this blog post is less a story about how awesome monoids are (They are.) and more a story about how awesome GHC is (It is.). To me, it’s amazing that GHC can take the straightforward (and, frankly, kinda sloppy) naive implementation and compile it down to efficient code. To me, this reinforces a general theme in Haskell: do the obvious, simple thing first.

## Appendix

In trying to find out why the six-pass naive version did so well compared to the first one-pass version, I went ahead and implemented the six-pass naive version in python.

The six-pass python version runs in about 3 seconds.

So my new question is, why the hell is my Haskell so slow? What am I doing wrong here. Will update when I find out.

## Appendix B

Profiling immediately pointed to readM as the culprit, taking 78% of the time. After a quick search for “Haskell is read slow?”, I found this SO question. The answer suggested refactoring to use bytestring instead, and that cut a big chunk of time out.

Refactored to use bytestring, the naive implementation takes just over 4 seconds, the polymorphic one-pass version takes just over 7 seconds, and the strict unboxed one-pass version takes just under 4 seconds, still slower than the naive Python implementation, which takes just under 3.5 seconds.

I think the data set I’m using (about 3.5 million list elements, about 42 megabytes) is just too small to make multiple passes a problem; however, on a significantly larger dataset, such as one that could not fit in memory, a one-pass implementation would OOM building up the closure needed to compute xys and xxs, so I’m not sure what to do in that situation. If you have an idea leave a comment, or reply on Twitter.

## Appendix 3

Sat down with my math friends (virtually) and we came up with a way to do it in constant memory. Expect that, and some interesting observations made by my Twitter friends, in a follow up post.