|
#!/usr/bin/python
import fann
def print_callback(epochs, error):
print "Epochs %8d. Current MSE-Error: %.10f\n" % (epochs, error)
return 0
# initialize network parameters
connection_rate = 1
learning_rate = 0.7
num_neurons_hidden = 32
desired_error = 0.000001
max_iterations = 300
iterations_between_reports = 1
# create training data, and ann object
print "Creating network."
train_data = fann.read_train_from_file("datasets/mushroom.train")
ann = fann.create(connection_rate, learning_rate, (train_data.get_num_input(), num_neurons_hidden, train_data.get_num_output()))
# start training the network
print "Training network"
ann.set_activation_function_hidden(fann.FANN_SIGMOID_SYMMETRIC_STEPWISE)
ann.set_activation_function_output(fann.FANN_SIGMOID_STEPWISE)
ann.set_training_algorithm(fann.FANN_TRAIN_INCREMENTAL)
ann.train_on_data(train_data, max_iterations, iterations_between_reports, desired_error)
# test outcome
print "Testing network"
test_data = fann.read_train_from_file("datasets/mushroom.test")
ann.reset_MSE()
for i in range(test_data.get_num_data()):
ann.test(test_data.get_input(i), test_data.get_output(i))
print "MSE error on test data: %f" % ann.get_MSE()
# save network to disk
print "Saving network"
ann.save("mushroom_float.net")
# blow it all up
print "Cleaning up."
ann.destroy()
test_data.destroy()
train_data.destroy()
|
By viewing downloads associated with this article you agree to the Terms of Service and the article's licence.
If a file you wish to view isn't highlighted, and is a text file (not binary), please
let us know and we'll add colourisation support for it.