Click here to Skip to main content
15,886,052 members
Articles / Artificial Intelligence / Machine Learning

Maximum Entropy Modeling Using SharpEntropy

Rate me:
Please Sign up or sign in to vote.
4.84/5 (42 votes)
9 May 200612 min read 201.1K   6.4K   109  
Presents a Maximum Entropy modeling library, and discusses its usage, with the aid of two examples: a simple example of predicting outcomes, and an English language tokenizer.
//Copyright (C) 2005 Richard J. Northedge
//
// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License as published by the Free Software Foundation; either
// version 2.1 of the License, or (at your option) any later version.
//
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.

//This file is based on the GISTrainer.java source file found in the
//original java implementation of MaxEnt.  That source file contains the following header:

// Copyright (C) 2001 Jason Baldridge and Gann Bierner
//
// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License as published by the Free Software Foundation; either
// version 2.1 of the License, or (at your option) any later version.
//
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.

using System;
using System.Collections;
using System.Collections.Generic;

namespace SharpEntropy
{
	/// <summary>
	/// An implementation of Generalized Iterative Scaling.  The reference paper
	/// for this implementation was Adwait Ratnaparkhi's tech report at the
	/// University of Pennsylvania's Institute for Research in Cognitive Science,
	/// and is available at <a href ="ftp://ftp.cis.upenn.edu/pub/ircs/tr/97-08.ps.Z"><code>ftp://ftp.cis.upenn.edu/pub/ircs/tr/97-08.ps.Z</code></a>. 
	/// </summary>
	/// <author>
	/// Jason Baldridge
	/// </author>
	/// <author>
	///  Richard J, Northedge
	/// </author>
	/// <version>
	/// based on GISTrainer.java, $Revision: 1.15 $, $Date: 2004/06/14 20:52:41 $
	/// </version>
	public class GisTrainer : IO.IGisModelReader
	{
		private int mTokenCount; // # of event tokens
		private int mPredicateCount; // # of predicates
		private int mOutcomeCount; // # of mOutcomes
		private int mTokenID; // global index variable for Tokens
		private int mPredicateId; // global index variable for Predicates    
		private int mOutcomeId; // global index variable for Outcomes
				
		// records the array of predicates seen in each event
		private int[][] mContexts;
		
		// records the array of outcomes seen in each event
		private int[] mOutcomes;
				
		// records the num of times an event has been seen, paired to
		// int[][] mContexts
		private int[] mNumTimesEventsSeen;
		
		// stores the string names of the outcomes.  The GIS only tracks outcomes
		// as ints, and so this array is needed to save the model to disk and
		// thereby allow users to know what the outcome was in human
		// understandable terms.
		private string[] mOutcomeLabels;
		
		// stores the string names of the predicates. The GIS only tracks
		// predicates as ints, and so this array is needed to save the model to
		// disk and thereby allow users to know what the outcome was in human
		// understandable terms.
		private string[] mPredicateLabels;

		// stores the observed expections of each of the events
		private double[][] mObservedExpections;
		
		// stores the estimated parameter value of each predicate during iteration
		private double[][] mParameters;
		
		// Stores the expected values of the features based on the current models
		private double[][] mModelExpections;
		
		//The maximum number of features fired in an event. Usually referred to as C.
		private int mMaximumFeatureCount;

		// stores inverse of constant, 1/C.
		private double mMaximumFeatureCountInverse;

		// the correction parameter of the model
		private double mCorrectionParameter;

		// observed expectation of correction feature
		private double mCorrectionFeatureObservedExpectation;

		// a global variable to help compute the amount to modify the correction
		// parameter
		private double mCorrectionFeatureModifier;
		
		private const double mNearZero = 0.01;
		private const double mLLThreshold = 0.0001;
		
		// Stores the output of the current model on a single event durring
		// training.  This will be reset for every event for every iteration.
		private double[] mModelDistribution;

		// Stores the number of features that get fired per event
		private int[] mFeatureCounts;

