Click here to Skip to main content
13,795,352 members
Click here to Skip to main content
Add your own
alternative version

Tagged as

Stats

14.4K views
1.8K downloads
44 bookmarked
Posted 11 Dec 2017
Licenced CPOL

ReInventing Neural Networks - Part 2

, 19 Feb 2018
Rate this:
Please Sign up or sign in to vote.
In Part 2, the Neural Network made in part 1 is tested in an environment made in Unity so that we can see how well it performs.

The Full Series:

  • Part 1: We create the whole NeuralNetwork class from scratch.
  • Part 2: We create an envirnment in Unity in order to test the neural network within that environment.
  • Part 3: We make a great improvement to the neural network already created by adding a new type of mutation to the code.

Introduction

Hello Fellas! A few days ago, I posted this article explaining how you can implement a neural network from scratch in C#. However, in the last article the neural network was trained on an XOR function. As promised, we're gonna train simple cars, in Unity, to drive! Here's what we're aiming for:

 

 

After I finished the video, I felt like it was some creepy 90s clip, but it does the job ...

Background

To follow along this article, you'll need to have basic C# and Unity programming knowledge. Also you're gonna need to have read my previous article where I first implemented the NeuralNetwork class.

Pre-Programming Resources

In case you're new to C#, you can always search the msdn docs for the stuff you're not familiar with, but in case you want to look for something Unity-specific, you may wanna search Unity's Scripting Reference or Unity's Manual instead.

Using the code

First of, you gotta know all the classes that are going to be used in the project:

Here's how it's all gonna work:

Now, we're gonna go over each script and explain it in a little more detail.

NeuralNetwork

A whole article was devoted to that one ...

Car

First, we gotta have a few variables defined:

[SerializeField] bool UseUserInput = false; // Defines whether the car uses a NeuralNetwork or user input
[SerializeField] LayerMask SensorMask; // Defines the layer of the walls ("Wall")
[SerializeField] float FitnessUnchangedDie = 5; // The number of seconds to wait before checking if the fitness didn't increase

public static NeuralNetwork NextNetwork = new NeuralNetwork(new uint[] { 6, 4, 3, 2 }, null); // public NeuralNetwork that refers to the next neural network to be set to the next instantiated car

public string TheGuid { get; private set; } // The Unique ID of the current car

public int Fitness { get; private set; } // The fitness/score of the current car. Represents the number of checkpoints that his car hit.

public NeuralNetwork TheNetwork { get; private set; } // The NeuralNetwork of the current car

Rigidbody TheRigidbody; // The Rigidbody of the current car
LineRenderer TheLineRenderer; // The LineRenderer of the current car

That's what we should do whenever a new car is created:

private void Awake()
{
    TheGuid = Guid.NewGuid().ToString(); // Assigns a new Unique ID for the current car

    TheNetwork = NextNetwork; // Sets the current network to the Next Network
    NextNetwork = new NeuralNetwork(NextNetwork.Topology, null); // Make sure the Next Network is reassigned to avoid having another car use the same network

    TheRigidbody = GetComponent<Rigidbody>(); // Assign Rigidbody
    TheLineRenderer = GetComponent<LineRenderer>(); // Assign LineRenderer

    StartCoroutine(IsNotImproving()); // Start checking if the score stayed the same for a lot of time

    TheLineRenderer.positionCount = 17; // Make sure the line is long enough
}

This is the IsNotImproving function:

// Checks each few seconds if the car didn't make any improvement
IEnumerator IsNotImproving ()
{
    while(true)
    {
        int OldFitness = Fitness; // Save the initial fitness
        yield return new WaitForSeconds(FitnessUnchangedDie); // Wait for some time
        if (OldFitness == Fitness) // Check if the fitness didn't change yet
            WallHit(); // Kill this car
    }
}

This is the Move function that(wait for it...) "Moves" the car:

// The main function that moves the car.
public void Move (float v, float h)
{
    TheRigidbody.velocity = transform.right * v * 4;
    TheRigidbody.angularVelocity = transform.up * h * 3;
}

Then comes the CastRay function that does a casts and visualises rays. It'll be used later on:

