home /
grad_rs /
src /
train.rs
train.rs
use std::time::Instant;
use crate::data::Dataset;
use crate::network::Network;
use crate::vec2::Vec2;
pub fn gradient_descent_training(mut net: Network,
training_dataset: &Dataset,
validation_dataset: &Dataset,
learning_rate: f64,
stopping_loss_score_percentage: f64) -> Network {
let mut i = 0;
let mut now = Instant::now();
let mut layer_state = net.give_layer_state();
let mut predictions = Vec2 {
flat: vec![0.0; training_dataset.labels.flat.len()],
columns: training_dataset.labels.columns
};
loop {
for ix in 0..training_dataset.inputs.len() {
layer_state.set_ins(&training_dataset.inputs[ix]);
net.output_training(&mut layer_state);
layer_state
.set_grads(&training_dataset.labels[ix], squared_error_derivatives);
layer_state.get_outs().iter().enumerate()
.for_each(|(oix, &o)| predictions[ix][oix] = o);
net.backward( &mut layer_state)
}
net.adjust_weights(learning_rate);
net.zero_grad();
if i % 100 == 0 {
let training_loss = get_loss(&training_dataset.labels, &predictions);
let (training_loss_count, training_loss_percentage) = get_loss_count_and_percentage(&training_dataset.labels, &predictions);
let mut validation_predictions = Vec2::create(validation_dataset.labels.flat.len());
for i in validation_dataset.inputs.iter() { validation_predictions.push(&net.output(i)) }
let validation_loss = get_loss(&validation_dataset.labels, &validation_predictions);
let (validation_loss_count, validation_loss_percentage) = get_loss_count_and_percentage(&validation_dataset.labels, &validation_predictions);
let elapsed = now.elapsed();
println!("{}: training loss = {:.6?} - {}/{} - {:.3?}%, validation loss = {:.6?} - {}/{} - {:.3?}%, duration = {:.2?}",
i,
training_loss,
training_loss_count,
training_dataset.labels.len(),
training_loss_percentage,
validation_loss,
validation_loss_count,
validation_dataset.labels.len(),
validation_loss_percentage,
elapsed
);
now = Instant::now();
if training_loss_percentage.abs() < stopping_loss_score_percentage {
break;
}
}
i += 1;
}
return net
}
fn squared_error(actual: f64, prediction: f64) -> f64 {
return (actual - prediction).powf(2.0)
}
fn squared_error_derivatives(actual: f64, prediction: f64) -> f64 {
-2.0 * (actual - prediction)
}
fn get_loss(labels: &Vec2<f64>, predictions: &Vec2<f64>) -> f64 {
labels.flat.iter()
.zip(predictions.flat.iter())
.map(|(&a, &p)| squared_error(a, p))
.sum()
}
fn get_loss_count_and_percentage(labels: &Vec2<f64>, predictions: &Vec2<f64>) -> (usize, f64) {
let loss_count = labels.iter().zip(predictions.iter())
.filter(|(l, p)| !is_correct_label(l, p))
.count();
let percentage = loss_count as f64 / ((labels.len() as f64) / 100.0);
return (loss_count, percentage)
}
pub fn is_correct_label(label: &[f64], pred: &[f64]) -> bool {
assert_eq!(label.len(), pred.len());
let mut correct = true;
for i in 0..label.len() {
let rounded_val = if pred[i] >= 0.0 { 1.0} else { -1.0 };
correct = rounded_val == label[i]
}
correct
}