		// initial probability for all outcomes.
		private double mInitialProbability;
		
		private Dictionary<string, PatternedPredicate> mPredicates;
		private int[][] mOutcomePatterns;

		#region smoothing algorithm (unused)

//		internal class UpdateParametersWithSmoothingProcedure : Trove.IIntDoubleProcedure
//		{

//			private double mdSigma = 2.0;

//			public UpdateParametersWithSmoothingProcedure(GisTrainer enclosingInstance)
//			{
//				moEnclosingInstance = enclosingInstance;
//			}
//		
//			private GisTrainer moEnclosingInstance;
//
//			public virtual bool Execute(int outcomeID, double input)
//			{
//				double x = 0.0;
//				double x0 = 0.0;
//				double tmp;
//				double f;
//				double fp;
//				for (int i = 0; i < 50; i++) 
//				{
//					// check what domain these parameters are in
//					tmp = moEnclosingInstance.maoModelExpections[moEnclosingInstance.miPredicateID][outcomeID] * System.Math.Exp(moEnclosingInstance.miConstant * x0);
//					f = tmp + (input + x0) / moEnclosingInstance.mdSigma - moEnclosingInstance.maoObservedExpections[moEnclosingInstance.miPredicateID][outcomeID];
//					fp = tmp * moEnclosingInstance.miConstant + 1 / moEnclosingInstance.mdSigma;
//					if (fp == 0) 
//					{
//						break;
//					}
//					x = x0 - f / fp;
//					if (System.Math.Abs(x - x0) < 0.000001) 
//					{
//						x0 = x;
//						break;
//					}
//					x0 = x;
//				}
//				moEnclosingInstance.maoParameters[moEnclosingInstance.miPredicateID].Put(outcomeID, input + x0);
//				return true;
//			}
//		}

		#endregion

		#region training progress event

		/// <summary>
		/// Used to provide informational messages regarding the
		/// progress of the training algorithm.
		/// </summary>
		public event TrainingProgressEventHandler TrainingProgress;

		/// <summary>
		/// Used to raise events providing messages with information
		/// about training progress.
		/// </summary>
		/// <param name="e">
		/// Contains the message with information about the progress of 
		/// the training algorithm.
		/// </param>
		protected virtual void OnTrainingProgress(TrainingProgressEventArgs e) 
		{
			if (TrainingProgress != null) 
			{
				TrainingProgress(this, e); 
			}
		}

		private void NotifyProgress(string message)
		{
			OnTrainingProgress(new TrainingProgressEventArgs(message));
		}

		#endregion

		#region training options

		private bool mSimpleSmoothing = false;
		private bool mUseSlackParameter = false;
		private double mSmoothingObservation = 0.1;

    	/// <summary>
    	/// Sets whether this trainer will use smoothing while training the model.
		/// This can improve model accuracy, though training will potentially take
		/// longer and use more memory.  Model size will also be larger.
		/// </summary>
		/// <remarks>
		/// Initial testing indicates improvements for models built on small data sets and
		/// few outcomes, but performance degradation for those with large data
		/// sets and lots of outcomes.
		/// </remarks>
		public virtual bool Smoothing
		{
			get
			{
				return mSimpleSmoothing;
			}
			set
			{
				mSimpleSmoothing = value;
			}
		}

		/// <summary>
		/// Sets whether this trainer will use slack parameters while training the model.
		/// </summary>
		public virtual bool UseSlackParameter
		{
			get
			{
				return mUseSlackParameter;
			}
			set
			{
				mUseSlackParameter = value;
			}
		}

		/// <summary>
		/// If smoothing is in use, this value indicates the "number" of
		/// times we want the trainer to imagine that it saw a feature that it
		/// actually didn't see.  Defaulted to 0.1.
		/// </summary>
		virtual public double SmoothingObservation
		{
			get
			{
				return mSmoothingObservation;
			}
			set
			{
				mSmoothingObservation = value;
			}
			
		}
		
