home / grad_rs / src / network.rs

network.rs



use rand::{Rng, thread_rng};

#[derive(Debug)]
pub struct Network  {
    layers: Vec<Layer>
}

pub struct LayerState {
    pub outs: Vec<Vec<f64>>,
    pub ins: Vec<Vec<f64>>,
    pub out_grads: Vec<Vec<f64>>,
    pub in_grads: Vec<Vec<f64>>
}

impl LayerState {
   pub fn get_outs(&self) -> &[f64] {
       return self.outs.last().unwrap()
   }

    pub fn set_ins(&mut self, ins: &[f64]) {
            self.ins[0].copy_from_slice(&ins);
    }

    pub fn set_grads(&mut self,
                     labels: &[f64],
                     derivative_function: fn (label: f64, pred: f64) -> f64){
        labels.iter().zip(self.outs.last().unwrap().iter())
            .map(|(&label, &pred)| derivative_function(label, pred))
            .enumerate()
            .for_each(|(dix, d)| self.out_grads.last_mut().unwrap()[dix] = d);
    }
}

impl Network {

    pub fn construct_random(layer_def: &[usize]) -> Network {
        let layers: Vec<Layer> = layer_def.iter()
            .inspect(|&def| assert_ne!(*def, 0))
            .map(|&d| d)
            .collect::<Vec<usize>>()
            .windows(2)
            .map(|window | Layer::construct(window[0], window[1]))
            .collect();
        return Network { layers };
    }

    pub fn give_layer_state(&self) -> LayerState {
        let mut outs: Vec<Vec<f64>> = Vec::with_capacity(self.layers.len());
        let mut ins: Vec<Vec<f64>> = Vec::with_capacity(self.layers.len());
        let mut in_grads: Vec<Vec<f64>> = Vec::with_capacity(self.layers.len());
        let mut out_grads: Vec<Vec<f64>> = Vec::with_capacity(self.layers.len());
        for ix in 0..self.layers.len() {
            outs.push(vec![0.0; self.layers[ix].neurons.len()]);
            ins.push(vec![0.0; self.layers[ix].neurons[0].weights.len()]);
            in_grads.push(vec![0.0; self.layers[ix].neurons.len()]);
            out_grads.push(vec![0.0; self.layers[ix].neurons[0].weights.len()]);
        }
        return LayerState { outs, ins, out_grads: in_grads, in_grads: out_grads };
    }

    pub fn output(&self, input: &[f64]) -> Vec<f64> {
        let mut next_input = input.to_vec();
        for layer in &self.layers {
            next_input = layer.output(&next_input).collect()
        }
        next_input
    }

    pub fn output_training(&self, state: &mut LayerState) {
        for ix in 0..self.layers.len() {
            for (oix, o) in self.layers[ix].output(&state.ins[ix]).enumerate() {
                state.outs[ix][oix] = o;
            }
            if ix < self.layers.len() - 1 {
                state.ins[ix + 1].copy_from_slice(&state.outs[ix]);
            }
        }
    }

    pub fn backward(&mut self, layer_state: &mut LayerState) {
        for ix in (1..self.layers.len()).rev() {
            let l = &mut self.layers[ix];
            l.backward(
                &layer_state.out_grads[ix],
                &layer_state.ins[ix],
                &layer_state.outs[ix],
                &mut layer_state.in_grads[ix]
            );
            layer_state.out_grads[ix - 1].copy_from_slice(&layer_state.in_grads[ix])
        }
        self.layers[0].backward_last(
            &layer_state.out_grads[0],
            &layer_state.ins[0],
            &layer_state.outs[0]
        );
    }

    pub fn adjust_weights(&mut self, learning_rate: f64) {
        for l in &mut self.layers {
            for n in &mut l.neurons {
                n.adjust_weights(learning_rate);
            }
        }
    }

    pub fn zero_grad(&mut self) {
        for l in &mut self.layers {
            for n in &mut l.neurons {
                n.zero_grad();
            }
        }
    }
}