// Casts a ray and makes it visible through the line renderer
double CastRay (Vector3 RayDirection, Vector3 LineDirection, int LinePositionIndex)
{
    float Length = 4; // Maximum length of each ray

    RaycastHit Hit;
    if (Physics.Raycast(transform.position, RayDirection, out Hit, Length, SensorMask)) // Cast a ray
    {
        float Dist = Vector3.Distance(Hit.point, transform.position); // Get the distance of the hit in the line
        TheLineRenderer.SetPosition(LinePositionIndex, Dist * LineDirection); // Set the position of the line

        return Dist; // Return the distance
    }
    else
    {
        TheLineRenderer.SetPosition(LinePositionIndex, LineDirection * Length); // Set the distance of the hit in the line to the maximum distance

        return Length; // Return the maximum distance
    }
}

Follows ... the GetNeuralInputAxisFunction that does a lot of the work for us:

// Casts all the rays, puts them through the NeuralNetwork and outputs the Move Axis
void GetNeuralInputAxis (out float Vertical, out float Horizontal)
{
    double[] NeuralInput = new double[NextNetwork.Topology[0]];

    // Cast forward, back, right and left
    NeuralInput[0] = CastRay(transform.forward, Vector3.forward, 1) / 4;
    NeuralInput[1] = CastRay(-transform.forward, -Vector3.forward, 3) / 4;
    NeuralInput[2] = CastRay(transform.right, Vector3.right, 5) / 4;
    NeuralInput[3] = CastRay(-transform.right, -Vector3.right, 7) / 4;

    // Cast forward-right and forward-left
    float SqrtHalf = Mathf.Sqrt(0.5f);
    NeuralInput[4] = CastRay(transform.right * SqrtHalf + transform.forward * SqrtHalf, Vector3.right * SqrtHalf + Vector3.forward * SqrtHalf, 9) / 4;
    NeuralInput[5] = CastRay(transform.right * SqrtHalf + -transform.forward * SqrtHalf, Vector3.right * SqrtHalf + -Vector3.forward * SqrtHalf, 13) / 4;

    // Feed through the network
    double[] NeuralOutput = TheNetwork.FeedForward(NeuralInput);

    // Get Vertical Value
    if (NeuralOutput[0] <= 0.25f)
        Vertical = -1;
    else if (NeuralOutput[0] >= 0.75f)
        Vertical = 1;
    else
        Vertical = 0;

    // Get Horizontal Value
    if (NeuralOutput[1] <= 0.25f)
        Horizontal = -1;
    else if (NeuralOutput[1] >= 0.75f)
        Horizontal = 1;
    else
        Horizontal = 0;

    // If the output is just standing still, then move the car forward
    if (Vertical == 0 && Horizontal == 0)
        Vertical = 1;
}

And, that's what we do 50 times per second:

private void FixedUpdate()
{
    if (UseUserInput) // If we're gonna use user input
        Move(Input.GetAxisRaw("Vertical"), Input.GetAxisRaw("Horizontal")); // Moves the car according to the input
    else // if we're gonna use a neural network
    {
        float Vertical;
        float Horizontal;

        GetNeuralInputAxis(out Vertical, out Horizontal);

        Move(Vertical, Horizontal); // Moves the car
    }
}

We also need to have a few functions that're going to be called from other scripts(Checkpoint and Wall):

// This function is called through all the checkpoints when the car hits any.
public void CheckpointHit ()
{
    Fitness++; // Increase Fitness/Score
}

// Called by walls when hit by the car
public void WallHit()
{
    EvolutionManager.Singleton.CarDead(this, Fitness); // Tell the Evolution Manager that the car is dead
    gameObject.SetActive(false); // Make sure the car is inactive
}

Wall

The Wall script simply notifies any car that hits it:

using UnityEngine;

public class Wall : MonoBehaviour
{
    [SerializeField] string LayerHitName = "CarCollider"; // The name of the layer set on each car

    private void OnCollisionEnter(Collision collision) // Once anything hits the wall
    {
        if (collision.gameObject.layer == LayerMask.NameToLayer(LayerHitName)) // Make sure it's a car
        {
            collision.transform.GetComponent<Car>().WallHit(); // If it is a car, tell it that it just hit a wall
        }
    }
}