		/// <summary>
		/// Creates a new <code>GisTrainer</code> instance.
		/// </summary>
		public GisTrainer()
		{
			mSimpleSmoothing = false;
			mUseSlackParameter = false;
			mSmoothingObservation = 0.1;
		}

		/// <summary>
		/// Creates a new <code>GisTrainer</code> instance.
		/// </summary>
		/// <param name="useSlackParameter">
		/// Sets whether this trainer will use slack parameters while training the model.
		/// </param>
		public GisTrainer(bool useSlackParameter)
		{
			mSimpleSmoothing = false;
			mUseSlackParameter = useSlackParameter;
			mSmoothingObservation = 0.1;
		}

		/// <summary>
		/// Creates a new <code>GisTrainer</code> instance.
		/// </summary>
		/// <param name="smoothingObservation">
		/// If smoothing is in use, this value indicates the "number" of
		/// times we want the trainer to imagine that it saw a feature that it
		/// actually didn't see.  Defaulted to 0.1.
		/// </param>
		public GisTrainer(double smoothingObservation)
		{
			mSimpleSmoothing = true;
			mUseSlackParameter = false;
			mSmoothingObservation = smoothingObservation;
		}
		
		/// <summary>
		/// Creates a new <code>GisTrainer</code> instance.
		/// </summary>
		/// <param name="useSlackParameter">
		/// Sets whether this trainer will use slack parameters while training the model.
		/// </param>
		/// <param name="smoothingObservation">
		/// If smoothing is in use, this value indicates the "number" of
		/// times we want the trainer to imagine that it saw a feature that it
		/// actually didn't see.  Defaulted to 0.1.
		/// </param>
		public GisTrainer(bool useSlackParameter, double smoothingObservation)
		{
			mSimpleSmoothing = true;
			mUseSlackParameter = useSlackParameter;
			mSmoothingObservation = smoothingObservation;
		}

		#endregion

		#region alternative TrainModel signatures

		/// <summary>
		/// Train a model using the GIS algorithm.
		/// </summary>
		/// <param name="eventReader">
		/// The ITrainingEventReader holding the data on which this model
		/// will be trained.
		/// </param>
		public virtual void TrainModel(ITrainingEventReader eventReader)
		{
			TrainModel(eventReader, 100, 0);
		}

		/// <summary>
		/// Train a model using the GIS algorithm.
		/// </summary>
		/// <param name="eventReader">
		/// The ITrainingEventReader holding the data on which this model
		/// will be trained.
		/// </param>
		/// <param name="iterations">
		/// The number of GIS iterations to perform.
		/// </param>
		/// <param name="cutoff">
		/// The number of times a predicate must be seen in order
		/// to be relevant for training.
		/// </param>
		public virtual void TrainModel(ITrainingEventReader eventReader, int iterations, int cutoff)
		{
			TrainModel(iterations, new OnePassDataIndexer(eventReader, cutoff));
		}
		
		#endregion

		#region training algorithm

