Connecting...

Scala

RANSAC in Scala by Nicolau Werneck

Scala

Image credit www.godatafy.com

Have you ever implemented RANSAC in Scala? In this article by Nicolau Werneck he tells us about his experience with this.


'A colleague of mine was going to give Scala a chance the other day, and tried to search for implementations of classic algorithms. He looked for RANSAC, and couldn’t find it… This blog seeks to rectify this situation!

If you know I am a computer vision doctor, you may be surprised if I tell you I hold very little love for RANSAC, and also for another technique that frequently walks together with it: the Harris corner detector. I don’t want to sound ungrateful or hypocritical, as I have actually used RANSAC in my research. But I just can’t understand why so many people hold this algorithm in high regard. I am fascinated by the problem it solves, yes, but when you start using and reading about it, and get to know the basics about the alternatives to RANSAC, you should eventually start feeling like it is the bubble-sort of robust estimation… And it’s just such an ugly and non-elegant random-search. It’s not even like the Hough transform, which is pretty cool even though the only version most people learn is also not very sophisticated. My favorite method is perhaps RUDR, maybe because it actually mixes the generate-and-test strategy from RANSAC with the parameter space search from the Hough transform, and many people don’t even quite seem to realize these two algorithms are actually two approaches for the same problem. At least I had to take my time to really grok this. And finally I must confess my ignorance of the branch-and-bound methods, that now seem to be the really good ones for these problems.

As for the Harris detector, don’t even get me started on it… First of all there is a common misconception that they are used to detect x-corners in checkerboard patterns for calibration. This is entirely false. The Harris detector is completely blind to regions with an indefinite Hessian. Furthermore, it will usually not give you many points that you would naturally consider to be a “corner”, and it has a tendency of providing you with lots of points that you would never think should be a “corner”. And frustrated about this low recall rate, we then tend to lower the detection threshold up until we start to generate lots of edge points, and go on calling that “corners”, what is completely ridiculous. But then you put some RANSAC over it, and it works, because if there are two sampled points over an edge that happen to be close to the same point in space, they will be matched, but these inliers don’t give us permission to become liars, and keep calling them “corners”, right? The greatest open question in computer vision research to me right now is: Why don’t we all just go ahead and start accepting edge-based vision as part of life instead of insisting on the “corner” paradigm?

But I digress. Fact is even with so many bad things to talk about it, RANSAC, as much as the Harris detector (which was actually pioneered by Moravec), has some great pedagogical value. And implementing RANSAC lets you implement a whole framework on top of which you can implement better algorithms. The one redeeming quality of RANSAC is probably this: it is easy to understand and to implement, and this is precisely why it might also be an interesting example to learn a new language, as we shall do right now.

I have implemented RANSAC in Scala, and left the code in a GitHub repo. In the rest of this article I will go though the code making some remarks.


Algorithm overview

RANSAC takes as input a set of data points. From this data we sample a number of minimal sets, from which geometrical models are calculated. We then need to go back to the data set and calculate the number of inliers from each model, and in the end you return the model with the highest inlier count.

We will start the analysis of our code in a very top-down and abstract way. The first thing I did when implementing was to consider how I might change the whole thing to either adapt the algorithm to different problems (e.g. line fitting, circle fitting, essential matrix estimation), or to implement modifications to the RANSAC algorithm. For that I created a small class that puts together all information necessary to run the algorithm, and that contains a single method which performs the estimation, even though it doesn’t do any real work, it is all really done by the functions send as arguments to the class. Here is the complete file:

class RobustEstimator[Data, Hypothesis, Model](sampler: Seq[Data] => Hypothesis,
                                               model_generator: Hypothesis => Model,
                                               inlier_detector: Model => Data => Boolean) {
  def estimate(data: Seq[Data], iterations: Int): Model = {
    val minimal_sets = Seq.fill(iterations)(sampler(data))
    val hypothetical_models = minimal_sets map model_generator
    hypothetical_models.maxBy(m => data count inlier_detector(m))
  }

}
The class takes three functions as parameters. The first one samples “minimal sets” from a data set. I’ve called the type of such a minimal set a Hypothesis, because we first hypothesize that these are inlier data points. The next function takes the minimal set and actually produces our hypothetical Model. The third function is an inlier detector, it takes a model as a parameter, and also a data point, and then tells use whether the point is an inlier of this model or not. Take line fitting for example. The first function just samples pairs of points from a list of points. The second one produces a line from a pair of points. The third one measures the distance of a point to a line.

