#!/usr/bin/python

# NeuralNetwork.py
# Copyright (C) Tobias Hermann 2012 <daiw@gmx.net>
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program.  If not, see <http://www.gnu.org/licenses/>.


import random
import turtle
import time
import colorsys
import itertools
import math


def calc_color(i):
    """Return a RGB color on the edge of the HSV cylinder (rainbow)."""
    return colorsys.hsv_to_rgb(i/16 % 1, 1, 1)


def generate_neuron_index():
    """Generator for indexes for neurons."""
    index = 0
    while True:
        yield index
        index += 1


def generate_connection_index():
    """Generator for indexes for connections."""
    index = 0
    while True:
        yield index
        index += 1


_neuron_index_generator = generate_neuron_index()
_connection_index_generator = generate_connection_index()


def sigmoid(x):
    """Activation function for neurons."""
    #return 1/(1+math.exp(-x))
    return math.tanh(x)


def sigmoid_derivative(x):
    """Derivation of the activation function for neurons."""
    #return x*(1-x)
    return 1 - x**2


def wait_for_input():
    """Just wait for some input."""
    dummyInput = input('Please enter anything to quit.')


class Neuron:
    """One single neuron in a network."""
    def __init__(self, bias = 0):
        self.index = _neuron_index_generator.__next__()
        self.connectionsFrom = set()
        self.connectionsTo = set()
        self.activation = self.input = self.error = 0
        self.delta = self.x = self.y = 0
        self.bias = bias
    def __str__(self):
        return ('Neuron(index,activation,input,bias),%i,%-.4f,%-.4f,%-.4f'
                %(self.index, self.activation, self.input, self.bias))
    def Propagate(self):
        self.activation = sigmoid(self.input) + self.bias
        for connection in self.connectionsTo:
            connection.dest.input += connection.weight * self.activation
    def back_propagate(self):
        self.delta = sigmoid_derivative(self.activation) * self.error
        for connection in self.connectionsFrom:
            connection.source.error += connection.weight * self.delta


class Connection:
    """One single connection between two neurons in a network."""
    def __init__(self, source, dest):
        self.index = _connection_index_generator.__next__()
        self.source = source
        self.dest = dest
        self.weight = random.uniform(-1, 1)
        self.change = 0
    def __str__(self):
        return ('Connection(index,source.index,weight,change,dest.index)'+
                '%i,%i,%-.4f,%-.4f,%i' % (self.index, self.source.index,
                self.weight, self.change, self.dest.index))
    def __repr__(self):
        return str(self)
    def __eq__(self, other):
        return self.source == other.source and self.dest == other.dest
    def __hash__(self):
        return self.index