		/// <summary>
		/// Train a model using the GIS algorithm.
		/// </summary>
		/// <param name="iterations">
		/// The number of GIS iterations to perform.
		/// </param>
		/// <param name="dataIndexer">
		/// The data indexer used to compress events in memory.
		/// </param>
		public virtual void TrainModel(int iterations, ITrainingDataIndexer dataIndexer)
		{
			int[] outcomeList;

			//incorporate all of the needed info
			NotifyProgress("Incorporating indexed data for training...");
			mContexts = dataIndexer.GetContexts();
			mOutcomes = dataIndexer.GetOutcomeList();
			mNumTimesEventsSeen = dataIndexer.GetNumTimesEventsSeen();
			mTokenCount = mContexts.Length;
			
			// determine the correction constant and its inverse
			mMaximumFeatureCount = mContexts[0].Length;
			for (mTokenID = 1; mTokenID < mContexts.Length; mTokenID++)
			{
				if (mContexts[mTokenID].Length > mMaximumFeatureCount)
				{
					mMaximumFeatureCount = mContexts[mTokenID].Length;
				}
			}
			mMaximumFeatureCountInverse = 1.0 / mMaximumFeatureCount;
			
			NotifyProgress("done.");
			
			mOutcomeLabels = dataIndexer.GetOutcomeLabels();
			outcomeList = dataIndexer.GetOutcomeList();
			mOutcomeCount = mOutcomeLabels.Length;
			mInitialProbability = System.Math.Log(1.0 / mOutcomeCount);
			
			mPredicateLabels = dataIndexer.GetPredicateLabels();
			mPredicateCount = mPredicateLabels.Length;
			
			NotifyProgress("\tNumber of Event Tokens: " + mTokenCount);
			NotifyProgress("\t    Number of Outcomes: " + mOutcomeCount);
			NotifyProgress("\t  Number of Predicates: " + mPredicateCount);
			
			// set up feature arrays
			int[][] predicateCounts = new int[mPredicateCount][];
			for (mPredicateId = 0; mPredicateId < mPredicateCount; mPredicateId++)
			{
				predicateCounts[mPredicateId] = new int[mOutcomeCount];
			}
			for (mTokenID = 0; mTokenID < mTokenCount; mTokenID++)
			{
				for (int currentContext = 0; currentContext < mContexts[mTokenID].Length; currentContext++)
				{
					predicateCounts[mContexts[mTokenID][currentContext]][outcomeList[mTokenID]] += mNumTimesEventsSeen[mTokenID];
				}
			}
			
			dataIndexer = null; // don't need it anymore
			
			// A fake "observation" to cover features which are not detected in
			// the data.  The default is to assume that we observed "1/10th" of a
			// feature during training.
			double smoothingObservation = mSmoothingObservation;
			
			// Get the observed expectations of the features. Strictly speaking,
			// we should divide the counts by the number of Tokens, but because of
			// the way the model's expectations are approximated in the
			// implementation, this is cancelled out when we compute the next
			// iteration of a parameter, making the extra divisions wasteful.
			mOutcomePatterns = new int[mPredicateCount][];
			mParameters = new double[mPredicateCount][];
			mModelExpections = new double[mPredicateCount][];
			mObservedExpections = new double[mPredicateCount][];
			
			int activeOutcomeCount;
			int currentOutcome;

			for (mPredicateId = 0; mPredicateId < mPredicateCount; mPredicateId++)
			{
				if (mSimpleSmoothing)
				{
					activeOutcomeCount = mOutcomeCount;
				}
				else
				{
					activeOutcomeCount = 0;
					for (mOutcomeId = 0; mOutcomeId < mOutcomeCount; mOutcomeId++)
					{
						if (predicateCounts[mPredicateId][mOutcomeId] > 0)
						{
							activeOutcomeCount++;
						}
					}
				}

				mOutcomePatterns[mPredicateId] = new int[activeOutcomeCount];
				mParameters[mPredicateId] = new double[activeOutcomeCount];
				mModelExpections[mPredicateId] = new double[activeOutcomeCount];
				mObservedExpections[mPredicateId] = new double[activeOutcomeCount];

				currentOutcome = 0;
				for (mOutcomeId = 0; mOutcomeId < mOutcomeCount; mOutcomeId++)
				{
					if (predicateCounts[mPredicateId][mOutcomeId] > 0)
					{
						mOutcomePatterns[mPredicateId][currentOutcome] = mOutcomeId;
						mObservedExpections[mPredicateId][currentOutcome] = System.Math.Log(predicateCounts[mPredicateId][mOutcomeId]);
						currentOutcome++;
					}
					else if (mSimpleSmoothing)
					{
						mOutcomePatterns[mPredicateId][currentOutcome] = mOutcomeId;
						mObservedExpections[mPredicateId][currentOutcome] = smoothingObservation;
						currentOutcome++;
					}
				}
			}
			
			// compute the expected value of correction
			if (mUseSlackParameter) 
			{
				int correctionFeatureValueSum = 0;
				for (mTokenID = 0; mTokenID < mTokenCount; mTokenID++)
				{
					for (int currentContext = 0; currentContext < mContexts[mTokenID].Length; currentContext++)
					{
						mPredicateId = mContexts[mTokenID][currentContext];

						if ((!mSimpleSmoothing) && predicateCounts[mPredicateId][mOutcomes[mTokenID]] == 0)
						{
							correctionFeatureValueSum += mNumTimesEventsSeen[mTokenID];
						}
					}
					correctionFeatureValueSum += (mMaximumFeatureCount - mContexts[mTokenID].Length) * mNumTimesEventsSeen[mTokenID];
				}
				if (correctionFeatureValueSum == 0)
				{
					mCorrectionFeatureObservedExpectation = System.Math.Log(mNearZero); //nearly zero so log is defined
				}
				else
				{
					mCorrectionFeatureObservedExpectation = System.Math.Log(correctionFeatureValueSum);
				}
			
				mCorrectionParameter = 0.0;
			}

			predicateCounts = null; // don't need it anymore
			
			NotifyProgress("...done.");
			
			mModelDistribution = new double[mOutcomeCount];
			mFeatureCounts = new int[mOutcomeCount];
			
			//Find the parameters
			NotifyProgress("Computing model parameters...");
			FindParameters(iterations);
			
			NotifyProgress("Converting to new predicate format...");
			ConvertPredicates();

		}
		