Putting these functions together to implement RANSAC is pretty straight-forward. The first line of the estimate method produces all hypotheses we need, controlled by the parameter iterations. Then we produce a model for each sampled minimal set. The last line is the tricky one. It employs the very useful maxBy method to go through all the models, and then pick the one with the largest number of inliers. This inlier count is performed by the count method, which uses the inlier detector function.

If we were writing this in a very procedural and non-functional kind of way, the first line would be something like a for that appends new minimal sets to a list in each iteration. The second one iterates over the list and appends models to a new one. Then we might perhaps iterate over it, calculating the inlier count, and then storing the current best guess and best count on a couple of variables… But reading this code here, loop-free, it makes very clear how the problem can be easily parallelized, for instance. Or how you might use lazy evaluation and avoid having all these lists in memory. That’s because a “for” can be anything in principle, it is not quite obvious to say when you can or cannot parallelize it. A “for” is still something almost as generic and non-descriptive as a “goto”. But once you write it as a “map”, it becomes quite obvious what is happening in the code.

This is specially important in the case of the last “for”. In that procedural implementation, the need to store and compare each result with this variable makes it all quite sequential. But if you express yourself in terms of operators such as “maxBy”, then it becomes clear and explicit that we are performing some kind of “reduce” operation, that can also be parallelized. It also says something about the nature of the algorithm too. These are probably the main interesting changes that come with working with Scala compared to traditional procedural languages.

Now we have already implemented the algorithm in principle. But we still have a lot of work to do. What we need next is to make it a little more concrete by implementing a real-life example, and using those function signatures to guide our work. This is done in this next file:


import estimation.RobustEstimator
import geometry.{Line, Point}

import scala.math.abs
import scala.util.Random

object TestRANSAC extends App {

  val n_outliers = 100
  val n_inliers = 100
  val aux = Line(Point.nextGaussian)
  val original_model = if (abs(aux.x) < abs(aux.y)) aux else Line(aux.y, aux.x)

  val sigma = 0.2
  val data = generate_data(n_outliers, n_inliers, original_model, sigma)

  val ransac = new RobustEstimator(pick_point_pair, get_line_from_point_pair, test_point_closer_than(3 * sigma))

  val iterations = 10
  val estimated_model = ransac.estimate(data, iterations)

  println(s"Original model $original_model")
  println(s"Estimated model: $estimated_model")

  def pick_point_pair(data: Seq[Point]): (Point, Point) = {
    val List(p1, p2) = sample_without_replacement(2, data)
    (p1, p2)
  }

  def get_line_from_point_pair(point_pair: (Point, Point)): Line = {
    val (p1, p2) = point_pair
    val delta = p2 - p1
    val x = delta.y * (p1 cross p2) / delta.sqnorm
    val y = -delta.x * (p1 cross p2) / delta.sqnorm
    Line(x, y)
  }

  def test_point_closer_than(threshold: Double)(l: Line)(p: Point): Boolean = {
    abs(l distance p) < threshold
  }

  def sample_without_replacement[D](N: Int, data: Seq[D], sample: List[D] = List.empty): List[D] = {
    if (N <= 0) sample else {
      val el = Random.nextInt(data.size)
      val remaining_data = data.take(el) ++ data.drop(el + 1)
      sample_without_replacement(N - 1, remaining_data, sample :+ data(el))
    }
  }

  def generate_data(n_outliers: Int, n_inliers: Int, line: Line, noise: Double) = {
    val r = 10.0 // Size of test space

    val outliers = List.fill(n_outliers) { (Point.nextUniform - Point(0.5, 0.5)) * r }
    val inliers = List.fill(n_inliers) {
      val x = (Random.nextDouble - 0.5) * r / 2
      Point(x, line(x)) + Point.nextGaussian * noise
    }

    outliers ++ inliers
  }
}

Our test starts by generating some data, random points on a surface and over a line, all put together at line 57. We then pick an object passing the functions created to solve our problem,  apply the estimation method to the generated data, and print the result.

The three necessary functions were named in a way that is pretty explicit in what they do. We must be able to pick a pair of points, to make a line out of a pair of points, then to test the distance of a line to a point.