#[derive(Debug)]
struct Layer  {
    neurons: Vec<Neuron>
}

impl Layer {

    fn construct(input_count: usize, layer_size: usize) -> Layer {
        let neurons = (0..layer_size).map(|_| Neuron::with_random_weights(input_count))
            .collect::<Vec<Neuron>>();
        return Layer { neurons };
    }

    fn output<'a>(&'a self, input: &'a [f64]) -> impl Iterator<Item = f64> + 'a  {
        self.neurons.iter().map(|n| n.output(input))
    }

    fn backward(
        &mut self,
        chained_grads: &[f64],
        inputs: &[f64],
        outputs: &[f64],
        input_chained_grads: &mut [f64]
    ){
        for ix in 0..input_chained_grads.len() { input_chained_grads[ix] = 0.0; }
        assert_eq!(chained_grads.len(), self.neurons.len());
        assert_eq!(input_chained_grads.len(), self.neurons[0].weights.len());
        for ix in 0..self.neurons.len() {
            let chained = self.neurons[ix].backward(chained_grads[ix], inputs, outputs[ix]);
            self.neurons[ix].input_grads(chained)
                .enumerate()
                .for_each(|(cgix, cg)| input_chained_grads[cgix] += cg);
        }
    }

    fn backward_last(
        &mut self,
        chained_grads: &[f64],
        inputs: &[f64],
        outputs: &[f64]
    ) {
        for ix in 0..self.neurons.len() {
            self.neurons[ix].backward(chained_grads[ix], inputs, outputs[ix]);
        }
    }

}

#[derive(Debug)]
struct Neuron {
    weights: Vec<f64>,
    weights_grad: Vec<f64>,
    bias: f64,
    bias_grad: f64
}

impl Neuron {

    pub fn with_random_weights(weight_count: usize) -> Neuron {
        let mut rng = thread_rng();
        let weights = (0..weight_count).map(|_| rng.gen_range(-1.0..1.0)).collect();
        Neuron::with_defined_weights(weights)
    }

    pub fn with_defined_weights(weights: Vec<f64>) -> Neuron {
        let weights_grad = vec![0.0; weights.len()];
        Neuron { weights, weights_grad,  bias: 0.0, bias_grad: 0.0 }
    }

    fn weighted_sum(&self, input: &[f64]) -> f64 {
        assert_eq!(input.len(), self.weights.len());

        let zipped = input.iter().zip(self.weights.iter());
        return zipped.map(|(i, w)| i * w).sum();
    }

    fn activation(&self, sum: f64) -> f64 {
        sum.tanh()
    }

    fn activation_derivative(&self, output: f64) -> f64 {
        1.0 - output.powf(2.0)
    }

    fn output(&self, input: &[f64]) -> f64 {
        let sum = self.weighted_sum(input);
        let activation = self.activation(sum + self.bias);
        return activation
    }

    fn backward(
        &mut self,
        chained_grad: f64,
        input: &[f64],
        output: f64,
    ) -> f64 {
        assert_eq!(input.len(), self.weights.len());
        let activation_derivative = self.activation_derivative(output) * chained_grad;
        self.bias_grad += activation_derivative;
        input.iter()
            .map(|i| i * activation_derivative)
            .enumerate()
            .for_each(|(ix, grad)| self.weights_grad[ix] += grad);
        activation_derivative
    }

    fn input_grads<'a>(&'a self, chained_grad: f64) -> impl Iterator<Item = f64> + 'a {
        self.weights.iter()
            .map(move |w| w * chained_grad)
    }

    fn adjust_weights(&mut self, learning_rate: f64) {
        self.bias += -learning_rate * self.bias_grad;
        for i in 0..self.weights.len() {
            self.weights[i] += -learning_rate * self.weights_grad[i];
        }
    }

    fn zero_grad(&mut self) {
        self.bias_grad = 0.0;
        for i in 0..self.weights_grad.len() {
            self.weights_grad[i] = 0.0;
        }
    }
}