		/// <summary>
		/// Estimate and return the model parameters.
		/// </summary>
		/// <param name="iterations">
		/// Number of iterations to run through.
		/// </param>
		private void FindParameters(int iterations)
		{
			double previousLogLikelihood = 0.0;
			double currentLogLikelihood = 0.0;
			NotifyProgress("Performing " + iterations + " iterations.");
			for (int currentIteration = 1; currentIteration <= iterations; currentIteration++)
			{
				if (currentIteration < 10)
				{
					NotifyProgress("  " + currentIteration + ":  ");
				}
				else if (currentIteration < 100)
				{
					NotifyProgress(" " + currentIteration + ":  ");
				}
				else
				{
					NotifyProgress(currentIteration + ":  ");
				}
				currentLogLikelihood = NextIteration();
				if (currentIteration > 1)
				{
					if (previousLogLikelihood > currentLogLikelihood)
					{
						throw new SystemException("Model Diverging: loglikelihood decreased");
					}
					if (currentLogLikelihood - previousLogLikelihood < mLLThreshold)
					{
						break;
					}
				}
				previousLogLikelihood = currentLogLikelihood;
			}
			
			// kill a bunch of these big objects now that we don't need them
			mObservedExpections = null;
			mModelExpections = null;
			mNumTimesEventsSeen = null;
			mContexts = null;
		}
		
		/// <summary>
		/// Use this model to evaluate a context and return an array of the
		/// likelihood of each outcome given that context.
		/// </summary>
		/// <param name="context">
		/// The integers of the predicates which have been
		/// observed at the present decision point.
		/// </param>
		/// <param name="outcomeSums">
		/// The normalized probabilities for the outcomes given the
		/// context. The indexes of the double[] are the outcome
		/// ids.
		/// </param>
		protected virtual void Evaluate(int[] context, double[] outcomeSums)
		{
			for (int outcomeIndex = 0; outcomeIndex < mOutcomeCount; outcomeIndex++)
			{
				outcomeSums[outcomeIndex] = mInitialProbability;
				mFeatureCounts[outcomeIndex] = 0;
			}
			int[] activeOutcomes;
			int outcomeId;
			int predicateId;
			int currentActiveOutcome;

			for (int currentContext = 0; currentContext < context.Length; currentContext++)
			{
				predicateId = context[currentContext];
				activeOutcomes = mOutcomePatterns[predicateId];
				for (currentActiveOutcome = 0; currentActiveOutcome < activeOutcomes.Length; currentActiveOutcome++)
				{
					outcomeId = activeOutcomes[currentActiveOutcome];
					mFeatureCounts[outcomeId]++;
					outcomeSums[outcomeId] += mMaximumFeatureCountInverse * mParameters[predicateId][currentActiveOutcome];
				}
			}
			
			double sum = 0.0;
			for (int currentOutcomeId = 0; currentOutcomeId < mOutcomeCount; currentOutcomeId++)
			{
				outcomeSums[currentOutcomeId] = System.Math.Exp(outcomeSums[currentOutcomeId]);
				if (mUseSlackParameter) 
				{
					outcomeSums[currentOutcomeId] += ((1.0 - ((double) mFeatureCounts[currentOutcomeId] / mMaximumFeatureCount)) * mCorrectionParameter);
				}
				sum += outcomeSums[currentOutcomeId];
			}
			
			for (int currentOutcomeId = 0; currentOutcomeId < mOutcomeCount; currentOutcomeId++)
			{
				outcomeSums[currentOutcomeId] /= sum;
			}
		}
				