This is the part of the code where we start to get a little more low-level… In the first function we made use of a function to make a random sampling without replacement from a list. I’ve used a recursive function for that, which progresses appending sampled points to a list while removing them from the input list… Notice that the way we do this in Scala is to just assume our immutable data structures will have a great performance, so we just go ahead and write line 44, generating a “new” list of samples to continue the recursion. Supposedly this “new” list will reuse a lot of the existing data, and not really occupy much memory or take much processing time as the program runs. But conceptually, it is really just like a completely new list. And immutable…

Now compare this to what would be a typical procedural implementation: you might actually count on mutability, and maybe move each sample to the end of the list, and move the element on the end to the place you picked the sample, and then continue replacing the second-to-last element and then returning a copy of these last elements. It looks clever, but when you look from the FP point of view, it is actually pretty gross… It is so much more elegant to write that immutable code. It really tells what we want to do, in a very explicit an formal way, there are no hacks moving data around to special parts of a list. It makes C feel like assembly programming.

The function to generate a model is only complicated because it has some geometrical calculations in there. Notice we are representing a line by the point closest to the origin. This can be quite useful in such problems. This can be seen on the implementation of the distance method, which is called in the third function from this solution implementation, only to be thresholded to give us our Boolean result for the detector.

All that is left in the code is the implementation of the Point and Line classes, which are “case classes”, one of Scala’s strongest tools. Just the fact it automatically implements a nice printing method for the class makes it very interesting to use! But it’s much more than that…

import scala.math.sqrt
import scala.util.Random

trait Vec2d {
  def x: Double

  def y: Double

  def -(that: Point): Point = Point(this.x - that.x, this.y - that.y)

  def +(that: Point): Point = Point(this.x + that.x, this.y + that.y)

  def *(that: Point): Double = this.x * that.x + this.y * that.y

  def cross(that: Point): Double = this.x * that.y - this.y * that.x

  def *(that: Double): Point = Point(this.x * that, this.y * that)

  def /(that: Double): Point = Point(this.x / that, this.y / that)

  def sqnorm = x * x + y * y

  def norm = sqrt(sqnorm)
}


case class Point(override val x: Double, override val y: Double) extends Vec2d

object Point {
  def nextGaussian = Point(Random.nextGaussian, Random.nextGaussian)

  def nextUniform = Point(Random.nextDouble, Random.nextDouble)
}

case class Line(override val x: Double, override val y: Double) extends Vec2d {
  def distance(p: Point) = {
    val lp = Point(x, y)
    val proj = lp * (lp * p) / lp.sqnorm
    (proj - lp) * lp / lp.norm
  }

  def apply(ix: Double) = {
    y + x * (x - ix) / y
  }
}

object Line {
  def apply(p: Point): Line = Line(p.x, p.y)
}

Another interesting thing to notice in the Vec2d trait is how we can easily define operators. There are really no “operators” in Scala, it’s all functions, methods. You can call any method that takes a single argument using spaces, like it were an operator. And you can use any symbol for the function names, including Unicode stuff. Scala does take care of operator precedence, though, but apart from that all methods are very much alike.


Some output

We conclude with a graphic I made using Breeze. This is a great library for scientific computing, that enables you to make all kind of things you usually do with Matlab, Numpy or R. I haven’t used it much yet, but it looks terrific. The plotting part is still lacking a lot compared to matplotlib, but I trust this will become great in the near future.


The blue line shows the simulated model, and the red line the estimated one. As we can see, it is pretty close, although the noise in the inliers turns the line away. Our program output in that case was

Original model Line(0.4374995149274083,-1.0920011922871506)
Estimated model: Line(0.470293650646997,-0.9313502646150857)

To get better and faster results, we might extend the algorithm to run an optimization after each model is created, before we count the inliers. We can also restrict the hypothesis we generate using some kind of heuristic. The tests can also avoid using all the data points available. Some of these changes would require modifications to the RobustEstimator class, some other only changes to the functions we provide to it. I think seeing where these changes happen, and seeing the function signatures help us a lot to understand differences between algorithms, and these changes may be quite important, but if we work with lower level procedural languages it is not so easy to see that kind of thing.

Programming in Scala really helps me see these things better, I can not only understand things better theoretically, but I have better insights about performance. I am usually more productive and just plain happier with my results after I see the code I wrote. That is just what I wish for any programmer, and I can’t recommend Scala enough. It can even make a dreadful technique such as RANSAC look good!'

This article was written by Nicolau Werneck and originally posted on xor0110.wordpress.com