Iterative methods in Rust: Conjugate Gradient
Introduction: the Beauty and the Trap
Iterative methods (IM) power many magical-seeming uses of computers, from deep learning to PageRank. An IM repeatedly applies a simple recipe which improves an approximate solution to a problem. As a block of wood in the right hands gradually transforms into a figurine, the right method will produce a sequence of solutions approaching perfection from an arbitrarily bad initial value. IM are quite unlike most CS algorithms, and take some getting used to, but for many problems they are the only game in town. Unfortunately, implementations of IM are often hard to maintain (read, test, benchmark) because they mix the recipe with other concerns.
We start with a simple concrete example. Implementing an iterative
method often starts as a for
loop around the recipe (to try it out,
press the play button):
#![allow(unused)] fn main() { // Problem: minimize the convex parabola f(x) = x^2 + x // An iterative solution by gradient descent let mut x = 2.0; for i in 0..10 { // "2.0*x + 1.0" is the derivative of f(x) // so moving a little bit in the opposite direction x -= 0.2 * (2.0*x + 1.0); // should reduce the value of f(x)! println!("x_{} = {:.2}; f(x_{}) = {:.4}", i, x, i, x*x + x); } // An alternative that does not scale well to harder problems is to // use the fact `f` does not decrease/increase at the minimum to // find its location directly: // // `f` at a minimum zeros its derivative: 2*x + 1 = 0. // Rearranging terms gives the solution x = -1/2. }
... but that loop body already mixes multiple concerns. If it didn't
println
progress, we wouldn't even see it is working. If it didn't
stop after 10 iterations, it would go on forever. But a reader might
justly be confused why that limit is there: does the method work only
for 10 iterations? and perhaps a different user does not want any
intermediate output, must they reimplement the recipe? how should they
ensure the reimplementation is also correct? The approach taken in
this example does not allow additional functionalities to be composed
cleanly, and thus they build up like barnacles on a ship. Is there a
way to keep the recipe separate from other concerns?
The answer is yes. This series of posts is about such a way to
implement IM in Rust using iterators. This design, realized
for Rust in the iterative_methods
crate
(repo),
follows (and expands on) the approach "Iterative Methods Done
Right"
(IMDR), by Lorenzo Stella, demonstrated with Julia
iterables. The
main idea is that each IM is an iterator over algorithm states, and
each reusable "utility" is an iterator adaptor augmenting or
processing them. These components are highly reusable, composable and
can be tailored using closures. This is basically an extension to this
domain of an approach the Rust standard library already follows. But
if it seems a bit dense, that's fine, we'll unpack it over this post
and the next, introducing along the way a common iterative method as
well as measures of IM quality.
Beyond reuse of methods and utilities, there is another reason to separate them: iterative methods are often highly sensitive beasts. Small changes can cause subtle numerical and/or dynamic effects that are very difficult to trace and resolve. Thus a primary design goal is to minimize the need for modifications to the method implementation, even when attempting to study/debug it.
Our first full iterative method
Our main guest IM for this post and the next will be Conjugate Gradient (CG). What is CG for? here is one example: to simulate physics like spread of heat (or deformation, fluid flows, etc) over a complex shape, we break it up into simpler pieces to apply the Finite Element Method (FEM). In FEM, the relations between temperatures at different elements implied by the heat equation are encoded into a matrix \(A\) and a vector \(b\). To find the vector of temperatures \(x\), it is enough to solve the matrix equation \(Ax = b\). CG is an IM for solving such equations when \(A\) is positive definite (PD)1, which it is for a wide variety of domains.
Why the conjugate gradient method works is beyond the scope of this post, but good sources include the Wikipedia exposition and also these lecture notes.
Positive Definite matrices
A positive definite matrix \(A\) only scales \(x\)'s differently in different directions; no rotation or flipping allowed.
Implementation and a general interface
To store the state our method maintains we define a struct
ConjugateGradient
, for which we implement the StreamingIterator
trait. This trait is simple, and requires us to implement two methods:
advance
applies one iteration of the algorithm, updating state.get
returns a borrow of theItem
type, generally some part of its state.
The benefit of the StreamingIterator trait over the ubiquitous
Iterator is get
exposing information by reference; this leaves
decisions to copy state up to the implementor.
The signatures for implementing an iterative method in this style are as follows:
#[derive(Clone, Debug)]
pub struct ConjugateGradient {
// State of the algorithm
}
impl ConjugateGradient {
/// Initialize a conjugate gradient iterative solver to solve linear system `p`.
pub fn for_problem(p: &LinearSystem) -> ConjugateGradient {
// Problem data such as the LinearSystem is often large, and should not be
// duplicated. This implementation uses ndarray's ArcArray's which are cheap
// to clone as they share the data they point to.
}
}
impl StreamingIterator for ConjugateGradient {
type Item = Self;
fn advance(&mut self) {
// the improvement recipe goes here
}
fn get(&self) -> Option<&Self::Item> {
// Return self, immutably borrowed. This allows callers read-only access
// to method state. The following is a bit simplified:
Some(self)
}
}
Note a few design decisions in the above:
- The problem is a distinct concept from any method for solving
it. The same problem representation (here,
LinearSystem
) often can and should be reused to initialize different methods. - The constructor method
for_problem
is responsible to set up the initial state for the first iteration, and so is part of the method definition. - Another constructor responsibility is to perform applicable and cheap checks of the input problem; expensive initialization is a bad fit for an iterative method.
Item
is set to the wholeConjugateGradient
, all algorithm state. We could set theItem
type returned by theget
method be only a result field, thus hiding implementation details from downstream. Similarly, there is some flexibility in defining the iterable struct: beyond a minimal representation of state required for the next iteration, should we add fields to store intermediate steps of calculations? How about auxiliary information not needed at all in the method itself? Consider the following excerpt from the implementation ofadvance
:
// while r_k != 0:
// alpha_k = ||r_k||^2 / ||p_k||^2_A
self.alpha_k = self.r_k2 / self.pap_k;
if (!too_small(self.r_k2)) && (!too_small(self.pap_k)) {
// x_{k+1} = x_k + alpha_k*p_k
...
Where self.alpha_k
is only read in the remainder of the recipe. So
why not make it a temporary, instead of a fields of
ConjugateGradient
? This would seem to shrink the struct saving
memory and hide an unnecessary detail, generally positive outcomes,
right? But soon after implementing this code I found myself wanting to
print alpha_k
, which is impossible for a local without modifying the
advance
method! By storing more intermediate state in the iterator
state, exposing all of it via get
, and inspecting it externally, we
avoid modifying the method for our inspection and the dreaded
Heisenbugs that could
ensue. On top of a solid whitebox implementation, we can always build
an interface that abstracts away some aspects.
Running an Iterative Method
How do we call such an implementation? the example below illustrates a common workflow:
// First we generate a problem, which consists of the pair (A,b).
let p = make_3x3_pd_system_2();
// Next convert it into an iterator
let mut cg_iter = ConjugateGradient::for_problem(&p);
// and loop over intermediate solutions.
// Note `next` is provided by the StreamingIterator trait using
// `advance` then `get`.
while let Some(result) = cg_iter.next() {
// We want to find x such that a.dot(x) = b
// then the difference between the two sides (called the residual),
// is a good measure of the error in a solution.
let res = result.a.dot(&result.solution) - &result.b;
// The (squared) length of the residual is a cost, a number
// summarizing how bad a solution is. When working on iterative
// methods, we want to see these numbers decrease quickly.
let res_squared_length = res.dot(&res);
// || ... ||_2 is notation for euclidean length of what
// lies between the vertical lines.
println!(
"||Ax - b||_2 = {:.5}, for x = {:.4}, residual = {:.7}",
res_squared_length.sqrt(),
result.solution,
res
);
// Stop if residual is small enough
if res_squared_length < 1e-3 {
break;
}
}
Indeed the output shows nice convergence, with the residual \(Ax - b\) tending quickly to zero:
||Ax - b||_2 = 1.00000, for x = [+0.000, +0.000, +0.000], residual = [+0.000, -1.000, +0.000]
||Ax - b||_2 = 0.94281, for x = [+0.000, +0.667, +0.000], residual = [+0.667, +0.000, +0.667]
||Ax - b||_2 = 0.00000, for x = [-4.000, +6.000, -4.000], residual = [+0.000, +0.000, +0.000]
In terms of the code, notice the algorithm is taken out of the loop! We do not modify it merely to report progress, not even to decide when to stop. But we do change that loop body, which gets a bit messy. Once we start looking for such niceties, soon we'll want to:
- look at only every Nth iteration,
- measure the runtime of an iteration (excluding the cost of reporting itself),
- plot progress over time,
- save progress in case of a power failure...
for basically every method we work on, and we certainly don't want all of those tangled up in our loops. We will want reusable components, named to convey intention! As mentioned above, the idea of representing processes with streaming iterators applies in a similar way to utilities as well, in a way that is clean and orthogonal. We demonstrate this in the next post.
Looking beyond design for code reuse, IM also put a new twist on benchmarking and testing. How does one time or test code that doesn't really want to stop, and for which solutions only approach correctness? We'll get to those questions as well.
Thanks to Daniel Fox (a collaborator on this project) and Yevgenia Vainsencher for feedback on early versions of this post.