		/// <summary>
		/// Compute one iteration of GIS and retutn log-likelihood.
		/// </summary>
		/// <returns>The log-likelihood.</returns>
		private double NextIteration()
		{
			// compute contribution of p(a|b_i) for each feature and the new
			// correction parameter
			double logLikelihood = 0.0;
			mCorrectionFeatureModifier = 0.0;
			int eventCount = 0;
			int numCorrect = 0;
			int outcomeId;

            for (mTokenID = 0; mTokenID < mTokenCount; mTokenID++)
            {
                Evaluate(mContexts[mTokenID], mModelDistribution);
                for (int currentContext = 0; currentContext < mContexts[mTokenID].Length; currentContext++)
                {
                    mPredicateId = mContexts[mTokenID][currentContext];
                    for (int currentActiveOutcome = 0; currentActiveOutcome < mOutcomePatterns[mPredicateId].Length; currentActiveOutcome++)
                    {
                        outcomeId = mOutcomePatterns[mPredicateId][currentActiveOutcome];
                        mModelExpections[mPredicateId][currentActiveOutcome] += (mModelDistribution[outcomeId] * mNumTimesEventsSeen[mTokenID]);

                        if (mUseSlackParameter)
                        {
                            mCorrectionFeatureModifier += mModelDistribution[mOutcomeId] * mNumTimesEventsSeen[mTokenID];
                        }
                    }
                }

				if (mUseSlackParameter)
				{
					mCorrectionFeatureModifier += (mMaximumFeatureCount - mContexts[mTokenID].Length) * mNumTimesEventsSeen[mTokenID];
				}

				logLikelihood += System.Math.Log(mModelDistribution[mOutcomes[mTokenID]]) * mNumTimesEventsSeen[mTokenID];
				eventCount += mNumTimesEventsSeen[mTokenID];
				
				//calculation solely for the information messages
				int max = 0;
				for (mOutcomeId = 1; mOutcomeId < mOutcomeCount; mOutcomeId++)
				{
					if (mModelDistribution[mOutcomeId] > mModelDistribution[max])
					{
						max = mOutcomeId;
					}
				}
				if (max == mOutcomes[mTokenID])
				{
					numCorrect += mNumTimesEventsSeen[mTokenID];
				}
			}
			NotifyProgress(".");
			
			// compute the new parameter values
			for (mPredicateId = 0; mPredicateId < mPredicateCount; mPredicateId++)
			{
				for (int currentActiveOutcome = 0; currentActiveOutcome < mOutcomePatterns[mPredicateId].Length; currentActiveOutcome++)
				{
					outcomeId = mOutcomePatterns[mPredicateId][currentActiveOutcome];
					mParameters[mPredicateId][currentActiveOutcome] += (mObservedExpections[mPredicateId][currentActiveOutcome] - System.Math.Log(mModelExpections[mPredicateId][currentActiveOutcome]));
					mModelExpections[mPredicateId][currentActiveOutcome] = 0.0;// re-initialize to 0.0's
				}
			}

			if (mCorrectionFeatureModifier > 0.0 && mUseSlackParameter)
			{
				mCorrectionParameter += (mCorrectionFeatureObservedExpectation - System.Math.Log(mCorrectionFeatureModifier));
			}

			NotifyProgress(". logLikelihood=" + logLikelihood + "\t" + ((double) numCorrect / eventCount));
			return (logLikelihood);
		}
		
