Click here to Skip to main content
15,861,168 members
Articles / Programming Languages / C#

Least Squares Regression for Quadratic Curve Fitting

Rate me:
Please Sign up or sign in to vote.
4.95/5 (22 votes)
9 Mar 2010CPOL2 min read 127.6K   3K   46   25
A C# class for Least Squares Regression for Quadratic Curve Fitting.

Introduction

A recent software project had a requirement to derive the equation of a quadratic curve from a series of data points. That is to say, to determine a, b, and c, where y = ax2 + bx + c. Having determined a, b, and c, I would also need a value for R-squared (the coefficient of determination).

A quick search of Google failed to bring up a suitable C# class; quite possibly, a more thorough search would have done so.

I had a vague recollection of something called 'Least Squares Regression', so back to Google I went.

References

Least Squares Regression:

R square:

Cramer's rule:

Using the class

Declare an instance:

C#
LstSquQuadRegr solvr = new LstSquQuadRegr();

Pass it some point pairs (at least 3):

C#
solvr.AddPoints(x1, y1);
solvr.AddPoints(x2, y2);
solvr.AddPoints(x3, y3);
solvr.AddPoints(x4, y4);

Get the values:

C#
double the_a_term = solvr.aTerm();  
double the_b_term = solvr.bTerm();
double the_c_term = solvr.cTerm();
double the_rSquare = solvr.rSquare();

The Theory

y = ax^2 + bx + c

We have a series of points (x1,y1), (x2,y2) ... (xn,yn).

for i = 1 to n

We want the values of a, b, and c that minimise the sum of squares of the deviations of yi from a*xi^2 + bxi + c. Such values will give the best-fitting quadratic equation.

Let the sum of the squares of the deviations be:

               n
   F(a,b,c) = SUM (a*xi^2 + bxi + c - yi)^2.
              i=1

dF/da = SUM 2*(a*xi^2+b*xi+c-yi)*xi^2 = 0,
dF/db = SUM 2*(a*xi^2+b*xi+c-yi)*xi = 0,
dF/dc = SUM 2*(a*xi^2+b*xi+c-yi) = 0.

(Here, all sums range over i = 1, 2, ..., n.) Dividing by 2 and rearranging gives these three simultaneous linear equations containing the three unknowns a, b, and c:

(SUM xi^4)*a + (SUM xi^3)*b + (SUM xi^2)*c = SUM xi^2*yi,
(SUM xi^3)*a + (SUM xi^2)*b +   (SUM xi)*c = SUM xi*yi,
(SUM xi^2)*a +   (SUM xi)*b +    (SUM 1)*c = SUM yi.

Using notation Sjk to mean the sum of x_i^j*y_i^k:

a*S40 + b*S30 + c*S20 = S21
a*S30 + b*S20 + c*S10 = S11
a*S20 + b*S10 + c*S00 = S01

Solve the simultaneous equations using Cramer's law:

  [ S40  S30  S20 ] [ a ]   [ S21 ]
  [ S30  S20  S10 ] [ b ] = [ S11 ]
  [ S20  S10  S00 ] [ c ]   [ S01 ] 

  D = [ S40  S30  S20 ] 
      [ S30  S20  S10 ] 
      [ S20  S10  S00 ]  
  
    = S40(S20*S00 - S10*S10) - S30(S30*S00 - S10*S20) + S20(S30*S10 - S20*S20)

 Da = [ S21  S30  S20 ]
      [ S11  S20  S10 ] 
      [ S01  S10  S00 ]  

    = S21(S20*S00 - S10*S10) - S11(S30*S00 - S10*S20) + S01(S30*S10 - S20*S20)

 Db = [ S40  S21  S20 ] 
      [ S30  S11  S10 ] 
      [ S20  S01  S00 ]  

    = S40(S11*S00 - S01*S10) - S30(S21*S00 - S01*S20) + S20(S21*S10 - S11*S20)
  
 Dc = [ S40  S30  S21 ] 
      [ S30  S20  S11 ] 
      [ S20  S10  S01 ] 
  
    = S40(S20*S01 - S10*S11) - S30(S30*S01 - S10*S21) + S20(S30*S11 - S20*S21)  

