Click here to Skip to main content
14,879,988 members
Articles / Artificial Intelligence / Machine Learning
Article
Posted 3 Sep 2019

Tagged as

Stats

5.6K views
4 bookmarked

Solving Iris classification using XGBoost and C#

Rate me:
Please Sign up or sign in to vote.
4.00/5 (2 votes)
3 Sep 2019CPOL2 min read

Table of contents

Introduction

Sample Image - maximum width is 600 pixels

Image source:Wikipedia

 

In this article I have demonstrated how to use the C# wrapper of the popular XGBoost unmanaged library. XGBoost stands for "Extreme Gradient Boosting". I have used the famous IRIS dataset to train and test a model. My objective was to share my learnings of how to embed a machine learning algorithm like extreme gradient boosting in your C# application. Before I move forward I must extend my gratitude to the developers of the XGBoost unmanaged library and to the developers of .NET wrapper library.

top

Background

This article expects the user to be comfortable with an intermediate knowledge of the following:

  • Decision tree algorithm
  • Gradient boosting algorithm
  • Data normalization
  • C#

This article and the accompanying code refrains from providing an indepth tutorial of decision trees and gradient boosting algorithms. I have provided links to Youtube training videos which in my opinion are of immense educational importance.

top

Overview of Gradient Boost Classification algorithm

Intro to decision trees (StatQuest)

top

Understanding Gini index while constructing a decision tree

top

Intro to AdaBoost

top

Intro to Gradient Boost

top