Checkpoint

The Checkpoint does almost the same thing as the Wall, but with a twist. Checkpoints use a Trigger instead of a Collider, and Checkpoints also make sure they increase the fitness of each car only once. This is why each Car has a Unique ID. Each Checkpoint simply saves all the Guids of Cars increased before:

using System.Collections.Generic;
using UnityEngine;

public class Checkpoint : MonoBehaviour
{
    [SerializeField] string LayerHitName = "CarCollider"; // The name of the layer set on each car

    List<string> AllGuids = new List<string>(); // The list of Guids of all the cars increased

    private void OnTriggerEnter(Collider other) // Once anything goes through the wall
    {
        if(other.gameObject.layer == LayerMask.NameToLayer(LayerHitName)) // If this object is a car
        {
            Car CarComponent = other.transform.parent.GetComponent<Car>(); // Get the compoent of the car
            string CarGuid = CarComponent.TheGuid; // Get the Unique ID of the car

            if (!AllGuids.Contains(CarGuid)) // If we didn't increase the car before
            {
                AllGuids.Add(CarGuid); // Make sure we don't increase it again
                CarComponent.CheckpointHit(); // Increase the car's fitness
            }
        }
    }
}

EvolutionManager

You can't write a script without variables:

public static EvolutionManager Singleton = null; // The current EvolutionManager Instance

[SerializeField] int CarCount = 100; // The number of cars per generation
[SerializeField] GameObject CarPrefab; // The Prefab of the car to be created for each instance
[SerializeField] Text GenerationNumberText; // Some text to write the generation number

int GenerationCount = 0; // The current generation number

List<Car> Cars = new List<Car>(); // This list of cars currently alive

NeuralNetwork BestNeuralNetwork = null; // The best NeuralNetwork currently available
int BestFitness = -1; // The FItness of the best NeuralNetwork ever created

On the start of the program:

// On Start
private void Start()
{
    if (Singleton == null) // If no other instances were created
        Singleton = this; // Make the only instance this one
    else
        gameObject.SetActive(false); // There is another instance already in place. Make this one inactive.

    BestNeuralNetwork = new NeuralNetwork(Car.NextNetwork); // Set the BestNeuralNetwork to a random new network

    StartGeneration();
}

That's how a new generation is created:

// Sarts a whole new generation
void StartGeneration ()
{
    GenerationCount++; // Increment the generation count
    GenerationNumberText.text = "Generation: " + GenerationCount; // Update generation text

    for (int i = 0; i < CarCount; i++)
    {
        if (i == 0)
            Car.NextNetwork = BestNeuralNetwork; // Make sure one car uses the best network
        else
        {
            Car.NextNetwork = new NeuralNetwork(BestNeuralNetwork); // Clone the best neural network and set it to be for the next car
            Car.NextNetwork.Mutate(); // Mutate it
        }

        Cars.Add(Instantiate(CarPrefab, transform.position, Quaternion.identity, transform).GetComponent<Car>()); // Instantiate a new car and add it to the list of cars
    }
}

Stuff called by the Cars:

// Gets called by cars when they die
public void CarDead (Car DeadCar, int Fitness)
{
    Cars.Remove(DeadCar); // Remove the car from the list
    Destroy(DeadCar.gameObject); // Destroy the dead car

    if (Fitness > BestFitness) // If it is better that the current best car
    {
        BestNeuralNetwork = DeadCar.TheNetwork; // Make sure it becomes the best car
        BestFitness = Fitness; // And also set the best fitness
    }

    if (Cars.Count <= 0) // If there are no cars left
        StartGeneration(); // Create a new generation
}

CameraFollow

Just another simple all-in-one script that does the job:

using UnityEngine;

public class CameraFollow : MonoBehaviour
{
    Vector3 SmoothPosVelocity; // Velocity of Position Smoothing
    Vector3 SmoothRotVelocity; // Velocity of Rotation  Smoothing

