Click here to Skip to main content
15,885,875 members
Articles / Desktop Programming / MFC

Neural Network Classifier

Rate me:
Please Sign up or sign in to vote.
4.47/5 (23 votes)
29 Jan 2005CPOL2 min read 165K   8.4K   82  
A Multilayer perceptron used to classify blue and red points.
#pragma once
#include <afxtempl.h>
#include "Layer.h"
#include "Synapse.h"
#include "Neuron.h"
#include <math.h>
class Perceptron
{
public:
	~Perceptron(void){};	
	CList<Layer*>	layers;
	CList<Samples>	inputSamples;//Samples are array of double
	CList<Samples>	outputSamples;
	Layer*	inputLayer;
	Layer*	outputLayer;
	double	error;
	//CMatrix<double> *iWeights[4];//initial wights (4 : input, hid1, hid2, hid3)
	Perceptron(int i,int o)
	{
		//iWeights = NULL;
		layers.RemoveAll();
		inputSamples.RemoveAll();
		outputSamples.RemoveAll();
		inputLayer      = new Layer(_T("I"),i+1); // plus the bias
		outputLayer     = new Layer(_T("O"),o);
		layers.AddTail(inputLayer);
		layers.AddTail(outputLayer);
		error = 0.0;
	}
	void addLayer(int n,CString name)
	{
		POSITION tailPos = layers.GetTailPosition();
		Layer* layer = new Layer(name,n);
		layers.InsertBefore(tailPos,layer);
	}
	Layer* getLayer(int i)
	{
		int      j=0;
		bool     found=false;
		POSITION pos = layers.GetHeadPosition();
		Layer* layer = layers.GetAt(pos);
		for (int j = 0; j < layers.GetCount(); j++) 
		{
			layer = layers.GetNext(pos);
			if (i==j)
			{
				found = true;
				break;
			} 
		}

		if (found == false) 
			layer = NULL;
		return layer;
	}

	void connect(int sourceLayer,int sourceNeuron,
		int destLayer,int destNeuron, double w = MAXWORD)
	{
		new Synapse(getLayer(sourceLayer)->getNeuron(sourceNeuron),
			getLayer(destLayer)->getNeuron(destNeuron), w);
	}
	void biasConnect(int destLayer,int destNeuron, double w = MAXWORD)
	{
		new Synapse(inputLayer->getNeuron(inputLayer->size-1),
			getLayer(destLayer)->getNeuron(destNeuron), w);
	}
	void removeSamples()
	{
		inputSamples.RemoveAll();
		outputSamples.RemoveAll();
	}
	void addSample(Samples inputs,Samples outputs)
	{
		ASSERT( (inputs.nLenght > 0) && (outputs.nLenght > 0) );
		ASSERT( (inputs.samples) && (outputs.samples) );
		inputSamples.AddTail(inputs);
		outputSamples.AddTail(outputs);
	}
	CString printSamples()
	{
		//System.out.println(inputSamples+"->"+outputSamples);
		CString str = _T("\ninputs : ");
		POSITION pos = inputSamples.GetHeadPosition();
		double* pSamples = inputSamples.GetAt(pos).samples;
		int len = inputSamples.GetAt(pos).nLenght;
		Samples S;
		for (int i = 0; i < inputSamples.GetCount(); i++) 
		{
			S = inputSamples.GetNext(pos);
			pSamples = S.samples;
			len = S.nLenght;
			str.AppendFormat(_T("\nInput Sample %03d\n"), i);
			for (int j = 0; j < len; j++) 
			{
				str.AppendFormat(_T("%.f ,"), pSamples[j]);					
			}
		}
		
		pos = outputSamples.GetHeadPosition();
		for (int i = 0; i < outputSamples.GetCount(); i++) 
		{
			S = outputSamples.GetNext(pos);
			pSamples = S.samples;
			len = S.nLenght;
			str.AppendFormat(_T("\nOutput Sample %03d\n"), i);
			for (int j = 0; j < len; j++) 
			{
				str.AppendFormat(_T("%.f ,"), pSamples[j]);					
			}
		}
		OutputDebugString(str);
	}