		/// <summary>
		/// Convert the predicate data into the outcome pattern / patterned predicate format used by the GIS models.
		/// </summary>
		private void ConvertPredicates()
		{
			PatternedPredicate[] predicates = new PatternedPredicate[mParameters.Length];
			
			for (mPredicateId = 0; mPredicateId < mPredicateCount; mPredicateId++)
			{
				double[] parameters = mParameters[mPredicateId];
				predicates[mPredicateId] = new PatternedPredicate(mPredicateLabels[mPredicateId], parameters);
			}

			OutcomePatternComparer comparer = new OutcomePatternComparer();
			Array.Sort(mOutcomePatterns, predicates, comparer);

            List<int[]> outcomePatterns = new List<int[]>();
			int currentPatternId = 0;
			int predicatesInPattern = 0;
			int[] currentPattern = mOutcomePatterns[0];

			for (mPredicateId = 0; mPredicateId < mPredicateCount; mPredicateId++)
			{
				if (comparer.Compare(currentPattern, mOutcomePatterns[mPredicateId]) == 0)
				{
					predicates[mPredicateId].OutcomePattern = currentPatternId;
					predicatesInPattern++;
				}
				else
				{
					int[] pattern = new int[currentPattern.Length + 1];
					pattern[0] = predicatesInPattern;
					currentPattern.CopyTo(pattern, 1);
					outcomePatterns.Add(pattern);
					currentPattern = mOutcomePatterns[mPredicateId];
					currentPatternId++;
					predicates[mPredicateId].OutcomePattern = currentPatternId;
					predicatesInPattern = 1;
				}
			}
			int[] finalPattern = new int[currentPattern.Length + 1];
			finalPattern[0] = predicatesInPattern;
			currentPattern.CopyTo(finalPattern, 1);
			outcomePatterns.Add(finalPattern);

			mOutcomePatterns = outcomePatterns.ToArray();
            mPredicates = new Dictionary<string, PatternedPredicate>(predicates.Length);
			for (mPredicateId = 0; mPredicateId < mPredicateCount; mPredicateId++)
			{
				mPredicates.Add(predicates[mPredicateId].Name, predicates[mPredicateId]);
			}
		}

		#endregion

		#region IGisModelReader implementation
		
		/// <summary>
		/// The correction constant for the model produced as a result of training.
		/// </summary>
		public int CorrectionConstant
		{
			get
			{
				return mMaximumFeatureCount;
			}
		}
	
		/// <summary>
		/// The correction parameter for the model produced as a result of training.
		/// </summary>
		public double CorrectionParameter
		{
			get
			{
				return mCorrectionParameter;
			}
		}
	
		/// <summary>
		/// Obtains the outcome labels for the model produced as a result of training.
		/// </summary>
		/// <returns>
		/// Array of outcome labels.
		/// </returns>
		public string[] GetOutcomeLabels()
		{
			return mOutcomeLabels;
		}
	
		/// <summary>
		/// Obtains the outcome patterns for the model produced as a result of training.
		/// </summary>
		/// <returns>
		/// Array of outcome patterns.
		/// </returns>
		public int[][] GetOutcomePatterns()
		{
			return mOutcomePatterns;
		}

		/// <summary>
		/// Obtains the predicate data for the model produced as a result of training.
		/// </summary>
		/// <returns>
		/// Dictionary containing PatternedPredicate objects.
		/// </returns>
        public Dictionary<string, PatternedPredicate> GetPredicates()
		{
			return mPredicates;
		}

