ReInventing Neural Networks - Part 2





5.00/5 (19 votes)
In Part 2, the Neural Network made in Part 1 is tested in an environment made in Unity so that we can see how well it performs.
The Full Series
- Part 1: We create the whole
NeuralNetwork
class from scratch. - Part 2: We create an environment in Unity in order to test the neural network within that environment.
- Part 3: We make a great improvement to the neural network already created by adding a new type of mutation to the code.
Introduction
A few days ago, I posted this article explaining how you can implement a neural network from scratch in C#. However, in the last article, the neural network was trained on an XOR function. As promised, we're going to train simple cars, in Unity, to drive! Here's what we're aiming for:
After I finished the video, I felt like it was some creepy 90s clip, but it does the job ...
Background
To follow along this article, you'll need to have basic C# and Unity programming knowledge. Also, you're going to need to have to read my previous article where I first implemented the NeuralNetwork
class.
Pre-Programming Resources
In case you're new to C#, you can always search the MSDN docs for the stuff you're not familiar with, but in case you want to look for something Unity-specific, you may want to search Unity's Scripting Reference or Unity's Manual instead.
Using the Code
First off, you gotta know all the classes that are going to be used in the project:
Car
: The main script that controls the movement of thecar
object (controlled by aNeuralNetwork
or by the user).Wall
: A simple script that is attached to every wall. It sends a "Die
" message to acar
if it hits an object with this script on it.Checkpoint
: A simple script that increases the fitness(score) of acar
once it is hit.EvolutionManager
: That script simply waits for all thecar
s to die, then it makes a new generation from the bestcar
.CameraFollow
: That's the function that changes to position of thecamera
to look at the bestcar
.
Here's how it's all going to work:
- There is going to be a track with a series of checkpoints along its path.
- Once a
car
hits a checkpoint, its fitness increases. - If a
car
hits a wall, it gets destroyed. - If all the
car
s are destroyed, a new generation is created from the bestcar
in the last generation.
Now, we're going to go over each script and explain it in a little more detail.
NeuralNetwork
An entire article was devoted to that one ...
Car
First, we got to have a few variables defined:
[SerializeField] bool UseUserInput = false; // Defines whether the car
// uses a NeuralNetwork or user input
[SerializeField] LayerMask SensorMask; // Defines the layer of the walls ("Wall")
[SerializeField] float FitnessUnchangedDie = 5; // The number of seconds to wait
// before checking if the fitness
// didn't increase
public static NeuralNetwork NextNetwork = new NeuralNetwork
(new uint[] { 6, 4, 3, 2 }, null); // public NeuralNetwork that refers to
// the next neural network to be set to
// the next instantiated car
public string TheGuid { get; private set; } // The Unique ID of the current car
public int Fitness { get; private set; } // The fitness/score of the current car.
// Represents the number of checkpoints
// that his car hit.
public NeuralNetwork TheNetwork { get; private set; } // The NeuralNetwork of
// the current car
Rigidbody TheRigidbody; // The Rigidbody of the current car
LineRenderer TheLineRenderer; // The LineRenderer of the current car
That's what we should do whenever a new car
is created:
private void Awake()
{
TheGuid = Guid.NewGuid().ToString(); // Assigns a new Unique ID for the current car
TheNetwork = NextNetwork; // Sets the current network to the Next Network
NextNetwork = new NeuralNetwork(NextNetwork.Topology, null); // Make sure the
// Next Network is reassigned to avoid having another car use the same network
TheRigidbody = GetComponent<Rigidbody>(); // Assign Rigidbody
TheLineRenderer = GetComponent<LineRenderer>(); // Assign LineRenderer
StartCoroutine(IsNotImproving()); // Start checking if the score stayed
// the same for a lot of time
TheLineRenderer.positionCount = 17; // Make sure the line is long enough
}
This is the IsNotImproving
function:
// Checks every few seconds if the car didn't make any improvement
IEnumerator IsNotImproving ()
{
while(true)
{
int OldFitness = Fitness; // Save the initial fitness
yield return new WaitForSeconds(FitnessUnchangedDie); // Wait for some time
if (OldFitness == Fitness) // Check if the fitness didn't change yet
WallHit(); // Kill this car
}
}
This is the Move
function that(wait for it...) "Moves" the car:
// The main function that moves the car.
public void Move (float v, float h)
{
TheRigidbody.velocity = transform.right * v * 4;
TheRigidbody.angularVelocity = transform.up * h * 3;
}
Then comes the CastRay
function that does a casts and visualises rays. It'll be used later on:
// Casts a ray and makes it visible through the line renderer
double CastRay (Vector3 RayDirection, Vector3 LineDirection, int LinePositionIndex)
{
float Length = 4; // Maximum length of each ray
RaycastHit Hit;
if (Physics.Raycast(transform.position, RayDirection,
out Hit, Length, SensorMask)) // Cast a ray
{
float Dist = Vector3.Distance
(Hit.point, transform.position); // Get the distance of the hit in the line
TheLineRenderer.SetPosition(LinePositionIndex,
Dist * LineDirection); // Set the position of the line
return Dist; // Return the distance
}
else
{
TheLineRenderer.SetPosition(LinePositionIndex,
LineDirection * Length); // Set the distance of the hit in the line
// to the maximum distance
return Length; // Return the maximum distance
}
}
Follows ... the GetNeuralInputAxisFunction
that does a lot of the work for us:
// Casts all the rays, puts them through the NeuralNetwork and outputs the Move Axis
void GetNeuralInputAxis (out float Vertical, out float Horizontal)
{
double[] NeuralInput = new double[NextNetwork.Topology[0]];
// Cast forward, back, right and left
NeuralInput[0] = CastRay(transform.forward, Vector3.forward, 1) / 4;
NeuralInput[1] = CastRay(-transform.forward, -Vector3.forward, 3) / 4;
NeuralInput[2] = CastRay(transform.right, Vector3.right, 5) / 4;
NeuralInput[3] = CastRay(-transform.right, -Vector3.right, 7) / 4;
// Cast forward-right and forward-left
float SqrtHalf = Mathf.Sqrt(0.5f);
NeuralInput[4] = CastRay(transform.right * SqrtHalf +
transform.forward * SqrtHalf, Vector3.right * SqrtHalf +
Vector3.forward * SqrtHalf, 9) / 4;
NeuralInput[5] = CastRay(transform.right * SqrtHalf + -transform.forward * SqrtHalf,
Vector3.right * SqrtHalf + -Vector3.forward * SqrtHalf, 13) / 4;
// Feed through the network
double[] NeuralOutput = TheNetwork.FeedForward(NeuralInput);
// Get Vertical Value
if (NeuralOutput[0] <= 0.25f)
Vertical = -1;
else if (NeuralOutput[0] >= 0.75f)
Vertical = 1;
else
Vertical = 0;
// Get Horizontal Value
if (NeuralOutput[1] <= 0.25f)
Horizontal = -1;
else if (NeuralOutput[1] >= 0.75f)
Horizontal = 1;
else
Horizontal = 0;
// If the output is just standing still, then move the car forward
if (Vertical == 0 && Horizontal == 0)
Vertical = 1;
}
And, that's what we do 50 times per second:
private void FixedUpdate()
{
if (UseUserInput) // If we're gonna use user input
Move(Input.GetAxisRaw("Vertical"),
Input.GetAxisRaw("Horizontal")); // Moves the car according to the input
else // if we're gonna use a neural network
{
float Vertical;
float Horizontal;
GetNeuralInputAxis(out Vertical, out Horizontal);
Move(Vertical, Horizontal); // Moves the car
}
}
We also need to have a few functions that're going to be called from other scripts (Checkpoint
and Wall
):
// This function is called through all the checkpoints when the car hits any.
public void CheckpointHit ()
{
Fitness++; // Increase Fitness/Score
}
// Called by walls when hit by the car
public void WallHit()
{
EvolutionManager.Singleton.CarDead(this, Fitness); // Tell the Evolution Manager
// that the car is dead
gameObject.SetActive(false); // Make sure the car is inactive
}
Wall
The Wall
script simply notifies any car that hits it:
using UnityEngine;
public class Wall : MonoBehaviour
{
[SerializeField] string LayerHitName = "CarCollider"; // The name of the layer
// set on each car
private void OnCollisionEnter(Collision collision) // Once anything hits the wall
{
if (collision.gameObject.layer == LayerMask.NameToLayer(LayerHitName)) // Make sure
// it's a car
{
collision.transform.GetComponent<Car>().WallHit(); // If it is a car,
// tell it that it just hit a wall
}
}
}
Checkpoint
The Checkpoint
does almost the same thing as the Wall
, but with a twist. Checkpoint
s use a Trigger
instead of a Collider
, and Checkpoint
s also make sure they increase the fitness of each car
only once. This is why each Car
has a Unique ID. Each Checkpoint
simply saves all the Guid
s of Car
s increased before:
using System.Collections.Generic;
using UnityEngine;
public class Checkpoint : MonoBehaviour
{
[SerializeField] string LayerHitName = "CarCollider"; // The name of the layer set
// on each car
List<string> AllGuids = new List<string>(); // The list of Guids of all the
// cars increased
private void OnTriggerEnter(Collider other) // Once anything goes through the wall
{
if(other.gameObject.layer == LayerMask.NameToLayer(LayerHitName)) // If this object
// is a car
{
Car CarComponent = other.transform.parent.GetComponent<Car>(); // Get the component
// of the car
string CarGuid = CarComponent.TheGuid; // Get the Unique ID of the car
if (!AllGuids.Contains(CarGuid)) // If we didn't increase
// the car before
{
AllGuids.Add(CarGuid); // Make sure we don't
// increase it again
CarComponent.CheckpointHit(); // Increase the car's fitness
}
}
}
}
EvolutionManager
You can't write a script without variables:
public static EvolutionManager Singleton = null; // The current EvolutionManager Instance
[SerializeField] int CarCount = 100; // The number of cars per generation
[SerializeField] GameObject CarPrefab; // The Prefab of the car to be created
// for each instance
[SerializeField] Text GenerationNumberText; // Some text to write the generation number
int GenerationCount = 0; // The current generation number
List<Car> Cars = new List<Car>(); // This list of cars currently alive
NeuralNetwork BestNeuralNetwork = null; // The best NeuralNetwork
// currently available
int BestFitness = -1; // The Fitness of the
// best NeuralNetwork ever created
On the start of the program:
// On Start
private void Start()
{
if (Singleton == null) // If no other instances were created
Singleton = this; // Make the only instance this one
else
gameObject.SetActive(false); // There is another instance already in place.
// Make this one inactive.
BestNeuralNetwork = new NeuralNetwork(Car.NextNetwork); // Set the BestNeuralNetwork
// to a random new network
StartGeneration();
}
That's how a new generation is created:
// Starts a whole new generation
void StartGeneration ()
{
GenerationCount++; // Increment the generation count
GenerationNumberText.text = "Generation: " + GenerationCount; // Update generation text
for (int i = 0; i < CarCount; i++)
{
if (i == 0)
Car.NextNetwork = BestNeuralNetwork; // Make sure one car uses the best network
else
{
Car.NextNetwork = new NeuralNetwork(BestNeuralNetwork); // Clone the best
// neural network and set it to be for the next car
Car.NextNetwork.Mutate(); // Mutate it
}
Cars.Add(Instantiate(CarPrefab, transform.position,
Quaternion.identity, transform).GetComponent<Car>()); // Instantiate
// a new car and add it to the list of cars
}
}
Stuff called by the Car
s:
// Gets called by cars when they die
public void CarDead (Car DeadCar, int Fitness)
{
Cars.Remove(DeadCar); // Remove the car from the list
Destroy(DeadCar.gameObject); // Destroy the dead car
if (Fitness > BestFitness) // If it is better that the current best car
{
BestNeuralNetwork = DeadCar.TheNetwork; // Make sure it becomes the best car
BestFitness = Fitness; // And also set the best fitness
}
if (Cars.Count <= 0) // If there are no cars left
StartGeneration(); // Create a new generation
}
CameraFollow
Just another simple all-in-one script that does the job:
using UnityEngine;
public class CameraFollow : MonoBehaviour
{
Vector3 SmoothPosVelocity; // Velocity of Position Smoothing
Vector3 SmoothRotVelocity; // Velocity of Rotation Smoothing
void FixedUpdate ()
{
Car BestCar = transform.GetChild(0).GetComponent<Car>(); // The best car in
// the bunch is the first one
for (int i = 1; i < transform.childCount; i++) // Loop over all the cars
{
Car CurrentCar = transform.GetChild(i).GetComponent<Car>(); // Get the component
// of the current car
if (CurrentCar.Fitness > BestCar.Fitness) // If the current car is better than
// the best car
{
BestCar = CurrentCar; // Then, the best car is the current car
}
}
Transform BestCarCamPos = BestCar.transform.GetChild(0); // The target position
// of the camera relative to the best car
Camera.main.transform.position = Vector3.SmoothDamp
(Camera.main.transform.position, BestCarCamPos.position,
ref SmoothPosVelocity, 0.7f); // Smoothly set the position
Camera.main.transform.rotation = Quaternion.Lerp(Camera.main.transform.rotation,
Quaternion.LookRotation(BestCar.transform.position -
Camera.main.transform.position),
0.1f); // Smoothly set the rotation
}
}
Points of Interest
Now that we have all the scripts explained in detail, you can sleep well knowing that the NeuralNetwork
class previously implemented works well and is not a waste of time. It felt really good to see those cars learning how they can drive through the track so step-by-step. Also, the car uses built-in sensors, which means that the car can drive on tracks it didn't learn driving on before. Once I got that done, I felt like my binary children were learning how to drive! I tried as hard as I can to make this implementation as simple as possible for people who don't wanna deeply dig into Unity's stuff. And... never think for a bit that we're done here. My current target is to implement 3 Crossover operators to make evolution a bit more efficient and offer the developer more diversity. After that, Backpropagation is the target.
Update on 20th February, 2018
Part 3 is up and running! It shows a substantial improvement over the system discussed in Parts 1 and 2. Tell me what you think!
History
- 11th December 2017: Version 1.0: Main Implementation