#region File and License Information
/*
<File>
<Copyright>Copyright © 2007, Daniel Vaughan. All rights reserved.</Copyright>
<License see="prj:///Documentation/License.txt"/>
<Owner Name="Daniel Vaughan" Email="dbvaughan@gmail.com"/>
<CreationDate>2009-01-18 17:06:42Z</CreationDate>
<LastSubmissionDate>$Date: $</LastSubmissionDate>
<Version>$Revision: $</Version>
</File>
*/
#endregion
using System;
using System.Collections.Generic;
using System.Linq;
namespace DanielVaughan.AI.NeuralNetworking
{
public partial class NeuralNetwork
{
public TimedTrainingResult Train(bool[][] input, bool[][] expectedOutput, double minimumAccuracy, long timeoutMilliseconds)
{
ArgumentValidator.AssertNotNull(input, "input");
ArgumentValidator.AssertNotNull(expectedOutput, "expectedOutput");
ArgumentValidator.AssertGreaterThan(minimumAccuracy, 0, "minimumAccuracy");
ArgumentValidator.AssertGreaterThan(timeoutMilliseconds, 0, "timeoutMilliseconds");
if (input.Length != expectedOutput.Length)
{
throw new ArgumentException("input and expectedOutput length must be equal.");
}
double[][] inputDoubles = ConvertToDoubleArray(input);
double[][] outputDoubles = ConvertToDoubleArray(expectedOutput);
var mappings = new List<KeyValuePair<LayerStimulus, LayerStimulus>>();
for (int i = 0; i < inputDoubles.Length; i++)
{
mappings.Add(new KeyValuePair<LayerStimulus, LayerStimulus>(
new LayerStimulus(inputDoubles[i]),
new LayerStimulus(outputDoubles[i])));
}
var trainingSet = new TrainingSet(mappings);
return Train(trainingSet, minimumAccuracy, timeoutMilliseconds);
}
public TimedTrainingResult Train(TrainingSet trainingSet, double minimumAccuracy, long timeoutMilliseconds)
{
ArgumentValidator.AssertNotNull(trainingSet, "trainingSet");
ArgumentValidator.AssertGreaterThan(minimumAccuracy, 0, "minimumAccuracy");
ArgumentValidator.AssertGreaterThan(timeoutMilliseconds, 0, "timeoutMilliseconds");
DateTime startTime = DateTime.Now;
var inputDoublesList = new List<double[]>();
var outputDoublesList = new List<double[]>();
foreach (var pair in trainingSet.InputOutputDictionary)
{
inputDoublesList.Add(pair.Key.Data);
outputDoublesList.Add(pair.Value.Data);
}
var inputDoubles = inputDoublesList.ToArray();
var outputDoubles = outputDoublesList.ToArray();
int sampleIterations = 1000;
/* Determine how long it takes to run 1000 epochs. */
Train(trainingSet, TrainingType.BackPropagation, sampleIterations);
var sampleTime = DateTime.Now - startTime;
double accuracy;
if (sampleTime.TotalMilliseconds > timeoutMilliseconds)
{
accuracy = MeasureAccuracy(inputDoubles, outputDoubles);
var trainingResult = accuracy >= minimumAccuracy ? TrainingResult.Success : TrainingResult.RanOutOfTime;
var result = new TimedTrainingResult { AccuracyAttained = accuracy, TrainingResult = trainingResult };
return result;
}
int iterations = 0;
if (timeoutMilliseconds > 500) /* Try and get as close to 500 ms for each training run. */
{
var iterationsIn500ms = (500 / sampleTime.TotalMilliseconds) * sampleIterations;
iterations = (int)iterationsIn500ms;
}
if (iterations < 1)
{
iterations = 1;
}
accuracy = MeasureAccuracy(inputDoubles, outputDoubles);
while (true)
{
DateTime now = DateTime.Now;
var duration = now - startTime;
if (duration.TotalMilliseconds > timeoutMilliseconds)
{
var result = new TimedTrainingResult { AccuracyAttained = accuracy, TrainingResult = TrainingResult.RanOutOfTime };
return result;
}
Train(trainingSet, TrainingType.BackPropagation, iterations);
accuracy = MeasureAccuracy(inputDoubles, outputDoubles);
if (accuracy >= minimumAccuracy)
{
var result = new TimedTrainingResult { AccuracyAttained = accuracy, TrainingResult = TrainingResult.Success };
return result;
}
}
}
void Train(TrainingSet trainingSet, TrainingType trainingType, int iterations)
{
switch (trainingType)
{
case TrainingType.BackPropagation:
lock (networkLock)
{
if (totalTrainingSet != null)
{
var unionResult = totalTrainingSet.InputOutputDictionary.Union(trainingSet.InputOutputDictionary);
totalTrainingSet = new TrainingSet(unionResult);
}
else
{
totalTrainingSet = new TrainingSet(trainingSet);
}
for (int i = 0; i < iterations; i++)
{
InitializeLearning(); /* Set all weight changes to zero. */
foreach (var pair in totalTrainingSet.InputOutputDictionary)
{
TrainUsingBackPropogation(pair.Key.Data, pair.Value.Data);
}
ApplyLearning(); /* Apply batch of cumlutive weight changes. */
}
}
break;
default:
throw new ArgumentException("Unexpected TrainingType");
}
}
// void SaveContext()
// {
// for (int i = 0; i < HiddenLayer.Count; i++)
// {
// var contextNeuron = ContextLayer[i];
// var hiddenNeuron = HiddenLayer[i];
// contextNeuron.Bias.Weight = hiddenNeuron.Bias.Weight;
// contextNeuron.Bias.WeightDelta = hiddenNeuron.Bias.WeightDelta;
// contextNeuron.Error = hiddenNeuron.Error;
// contextNeuron.LastError = hiddenNeuron.LastError;
// foreach (var pair in hiddenNeuron.Inputs)
// {
// NeuralBias bias;
// if (contextNeuron.Inputs.TryGetValue(pair.Key, out bias))
// {
// bias.Weight = pair.Value.Weight;
// bias.WeightDelta = pair.Value.WeightDelta;
// }
// }
// //contextNeuron.Output
// }
// }
}
}