		/// <summary>
		/// Returns trained model information for a predicate, given the predicate label.
		/// </summary>
		/// <param name="predicateLabel">
		/// The predicate label to fetch information for.
		/// </param>
		/// <param name="featureCounts">
		/// Array to be passed in to the method; it should have a length equal to the number of outcomes
		/// in the model.  The method increments the count of each outcome that is active in the specified
		/// predicate.
		/// </param>
		/// <param name="outcomeSums">
		/// Array to be passed in to the method; it should have a length equal to the number of outcomes
		/// in the model.  The method adds the parameter values for each of the active outcomes in the
		/// predicate.
		/// </param>
		public void GetPredicateData(string predicateLabel, int[] featureCounts, double[] outcomeSums)
		{
			PatternedPredicate predicate = (PatternedPredicate)mPredicates[predicateLabel];
			if (predicate != null)
			{
				int[] activeOutcomes = mOutcomePatterns[predicate.OutcomePattern];
					
				for (int currentActiveOutcome = 1; currentActiveOutcome < activeOutcomes.Length; currentActiveOutcome++)
				{
					int outcomeIndex = activeOutcomes[currentActiveOutcome];
					featureCounts[outcomeIndex]++;
					outcomeSums[outcomeIndex] += predicate.GetParameter(currentActiveOutcome - 1);
				}
			}
		}
		#endregion

		private class OutcomePatternComparer : IComparer<int[]>
		{

			internal OutcomePatternComparer()
			{
			}

			/// <summary>
			/// Compare two outcome patterns and determines which comes first,
			/// based on the outcome ids (lower outcome ids first)
			/// </summary>
            /// <param name="firstPattern">
			/// First outcome pattern to compare.
			/// </param>
            /// <param name="secondPattern">
			/// Second outcome pattern to compare.
			/// </param>
			/// <returns></returns>
            public virtual int Compare(int[] firstPattern, int[] secondPattern)
			{			
				int smallerLength = (firstPattern.Length > secondPattern.Length ? secondPattern.Length : firstPattern.Length);
			
				for (int currentOutcome = 0; currentOutcome < smallerLength; currentOutcome++)
				{
					if (firstPattern[currentOutcome] < secondPattern[currentOutcome])
					{
						return - 1;
					}
					else if (firstPattern[currentOutcome] > secondPattern[currentOutcome])
					{
						return 1;
					}
				}
			
				if (firstPattern.Length < secondPattern.Length)
				{
					return - 1;
				}
				else if (firstPattern.Length > secondPattern.Length)
				{
					return 1;
				}
			
				return 0;
			}
		}
	}

	/// <summary>
	/// Event arguments class for training progress events.
	/// </summary>
	public class TrainingProgressEventArgs : EventArgs
	{
		private string mMessage;
	
		/// <summary>
		/// Constructor for the training progress event arguments.
		/// </summary>
		/// <param name="message">
		/// Information message about the progress of training.
		/// </param>
		public TrainingProgressEventArgs(string message)
		{
			mMessage = message;
		}

		/// <summary>
		/// Information message about the progress of training.
		/// </summary>
		public string Message 
		{
			get
			{
				return mMessage;
			}
		}
	}

	/// <summary>
	/// Event handler delegate for the training progress event.
	/// </summary>
	public delegate void TrainingProgressEventHandler(object sender, TrainingProgressEventArgs e);


}

By viewing downloads associated with this article you agree to the Terms of Service and the article's licence.

If a file you wish to view isn't highlighted, and is a text file (not binary), please let us know and we'll add colourisation support for it.

License

This article has no explicit license attached to it but may contain usage terms in the article text or the download files themselves. If in doubt please contact the author via the discussion board below.

A list of licenses authors might use can be found here


Written By
Web Developer
United Kingdom United Kingdom
Richard Northedge is a senior developer with a UK Microsoft Gold Partner company. He has a postgraduate degree in English Literature, has been programming professionally since 1998 and has been an MCSD since 2000.

Comments and Discussions