    void FixedUpdate ()
    {
        Car BestCar = transform.GetChild(0).GetComponent<Car>(); // The best car in the bunch is the first one

        for (int i = 1; i < transform.childCount; i++) // Loop over all the cars
        {
            Car CurrentCar = transform.GetChild(i).GetComponent<Car>(); // Get the component of the current car

            if (CurrentCar.Fitness > BestCar.Fitness) // If the current car is better than the best car
            {
                BestCar = CurrentCar; // Then, the best car is the current car
            }
        }

        Transform BestCarCamPos = BestCar.transform.GetChild(0); // The target position of the camera relative to the best car

        Camera.main.transform.position = Vector3.SmoothDamp(Camera.main.transform.position, BestCarCamPos.position, ref SmoothPosVelocity, 0.7f); // Smoothly set the position

        Camera.main.transform.rotation = Quaternion.Lerp(Camera.main.transform.rotation,
                                                         Quaternion.LookRotation(BestCar.transform.position - Camera.main.transform.position),
                                                         0.1f); // Smoothly set the rotation
    }
}

Points of Interest

Now that we have all the scripts explained in detail, you can sleep well knowing that the NeuralNetwork class previously implemented works well and is not a waste of time. It felt really good to see those cars learning how they can drive through the track so step-by-step. Also, the car uses built-in sensors, which means that the car can drive on tracks it didn't learn driving on before. Once I got that done, I felt like my binary children were learning how to drive! I tried as hard as I can to make this implementation as simple as possible for people who don't wanna deeply dig into Unity's stuff. And... Never think for a bit that we're done here. My current target is to implement 3 Crossover operators to make evolution a bit more efficient and offer the developer more diversity. After that, Backpropagation is the target.

Update on February 20th 2018:

Part 3 is up and running! It shows a substantial improvement over the system discussed in Parts 1 and 2. Tell me what you think!

History

Version 1.0: Main Implemetation (December 11th 2017)

License

This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)

Share

About the Author

Byte-Master-101
Student
Egypt Egypt
Hello, I'm a programmer by day, and a Jedi Ninja by night! I mean... I like programming Neural Networks. I'm kinda experienced with C#, JavaScript, MySQL, Network Programming, Unity Game Programming and a bit of C++. I also tried out freelancing for a few months, and it worked well. If anybody wants to talk to me, you can contact me on email(mokhtar.mohammed.red@gmail.com) or on Skype(mokhtar.mohammed.red@gmail.com).

You may also be interested in...

Comments and Discussions

 
QuestionIs it possible to get the Unity project? Pin
baker395013-Dec-17 18:50
memberbaker395013-Dec-17 18:50 
AnswerRe: Is it possible to get the Unity project? Pin
Byte-Master-10113-Dec-17 23:45
memberByte-Master-10113-Dec-17 23:45 
QuestionWhat I found Interesting... Pin
Foothill12-Dec-17 11:49
professionalFoothill12-Dec-17 11:49 
AnswerRe: What I found Interesting... Pin
Byte-Master-10112-Dec-17 12:08
memberByte-Master-10112-Dec-17 12:08 
GeneralRe: What I found Interesting... Pin
Foothill12-Dec-17 12:35
professionalFoothill12-Dec-17 12:35 
GeneralRe: What I found Interesting... Pin
Byte-Master-10112-Dec-17 14:01
memberByte-Master-10112-Dec-17 14:01 
QuestionReally nice!!! Pin
Dewey12-Dec-17 10:31
memberDewey12-Dec-17 10:31 
PraiseRe: Really nice!!! Pin
Byte-Master-10112-Dec-17 11:08
memberByte-Master-10112-Dec-17 11:08 

General General    News News    Suggestion Suggestion    Question Question    Bug Bug    Answer Answer    Joke Joke    Praise Praise    Rant Rant    Admin Admin   

Use Ctrl+Left/Right to switch messages, Ctrl+Up/Down to switch threads, Ctrl+Shift+Left/Right to switch pages.

Permalink | Advertise | Privacy | Cookies | Terms of Use | Mobile
Web01 | 2.8.181207.3 | Last Updated 20 Feb 2018
Article Copyright 2017 by Byte-Master-101
Everything else Copyright © CodeProject, 1999-2018
Layout: fixed | fluid