home / grad_rs / src / train.rs

train.rs



use std::time::Instant;
use crate::data::Dataset;
use crate::network::Network;
use crate::vec2::Vec2;

pub fn gradient_descent_training(mut net: Network,
                                 training_dataset: &Dataset,
                                 validation_dataset: &Dataset,
                                 learning_rate: f64,
                                 stopping_loss_score_percentage: f64) -> Network {
    let mut i = 0;
    let mut now = Instant::now();

    let mut layer_state = net.give_layer_state();
    let mut predictions = Vec2 {
        flat: vec![0.0; training_dataset.labels.flat.len()],
        columns: training_dataset.labels.columns
    };

    loop {
        for ix in 0..training_dataset.inputs.len() {
            layer_state.set_ins(&training_dataset.inputs[ix]);
            net.output_training(&mut layer_state);


            layer_state
                .set_grads(&training_dataset.labels[ix], squared_error_derivatives);
            layer_state.get_outs().iter().enumerate()
                .for_each(|(oix, &o)| predictions[ix][oix] = o);

            net.backward( &mut layer_state)
        }

        net.adjust_weights(learning_rate);
        net.zero_grad();

        if i % 100 == 0 {
            let training_loss = get_loss(&training_dataset.labels, &predictions);
            let (training_loss_count, training_loss_percentage) = get_loss_count_and_percentage(&training_dataset.labels, &predictions);


            let mut validation_predictions = Vec2::create(validation_dataset.labels.flat.len());
            for i in validation_dataset.inputs.iter() { validation_predictions.push(&net.output(i)) }
            let validation_loss = get_loss(&validation_dataset.labels, &validation_predictions);

            let (validation_loss_count, validation_loss_percentage) = get_loss_count_and_percentage(&validation_dataset.labels, &validation_predictions);

            let elapsed = now.elapsed();
            println!("{}: training loss = {:.6?} - {}/{} - {:.3?}%, validation loss = {:.6?} - {}/{} - {:.3?}%, duration = {:.2?}",
                     i,
                     training_loss,
                     training_loss_count,
                     training_dataset.labels.len(),
                     training_loss_percentage,
                     validation_loss,
                     validation_loss_count,
                     validation_dataset.labels.len(),
                     validation_loss_percentage,
                     elapsed
            );
            now = Instant::now();
            if training_loss_percentage.abs() < stopping_loss_score_percentage {
                break;
            }
        }
        i += 1;
    }
    return net
}

fn squared_error(actual: f64, prediction: f64) -> f64 {

    return (actual - prediction).powf(2.0)
}

fn squared_error_derivatives(actual: f64, prediction: f64) -> f64 {

    -2.0 * (actual - prediction)
}

fn get_loss(labels: &Vec2<f64>, predictions: &Vec2<f64>) -> f64 {
    labels.flat.iter()
        .zip(predictions.flat.iter())
        .map(|(&a, &p)| squared_error(a, p))
        .sum()
}

fn get_loss_count_and_percentage(labels: &Vec2<f64>, predictions: &Vec2<f64>) -> (usize, f64) {
    let loss_count = labels.iter().zip(predictions.iter())
        .filter(|(l, p)| !is_correct_label(l, p))
        .count();
    let percentage = loss_count as f64 / ((labels.len() as f64) / 100.0);
    return (loss_count, percentage)
}

pub fn is_correct_label(label: &[f64], pred: &[f64]) -> bool {
    assert_eq!(label.len(), pred.len());
    let mut correct = true;
    for i in 0..label.len() {
        let rounded_val = if pred[i] >= 0.0 { 1.0} else { -1.0 };
        correct = rounded_val == label[i]
    }
    correct
}