a = Da/D
b = Db/D
c = Dc/D

R square

R2 = 1 - (residual sum of squares / total sum of squares).

                        n
total sum of squares = SUM (yi - y_mean)^2.
                       i=1

This is the sum of the squares of the differences between the measured y values and the mean y value.

                           n
residual sum of squares = SUM (yi - yi_predicted)^2.
                          i=1

This is the sum of the squares of the difference between the measured y values and the values of y predicted by the equation.

The Code

A bunch of helper methods calculate all the various sums of squares. When calculating the values of a, b, and c, I used the sjk notation above as I found it easier to keep track.

C#
/******************************************************************************
                          Class LstSquQuadRegr
     A C#  Class for Least Squares Regression for Quadratic Curve Fitting
                          Alex Etchells  2010    
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Collections;


public  class LstSquQuadRegr
{
     /* instance variables */
    ArrayList pointArray = new ArrayList(); 
    private int numOfEntries; 
    private double[] pointpair;          

    /*constructor */
    public LstSquQuadRegr()
    {
        numOfEntries = 0;
        pointpair = new double[2];
    }

    /*instance methods */    
    /// <summary>
    /// add point pairs
    /// </summary>
    /// <param name="x">x value</param>
    /// <param name="y">y value</param>
    public void AddPoints(double x, double y) 
    {
        pointpair = new double[2]; 
        numOfEntries +=1; 
        pointpair[0] = x; 
        pointpair[1] = y;
        pointArray.Add(pointpair);
    }

    /// <summary>
    /// returns the a term of the equation ax^2 + bx + c
    /// </summary>
    /// <returns>a term</returns>
    public double aTerm()
    {
        if (numOfEntries < 3)
        {
            throw new InvalidOperationException(
               "Insufficient pairs of co-ordinates");
        }
        //notation sjk to mean the sum of x_i^j*y_i^k. 
        double s40 = getSx4(); //sum of x^4
        double s30 = getSx3(); //sum of x^3
        double s20 = getSx2(); //sum of x^2
        double s10 = getSx();  //sum of x
        double s00 = numOfEntries;
        //sum of x^0 * y^0  ie 1 * number of entries

        double s21 = getSx2y(); //sum of x^2*y
        double s11 = getSxy();  //sum of x*y
        double s01 = getSy();   //sum of y

        //a = Da/D
        return (s21*(s20 * s00 - s10 * s10) - 
                s11*(s30 * s00 - s10 * s20) + 
                s01*(s30 * s10 - s20 * s20))
                /
                (s40*(s20 * s00 - s10 * s10) -
                 s30*(s30 * s00 - s10 * s20) + 
                 s20*(s30 * s10 - s20 * s20));
    }

    /// <summary>
    /// returns the b term of the equation ax^2 + bx + c
    /// </summary>
    /// <returns>b term</returns>
    public double bTerm()
    {
        if (numOfEntries < 3)
        {
            throw new InvalidOperationException(
               "Insufficient pairs of co-ordinates");
        }
        //notation sjk to mean the sum of x_i^j*y_i^k.
        double s40 = getSx4(); //sum of x^4
        double s30 = getSx3(); //sum of x^3
        double s20 = getSx2(); //sum of x^2
        double s10 = getSx();  //sum of x
        double s00 = numOfEntries;
        //sum of x^0 * y^0  ie 1 * number of entries

        double s21 = getSx2y(); //sum of x^2*y
        double s11 = getSxy();  //sum of x*y
        double s01 = getSy();   //sum of y

        //b = Db/D
        return (s40*(s11 * s00 - s01 * s10) - 
                s30*(s21 * s00 - s01 * s20) + 
                s20*(s21 * s10 - s11 * s20))
                /
                (s40 * (s20 * s00 - s10 * s10) - 
                 s30 * (s30 * s00 - s10 * s20) + 
                 s20 * (s30 * s10 - s20 * s20));
    }

    /// <summary>
    /// returns the c term of the equation ax^2 + bx + c
    /// </summary>
    /// <returns>c term</returns>
    public double cTerm()
    {
        if (numOfEntries < 3)
        {
            throw new InvalidOperationException(
                       "Insufficient pairs of co-ordinates");
        }
        //notation sjk to mean the sum of x_i^j*y_i^k.
        double s40 = getSx4(); //sum of x^4
        double s30 = getSx3(); //sum of x^3
        double s20 = getSx2(); //sum of x^2
        double s10 = getSx();  //sum of x
        double s00 = numOfEntries;
        //sum of x^0 * y^0  ie 1 * number of entries

        double s21 = getSx2y(); //sum of x^2*y
        double s11 = getSxy();  //sum of x*y
        double s01 = getSy();   //sum of y

        //c = Dc/D
        return (s40*(s20 * s01 - s10 * s11) - 
                s30*(s30 * s01 - s10 * s21) + 
                s20*(s30 * s11 - s20 * s21))
                /
                (s40 * (s20 * s00 - s10 * s10) - 
                 s30 * (s30 * s00 - s10 * s20) + 
                 s20 * (s30 * s10 - s20 * s20));
    }
    
    public double rSquare() // get r-squared
    {
        if (numOfEntries < 3)
        {
            throw new InvalidOperationException(
               "Insufficient pairs of co-ordinates");
        }
        // 1 - (residual sum of squares / total sum of squares)
        return 1 - getSSerr() / getSStot();
    }
   

    /*helper methods*/
    private double getSx() // get sum of x
    {
        double Sx = 0;
        foreach (double[] ppair in pointArray)
        {
            Sx += ppair[0];
        }
        return Sx;
    }

    private double getSy() // get sum of y
    {
        double Sy = 0;
        foreach (double[] ppair in pointArray)
        {
            Sy += ppair[1];
        }
        return Sy;
    }

    private double getSx2() // get sum of x^2
    {
        double Sx2 = 0;
        foreach (double[] ppair in pointArray)
        {
            Sx2 += Math.Pow(ppair[0], 2); // sum of x^2
        }
        return Sx2;
    }

    private double getSx3() // get sum of x^3
    {
        double Sx3 = 0;
        foreach (double[] ppair in pointArray)
        {
            Sx3 += Math.Pow(ppair[0], 3); // sum of x^3
        }
        return Sx3;
    }

    private double getSx4() // get sum of x^4
    {
        double Sx4 = 0;
        foreach (double[] ppair in pointArray)
        {
            Sx4 += Math.Pow(ppair[0], 4); // sum of x^4
        }
        return Sx4;
    }

    private double getSxy() // get sum of x*y
    {
        double Sxy = 0;
        foreach (double[] ppair in pointArray)
        {
            Sxy += ppair[0] * ppair[1]; // sum of x*y
        }
        return Sxy;
    }

    private double getSx2y() // get sum of x^2*y
    {
        double Sx2y = 0;
        foreach (double[] ppair in pointArray)
        {
            Sx2y += Math.Pow(ppair[0], 2) * ppair[1]; // sum of x^2*y
        }
        return Sx2y;
    }

    private double getYMean() // mean value of y
    {
        double y_tot = 0;
        foreach (double[] ppair in pointArray)
        {
            y_tot += ppair[1]; 
        }
        return y_tot/numOfEntries;
    }

    private double getSStot() // total sum of squares
    {
        //the sum of the squares of the differences between 
        //the measured y values and the mean y value
        double ss_tot = 0;
        foreach (double[] ppair in pointArray)
        {
            ss_tot += Math.Pow(ppair[1] - getYMean(), 2);
        }
        return ss_tot;
    }

    private double getSSerr() // residual sum of squares
    {
        //the sum of the squares of te difference between 
        //the measured y values and the values of y predicted by the equation
        double ss_err = 0;
        foreach (double[] ppair in pointArray)
        {
            ss_err += Math.Pow(ppair[1] - getPredictedY(ppair[0]), 2);
        }
        return ss_err;
    }

    private double getPredictedY(double x)
    {
        //returns value of y predicted by the equation for a given value of x
        return aTerm() * Math.Pow(x, 2) + bTerm() * x + cTerm();
    }
}