class NeuralNetwork:
    """A complete neural network."""
    def __init__(self, layerSizes):
        self.biasNeuronOn = Neuron(1)
        self.layers = [[Neuron() for i in range(layerSize)]
                        for layerSize in layerSizes]
        self.layers[0].append(self.biasNeuronOn)
        for n, layer in enumerate(self.layers[:-1]):
            self.connect_all(layer, self.layers[n+1])
    def get_connection_weights(self):
        return [connection.weight for layer in self.layers
                for neuron in layer for connection in neuron.connectionsTo]
    def clear_neurons_for_update(self):
        for layer in self.layers[1:]:
            for neuron in layer:
                neuron.activation = neuron.input = 0
    def clear_neurons_for_back_propagation(self):
        for layer in reversed(self.layers):
            for neuron in layer:
                neuron.error = neuron.delta = 0
    def connect(self, sourceNeuron, destNeuron):
        connection = Connection(sourceNeuron, destNeuron)
        sourceNeuron.connectionsTo.add(connection)
        destNeuron.connectionsFrom.add(connection)
    def connect_all(self, sourceLayer, destLayer):
        for sourceNeuron in sourceLayer:
            for destNeuron in destLayer:
                self.connect(sourceNeuron, destNeuron)
    def draw(self):
        turtle.speed('fastest')
        turtle.hideturtle()
        turtle.bgcolor('black')
        counter = 0
        dotSize = 16
        xDist = 128
        yDist = 32

        for i, layer in enumerate(self.layers):
            for j, neuron in enumerate(layer):
                neuron.x += i * xDist + (xDist/2 if neuron.bias > 0 else 0)
                neuron.y += (-j+len(layer)/2)*yDist

        for layer in self.layers:
            for neuron in layer:
                for connectionTo in neuron.connectionsTo:
                    counter += 1
                    turtle.up()
                    turtle.setpos(neuron.x, neuron.y)
                    turtle.down()
                    turtle.pencolor(calc_color(counter))
                    turtle.setpos(connectionTo.dest.x, connectionTo.dest.y)

        for i, layer in enumerate(self.layers):
            for j, neuron in enumerate(layer):
                counter += 1
                turtle.up()
                turtle.setpos(neuron.x, neuron.y)
                turtle.down()
                turtle.pencolor(calc_color(counter))
                turtle.dot(dotSize)
        turtleScreen = turtle.getscreen()
        turtleScreen.getcanvas().postscript(file="NeuralNetwork.eps")
        wait_for_input()
        turtle.bye()

    def set_input(self, inputVector):
        for inputValue, inputNeuron in zip(inputVector, self.layers[0]):
            inputNeuron.input = inputValue

    def get_output(self):
        outputVector = []
        for outputNeuron in self.layers[-1]:
            outputVector.append(outputNeuron.output)
        return outputVector

    def __str__(self):
        result = ''
        for n, layer in enumerate(self.layers):
            result += 'Layer ' + str(n) + '\n'
            for neuron in layer:
                result += str(neuron) + '\n'
                for connection in neuron.connectionsTo:
                    result += str(connection) + '\n'
        return result[:-1]

    def update(self):
        self.clear_neurons_for_update()
        for layer in self.layers:
            for neuron in layer:
                neuron.Propagate()

    def back_propagate(self, goals, N, M):
        self.clear_neurons_for_back_propagation()

        for neuron, goal in zip(self.layers[-1], goals):
            neuron.error = goal - neuron.activation

        for layer in reversed(self.layers):
            for neuron in layer:
                neuron.back_propagate()

        for layer in reversed(self.layers):
            for neuron in layer:
                for connection in neuron.connectionsFrom:
                    change = neuron.delta * connection.source.activation
                    connection.weight += N * change + M * connection.change
                    connection.change = change

        return sum([(out.activation - goal)**2
                    for out, goal in zip(self.layers[-1], goals)])

    def learn(self, trainingSet, maxError=0.07, changeSpeed=0.4,
                changeMomentumFactor=0.15):
        print('Learning...')
        counter = 0
        print('iteration,error' + ''.join([',w'+str(i)
            for i in range(len(self.get_connection_weights()))]) +
            ''.join([',c'+str(i)
            for i in range(len(self.get_connection_weights()))]))
        while(True):
            error = 0
            for data in trainingSet:
                self.set_input(data[0])
                self.update()
                error += self.back_propagate(data[1], changeSpeed,
                                            changeMomentumFactor)
            counter += 1
            cSVLine = '%i,%-.4f,' % (counter, error)
            weights = self.get_connection_weights()
            for weight in weights:
                cSVLine += '%-.4f,' % weight
            print(cSVLine[:-1])
            if error < maxError:
                break

    def test(self, testSet):
        print('Testing...')
        wrongs = []
        for data in testSet:
            self.set_input(data[0])
            self.update()
            results = [neuron.activation for neuron in self.layers[-1]]
            binarizedResults = [0 if result < 0.5 else 1
                                for result in results ]
            errors = [abs(result-goal)
                        for result, goal in zip(results, data[1])]
            print('input:', data[0])
            print('output:', ['%-.4f' % result for result in results])
            print('binarizedResults:', binarizedResults)
            print('correctResults:', data[1])
            print('---')
            if binarizedResults != data[1]:
                wrongs.append([data[0], data[1], binarizedResults])
        if not wrongs:
            print('All correct.')
        else:
            print('wrongs:', wrongs)


def neural_network_demo():
    """Demonstrate learning and recognizing of a MLP"""
    inputs = list(itertools.product([0,1], repeat=4))
    outputs = [[(1 if sum(i) <= len(i)/2 else 0),
                (1 if sum(i) >= len(i)/2 else 0)] for i in inputs]
    trainingset = list(zip(inputs, outputs))

    layerSizes = [
        len(inputs[0]),
        int(len(inputs[0])*len(outputs[0])/2 - 1),
        #3,7,5,
        len(outputs[0])
        ]

    MLP = NeuralNetwork(layerSizes)

    print(MLP)
    MLP.draw()
    MLP.learn(trainingset, maxError=0.3)
    print(MLP)
    MLP.test(trainingset)


if __name__ == "__main__":
    neural_network_demo()