XGBoost library (C#)

Managed wrapper

The C/C++ source code for the original XGBoost library is available on Github. You can find build instructions for Windows. Thanks to the efforts of PicNet, we can skip the step of compiling the unmanaged sources and directly jump to the managed wrapper.

top

Simple linear classification problem

Image 2

We will carry out a simple exercise where we will train a model to classify 2 clusters of points which are nicely linearly separable

C#
/// <summary>
/// Two classes of vectors - Class-Blue and Class-Red
/// Class-Blue  - The vectors are centered around the point (+0.5,+0.5) and label value=1
/// Class-Red   - The vectors are centered around the point (-0.5,-0.5) and label value=0
/// <summary>
[TestMethod]
public void LinearClassification1()
{
    var xgb = new XGBoost.XGBClassifier();
    float[][] vectorsTrain = new float[][]
    {
        new[] {0.5f,0.5f},
        new[] {0.6f,0.6f},
        new[] {0.6f,0.4f},
        new[] {0.4f,0.6f},
        new[] {0.4f,0.4f},

        new[] {-0.5f,-0.5f},
        new[] {-0.6f,-0.6f},
        new[] {-0.6f,-0.4f},
        new[] {-0.4f,-0.6f},
        new[] {-0.4f,-0.4f},
    };
    var lablesTrain = new[]
    {
        1.0f,
        1.0f,
        1.0f,
        1.0f,
        1.0f,

        0.0f,
        0.0f,
        0.0f,
        0.0f,
        0.0f,
    };
    ///
    /// Ensure count of training labels=count of training vectors
    ///
    Assert.AreEqual(vectorsTrain.Length, lablesTrain.Length);
    ///
    /// Train the model
    ///
    xgb.Fit(vectorsTrain, lablesTrain);
    ///
    /// Test the model using test vectors
    ///
    float[][] vectorsTest = new float[][]
    {
        new[] {0.55f,0.55f},
        new[] {0.55f,0.45f},
        new[] {0.45f,0.55f},
        new[] {0.45f,0.45f},

        new[] {-0.55f,-0.55f},
        new[] {-0.55f,-0.45f},
        new[] {-0.45f,-0.55f},
        new[] {-0.45f,-0.45f},
    };
    var labelsTestExpected = new[]
    {
        1.0f,
        1.0f,
        1.0f,
        1.0f,

        0.0f,
        0.0f,
        0.0f,
        0.0f,
    };
    float[] labelsTestPredicted = xgb.Predict(vectorsTest);
    ///
    /// Verify that predicted labels match the expected labels
    ///
    CollectionAssert.AreEqual(labelsTestPredicted, labelsTestExpected);
}

top

Implementing XOR logic

The XOR logic is more complex than the a linear classification. The data points are not directly linearly separable.
Image 3

XOR Truth table

X | Y | OUTPUT
--------------
1 | 0 |   1
--------------
0 | 1 |   1
--------------
0 | 0 |   0
--------------
1 | 1 |   0
--------------

Sample code

C#
[TestMethod]
public void TestMethod1()
{
    var xgb = new XGBoost.XGBClassifier();
    ///
    /// Generate training vectors
    ///
    int countTrainingPoints = 50;
    entity.XGBArray trainClass_0_1 = Util.GenerateRandom2dPoints(countTrainingPoints / 2,
        0.0, 0.5,
        0.5, 1.0, 1.0);//0,1
    entity.XGBArray trainClass_1_0 = Util.GenerateRandom2dPoints(countTrainingPoints / 2,
        0.5, 1.0,
        0.0, 0.5, 1.0);//1,0
    entity.XGBArray trainClass_0_0 = Util.GenerateRandom2dPoints(countTrainingPoints / 2,
        0.0, 0.5,
        0.0, 0.5, 0.0);//0,0
    entity.XGBArray trainClass_1_1 = Util.GenerateRandom2dPoints(countTrainingPoints / 2,
        0.5, 1.0,
        0.5, 1.0, 0.0);//1,1
    ///
    /// Train the model
    ///
    entity.XGBArray allVectorsTraining = Util.UnionOfXGBArrays(trainClass_0_1,trainClass_1_0,trainClass_0_0,trainClass_1_1);
    xgb.Fit(allVectorsTraining.Vectors, allVectorsTraining.Labels);
    ///
    /// Test the model
    ///
    int countTestingPoints = 10;
    entity.XGBArray testClass_0_1 = Util.GenerateRandom2dPoints(countTestingPoints ,
        0.1, 0.4,
        0.6, 0.9, 1.0);//0,1
    entity.XGBArray testClass_1_0 = Util.GenerateRandom2dPoints(countTestingPoints,
        0.6, 0.9,
        0.1, 0.4, 1.0);//1,0
    entity.XGBArray testClass_0_0 = Util.GenerateRandom2dPoints(countTestingPoints,
        0.1, 0.4,
        0.1, 0.4, 0.0);//0,0
    entity.XGBArray testClass_1_1 = Util.GenerateRandom2dPoints(countTestingPoints,
        0.6, 0.9,
        0.6, 0.9, 0.0);//1,1
    entity.XGBArray allVectorsTest = Util.UnionOfXGBArrays(testClass_0_1, testClass_1_0,testClass_0_0,testClass_1_1);
    var resultsActual = xgb.Predict(allVectorsTest.Vectors);
    CollectionAssert.AreEqual(resultsActual, allVectorsTest.Labels);

}

top

Persisting a model to file

Once a model has been trained and found to produce satisfactory results, you would like to use this model in production. The method SaveModelToFile will persist the model to a binary file. The static method LoadClassifierFromFile will rehydrate the saved model.

C#
var xgbTrainer = new XGBoost.XGBClassifier();
///
///Train the model
///
xgbTrainer.SaveModelToFile("SimpleLinearClassifier.dat");
///
///Load the persisted model
///
var xgbProduction = XGBoost.XGBClassifier.LoadClassifierFromFile(fileModel);

Iris dataset

Overview

Image 4 Image 5 Image 6
Source:Wikipedia
The data set contains 50 records from each of the three species of the Iris flower. This data set is a test case to demonstrate many statistical classification techniques. Describe the columns

  1. Iris-setosa
  2. Iris-versicolor
  3. Iris-virginica

top

Data structure

Image 7
Source: Wikipedia

top

Parsing IRIS records from CSV

C#
///
///The C# class Iris will be used for capturing a single data row
///
public class Iris
{
    public float Col1 { get; set; }
    public float Col2 { get; set; }
    public float Col3 { get; set; }
    public float Col4 { get; set; }
    public string Petal { get; set; }
}
///
///The function LoadIris will read the specified file line by line and create an instance of the Iris POCO
///The class TextFieldParser from the assembly Microsoft.VisualBasic is being used here
///
private Iris[] LoadIris(string filename)
{
    string pathFull = System.IO.Path.Combine(Util.GetProjectDir2(), filename);
    List<Iris> records = new List<Iris>();
    using (var parser = new TextFieldParser(pathFull))
    {
        parser.TextFieldType = FieldType.Delimited;
        parser.SetDelimiters(",");
        while (!parser.EndOfData)
        {
            var fields = parser.ReadFields();
            Iris oRecord = new Iris();
            oRecord.Col1 = float.Parse(fields[0]);
            oRecord.Col2 = float.Parse(fields[1]);
            oRecord.Col3 = float.Parse(fields[2]);
            oRecord.Col4 = float.Parse(fields[3]);
            oRecord.Petal = fields[4];
            records.Add(oRecord);
        }
    }

top

Creating a feature vector from CSV

C#
/// <summary>
/// Create XGBoost consumable feature vector from Iris POCO classes
/// </summary>
internal static XGVector<Iris>[] ConvertFromIrisToFeatureVectors(Iris[] records)
{
    List<XGVector<Iris>> vectors = new List<XGVector<Iris>>();
    foreach (var rec in records)
    {
        XGVector<Iris> newVector = new XGVector<Iris>();
        newVector.Original = rec;
        newVector.Features = new float[]
        {
            rec.Col1, rec.Col2,rec.Col3,rec.Col4
        };
        newVector.Label = ConvertLabelFromStringToNumeric(rec.Petal);
        vectors.Add(newVector);
    }
    return vectors.ToArray();
}


/// <summary>
/// Converts the string based name of the petal to a numeric representation
/// </summary>
internal static float ConvertLabelFromStringToNumeric(string petal)
{
    if (petal.Contains("setosa"))
    {
        return 0;
    }
    else if (petal.Contains("versicolor"))
    {
        return 1.0f;
    }
    else if (petal.Contains("virginica"))
    {
        return 2.0f;
    }
    else
    {
        throw new NotImplementedException();
    }
}

top

Loading IRIS-putting it all together

C#
[TestMethod]
public void BasicLoadData()
{
    string filename = "Iris\\Iris.train.data";
    iris.Iris[] records = IrisUtils.LoadIris(filename);
    entity.XGVector<iris.Iris>[] vectors = IrisUtils.ConvertFromIrisToFeatureVectors(records);
    Assert.IsTrue(records.Length >= 140);
}

top

Training and testing IRIS

C#
[TestMethod]
public void TrainAndTestIris()
{
    ///
    /// Load training vectors
    ///
    string filenameTrain = "Iris\\Iris.train.data";
    iris.Iris[] recordsTrain = IrisUtils.LoadIris(filenameTrain);
    entity.XGVector<iris.Iris>[] vectorsTrain = IrisUtils.ConvertFromIrisToFeatureVectors(recordsTrain);
    ///
    /// Load testingvectors
    ///
    string filenameTest = "Iris\\Iris.test.data";
    iris.Iris[] recordsTest = IrisUtils.LoadIris(filenameTest);
    entity.XGVector<iris.Iris>[] vectorsTest = IrisUtils.ConvertFromIrisToFeatureVectors(recordsTest);

    int noOfClasses = 3;
    var xgbc = new XGBoost.XGBClassifier(objective: "multi:softprob", numClass:3);
    entity.XGBArray arrTrain = Util.ConvertToXGBArray(vectorsTrain);
    entity.XGBArray arrTest = Util.ConvertToXGBArray(vectorsTest);
    xgbc.Fit(arrTrain.Vectors, arrTrain.Labels);
    var outcomeTest=xgbc.Predict(arrTest.Vectors);
    for(int index=0;index<arrTest.Vectors.Length;index++)
    {
        string sExpected = IrisUtils.ConvertLabelFromNumericToString(arrTest.Labels[index]);
        float[] arrResults = new float[]
        {
            outcomeTest[index*noOfClasses +0],
            outcomeTest[index*noOfClasses +1],
            outcomeTest[index*noOfClasses +2]
        };
        float max = arrResults.Max();
        int indexWithMaxValue = Util.GetIndexWithMaxValue(arrResults);
        string sActualClass = IrisUtils.ConvertLabelFromNumericToString((float)indexWithMaxValue);
        Trace.WriteLine($"{index}       Expected={sExpected}        Actual={sActualClass}");
        Assert.AreEqual(sActualClass, sExpected);
    }
    string pathFull = System.IO.Path.Combine(Util.GetProjectDir2(), _fileModelIris);
    xgbc.SaveModelToFile(pathFull);
}

top

Using the code

Github

Solution structure

|
|-----XGBoost
|
|-----XGBoostTests
|           |
|           |---iris
|           |     |
|           |     |--Iris.data
|           |     |
|           |     |--Iris.test.data
|           |     |
|           |     |--Iris.train.data
|           |     |
|           |     |--Iris.cs
|           |     |
|           |
|           |---IrisUtils.cs
|           |
|           |---IrisUnitTest.cs
|           |
|           |---SimpleLinearClassifierTests.cs
|           |
|           |---XORClassifierTests.cs
|           |
|
|

top

License

This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)

Share

About the Author

Sau002
Software Developer (Senior)
United Kingdom United Kingdom
With over 22 years experience in software development. My first job was to port C and Fortran code from UNIX to Windows NT. I have worn many hats since then. Windows forms, Web development, Windows Presentation Framework, Silverlight, ASP.NET, SQL tuning, jQuery, Web API, SharePoint and now machine learning.

My book on Neural Network: http://amzn.eu/8G4erDQ


Comments and Discussions

 
GeneralMy vote of 3 Pin
Shabbazz23-Sep-19 15:20
professionalShabbazz23-Sep-19 15:20 

General General    News News    Suggestion Suggestion    Question Question    Bug Bug    Answer Answer    Joke Joke    Praise Praise    Rant Rant    Admin Admin   

Use Ctrl+Left/Right to switch messages, Ctrl+Up/Down to switch threads, Ctrl+Shift+Left/Right to switch pages.