Points of Interest

That's it really - it seems to agree pretty closely with the values given by Excel. I hope it saves someone else the bother of working it out.

License

This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)


Written By
Software Developer University Of East Anglia, UK
United Kingdom United Kingdom
This member has not yet provided a Biography. Assume it's interesting and varied, and probably something to do with programming.

Comments and Discussions

 
Question3rd degree polynomial regression C# Pin
equitymasters28-Dec-22 5:11
equitymasters28-Dec-22 5:11 
Question[Insert points through csv file] Pin
Member 1451163725-Jun-19 0:20
Member 1451163725-Jun-19 0:20 
QuestionIncorrect result of lat/log pairs with 8 decimal point precision Pin
Member 1094985924-Jul-14 1:35
Member 1094985924-Jul-14 1:35 
AnswerRe: Incorrect result of lat/log pairs with 8 decimal point precision Pin
Alex@UEA24-Jul-14 1:44
Alex@UEA24-Jul-14 1:44 
Are you using the original version in the article above or the updated version[^] from the comments section?

Alex
GeneralRe: Incorrect result of lat/log pairs with 8 decimal point precision Pin
Member 1094985925-Jul-14 2:14
Member 1094985925-Jul-14 2:14 
QuestionThank you Pin
lolaparra806-Feb-14 23:17
lolaparra806-Feb-14 23:17 
GeneralSingle C# static function version Pin
Juraj Lutisan5-May-13 23:43
Juraj Lutisan5-May-13 23:43 
GeneralRe: Single C# static function version Pin
Alex@UEA6-May-13 21:32
Alex@UEA6-May-13 21:32 
GeneralRe: Single C# static function version Pin
Juraj Lutisan6-May-13 23:52
Juraj Lutisan6-May-13 23:52 
BugRe: Single C# static function version Pin
Rado_24-Sep-15 7:21
Rado_24-Sep-15 7:21 
GeneralRe: Single C# static function version Pin
hankkauffmann9-Apr-16 11:04
hankkauffmann9-Apr-16 11:04 
QuestionC++ version Pin
Robin Imrie8-Nov-12 0:33
professionalRobin Imrie8-Nov-12 0:33 
AnswerRe: C++ version Pin
PIntag24-Sep-15 7:46
PIntag24-Sep-15 7:46 
GeneralRe: C++ version Pin
Robin Imrie23-Nov-15 22:25
professionalRobin Imrie23-Nov-15 22:25 
GeneralMy vote of 5 Pin
NicoD22-May-12 0:18
NicoD22-May-12 0:18 
GeneralMy vote of 5 Pin
Libra85225-Dec-11 4:56
Libra85225-Dec-11 4:56 
QuestionNice work, but a suggestion: Pin
Kangerm00se28-Nov-11 22:18
Kangerm00se28-Nov-11 22:18 
AnswerRe: Nice work, but a suggestion: Pin
Alex@UEA28-Nov-11 22:57
Alex@UEA28-Nov-11 22:57 
Generalincorrect results Pin
Prcy9-Dec-10 11:14
Prcy9-Dec-10 11:14 
GeneralRe: incorrect results Pin
Alex@UEA10-Dec-10 0:23
Alex@UEA10-Dec-10 0:23 
GeneralMe Like! Pin
John Devron27-Sep-10 8:03
John Devron27-Sep-10 8:03 
Generallooks nice but high order polynomial is better Pin
lastguy30-Aug-10 14:38
lastguy30-Aug-10 14:38 
GeneralRe: looks nice but high order polynomial is better [modified] Pin
Alex@UEA6-Oct-10 23:27
Alex@UEA6-Oct-10 23:27 
GeneralGood job Pin
Pete O'Hanlon4-Mar-10 20:57
subeditorPete O'Hanlon4-Mar-10 20:57 
GeneralI've been looking for one of these Pin
Don Kackman4-Mar-10 10:33
Don Kackman4-Mar-10 10:33 

General General    News News    Suggestion Suggestion    Question Question    Bug Bug    Answer Answer    Joke Joke    Praise Praise    Rant Rant    Admin Admin   

Use Ctrl+Left/Right to switch messages, Ctrl+Up/Down to switch threads, Ctrl+Shift+Left/Right to switch pages.