	Samples recognize(Samples iS)
	{
		initInputs(iS);
		propagate();
		Samples oS = getOutput();
		return oS;
	}
	void learn(int iterations)
	{
		int i=0,j=0;
		for(i = 0; i<iterations; i++)
		{
			POSITION pos1 = inputSamples.GetHeadPosition();
			POSITION pos2 = outputSamples.GetHeadPosition();
			Samples pattern = inputSamples.GetAt(pos1);
			Samples target = outputSamples.GetAt(pos2);
			// accumulate total error over each epoch
			error = 0.0;
			for (j = 0; j < inputSamples.GetCount(); j++) 
			{			
				pattern = inputSamples.GetNext(pos1);
				target = outputSamples.GetNext(pos2);
				learnPattern(pattern, target);
				error += computeError(target);
			}
			error /= (inputSamples.GetCount()*outputLayer->neurons.GetCount());
			error = sqrt(error);
		}
	}
	void learnPattern(Samples iS, Samples oS)
	{
		initInputs(iS);
		propagate();
		bpAdjustWeights(oS);
	}
	void initInputs(Samples iS)
	{
		POSITION pos = inputLayer->neurons.GetHeadPosition();
		Neuron* neuron = (Neuron*)inputLayer->neurons.GetAt(pos);
		for (int i = 0; i < iS.nLenght; i++) 
		{
			neuron = (Neuron*)inputLayer->neurons.GetNext(pos);
			neuron->output = iS.samples[i];
		}
		neuron = (Neuron*)inputLayer->neurons.GetNext(pos);// bias;
		neuron->output = 1.0;
	}
	void propagate()
	{
		POSITION pos = layers.GetHeadPosition();
		Layer* layer = (Layer*)layers.GetAt(pos);
		layer = (Layer*)layers.GetNext(pos);// skip the input layer
		for (int j = 0; j < layers.GetCount()-1; j++) 
		{
			layer = (Layer*)layers.GetNext(pos);
			layer->computeOutputs();
		}
	}
	Samples getOutput()
	{
		double sum=0.0;
		double tmp;
		POSITION pos = outputLayer->neurons.GetHeadPosition();
		Samples oS = {0,outputLayer->neurons.GetCount()};
		//oS.samples = new double[outputLayer->neurons.GetCount()];
		Neuron* neuron = (Neuron*)outputLayer->neurons.GetAt(pos);
		for (int j = 0; j < outputLayer->neurons.GetCount(); j++) 
		{
			neuron = (Neuron*)outputLayer->neurons.GetNext(pos);
			oS.samples[j] = neuron->getOutput();
		}
		return oS;
	}
	double computeError(Samples oS)
	{
		double sum=0.0;
		double tmp;
		POSITION pos = outputLayer->neurons.GetHeadPosition();
		Neuron* neuron = (Neuron*)outputLayer->neurons.GetAt(pos);
		for (int i = 0; i < outputLayer->neurons.GetCount(); i++) 
		{
			neuron = (Neuron*)outputLayer->neurons.GetNext(pos);
			tmp = oS.samples[i] - neuron->getOutput();
			sum += tmp * tmp;
		}
		return sum;//sum/2.0;
	}
	double currentError() {
		return error;
	}
	void bpAdjustWeights(Samples oS)
	{
		outputLayer->computeBackpropDeltas(oS);
		POSITION pos = layers.GetHeadPosition();
		
		for(int i=layers.GetCount()-2; i>=1; i--)
		{
			layers.GetNext(pos);
			((Layer*)layers.GetAt(pos))->computeBackpropDeltas();
		}
		outputLayer->computeWeights();
		pos = layers.GetHeadPosition();
		for(int i=layers.GetCount()-2; i>=1; i--)
		{
			layers.GetNext(pos);
			((Layer*)layers.GetAt(pos))->computeWeights();
		}
	}
	CString print()
	{
		CString str = _T("");
		POSITION pos = layers.GetHeadPosition();
		Layer* layer;
		for (int i = 0; i < layers.GetCount(); i++) 
		{
			layer = layers.GetNext(pos);
			str +=layer->print();
		}
		return str;
	}

