home / grad_rs / src / main.rs

main.rs



mod data;
mod network;
mod train;
mod vec2;
mod csv_reader;

use std::{f64, usize};
use clap::Parser;

use crate::data::{Dataset};
use crate::network::{Network};
use crate::train::gradient_descent_training;

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
    #[arg(short, long)]
    training_data_path: String,
    #[arg(short, long)]
    learning_rate: f64,
    #[arg(short, long)]
    stopping_loss_percentage: f64,
    #[arg(short, long, default_value_t = -1)]
    class_column_index: i32,
}

fn main() {
    let args = Args::parse();

    let training_data_path = args.training_data_path;

    let learning_rate: f64 = args.learning_rate;
    let stopping_loss: f64 = args.stopping_loss_percentage;
    let class_column_index: Option<usize> =
        if args.class_column_index >= 0 { Some(args.class_column_index as usize) }
        else { None };

    let data = Dataset::auto_load_dataset(training_data_path.as_str(), class_column_index);
    let scaled = data.shuffle().scale_to((-1.0, 1.0));

    let (training, test) = scaled.divide(75);
    let mut net = Network::construct_random(&[
        training.inputs.columns,
        training.inputs.columns,
        (training.inputs.columns + training.labels.columns) / 2,
        training.labels.columns
    ]);
    net = gradient_descent_training(net, &training, &test, learning_rate, stopping_loss);
}