	int SaveNetwork(CString FileName)
	{
		FILE *netFile = _wfopen(FileName, _T("wt"));
		CTime time = CTime::GetCurrentTime();
		CString str = _T(":::In the name of Allah:::\n:::MultiLayer Perceptron Information\n:::Producer : Hossein Khosravi\n");
		str += time.Format(_T(":::Creation Time : %A, %B %d, %Y (%H:%M)\n"));
		fprintf(netFile, "%S", str);
		int nLayers = layers.GetCount();
		POSITION pos = layers.GetHeadPosition();
		Layer *inLayer = layers.GetNext(pos);
		Layer *hid1,*hid2,*hid3;
		if (nLayers == 2)//no hidden layer
		{
			fprintf(netFile, "Inputs : %3d\nOutputs : %3d\nHid1 : %3d\nHid2 : %3d\nHid3 : %3d\n", 
				layers.GetNext(pos)->neurons.GetCount(),
				layers.GetTail()->neurons.GetCount(),0,0,0);
		}
		else if (nLayers == 3)//one hidden layer
		{
			hid1 = layers.GetNext(pos);
			fprintf(netFile, "Inputs : %3d\nOutputs : %3d\nHid1 : %3d\nHid2 : %3d\nHid3 : %3d\n", 
				inLayer->neurons.GetCount(),
				layers.GetTail()->neurons.GetCount(),
				hid1->neurons.GetCount(),0,0);
		}
		else if (nLayers == 4)//two hidden layer
		{
			hid1 = layers.GetNext(pos);
			hid2 = layers.GetNext(pos);
			fprintf(netFile, "Inputs : %3d\nOutputs : %3d\nHid1 : %3d\nHid2 : %3d\nHid3 : %3d\n", 
				inLayer->neurons.GetCount(),
				layers.GetTail()->neurons.GetCount(),
				hid1->neurons.GetCount(),
				hid2->neurons.GetCount(),0);
		}
		else if (nLayers == 5)//three hidden layer
		{
			hid1 = layers.GetNext(pos);
			hid2 = layers.GetNext(pos);
			hid3 = layers.GetNext(pos);
			fprintf(netFile, "Inputs : %3d\nOutputs : %3d\nHid1 : %3d\nHid2 : %3d\nHid3 : %3d\n", 
				inLayer->neurons.GetCount(),
				layers.GetTail()->neurons.GetCount(),
				hid1->neurons.GetCount(),
				hid2->neurons.GetCount(),
				hid3->neurons.GetCount());
		}
		fprintf(netFile, "%S", print());

		fclose(netFile);
		return 0;
	}

	int LoadNet(CString FileName)
	{
		int in = 0, out = 0, h1 = 0, h2 = 0, h3 = 0;
		int i = 0, j = 0;
		CStdioFile netFile(FileName, CFile::modeRead|CFile::typeText);
		CString str;
		netFile.ReadString(str);
		while (str[0] == ':')
		{
			netFile.ReadString(str);
		}
		in = _wtoi(str.Right(3));
		netFile.ReadString(str);
		out = _wtoi(str.Right(3));
		netFile.ReadString(str);
		h1 = _wtoi(str.Right(3));
		netFile.ReadString(str);
		h2 = _wtoi(str.Right(3));
		netFile.ReadString(str);
		h3 = _wtoi(str.Right(3));
		netFile.ReadString(str);//read blank line

		POSITION pos = layers.GetHeadPosition();
		int strPos = 0;
		CString strRes;
		Layer* layer = NULL;
		double weights[1024];
		int k = 0;
		for (i = 0; i < layers.GetCount()-1; i++) 
		{
			layer = layers.GetNext(pos);
			for (j = 0; j < layer->neurons.GetCount(); j++) 
			{
				strPos = 0;
				k = 0;
				netFile.ReadString(str);//skip header line
				netFile.ReadString(str);
				strRes = str.Tokenize(_T("| ()"), strPos);
				while (strRes != "")
				{
					strRes= str.Tokenize(_T("\ ()"),strPos);
					if(strRes.Find('.') != -1)//it is a weight
					{
						weights[k] = _wtof(strRes);
						k++;
					}
				};
				layer->getNeuron(j)->SetWeights(weights);
			}
		}
		return 0;
	}
};

By viewing downloads associated with this article you agree to the Terms of Service and the article's licence.

If a file you wish to view isn't highlighted, and is a text file (not binary), please let us know and we'll add colourisation support for it.

License

This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)


Written By
Software Developer (Senior) https://shahaab-co.com
Iran (Islamic Republic of) Iran (Islamic Republic of)
Currently I'm working at Dept. of Electrical Engineering in University of Shahrood.
Pattern Recognition (specially OCR), Neural Networks, Image Processing and Machine Vision are my interests. However I'm a PROGRAMMER as well.
BSc: Sharif University of technology @ 2002
MSc. and PhD: Tarbiat Modarres University @ 2006 & 2010 respectively

Personal Blog: Andisheh Online

Religious Blogs: Shia Muslims , Islamic Quotes

Company Site: Shahaab-co
My old Site: Farsi OCR

Comments and Discussions