home /
grad_rs /
src /
main.rs
main.rs
mod data;
mod network;
mod train;
mod vec2;
mod csv_reader;
use std::{f64, usize};
use clap::Parser;
use crate::data::{Dataset};
use crate::network::{Network};
use crate::train::gradient_descent_training;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(short, long)]
training_data_path: String,
#[arg(short, long)]
learning_rate: f64,
#[arg(short, long)]
stopping_loss_percentage: f64,
#[arg(short, long, default_value_t = -1)]
class_column_index: i32,
}
fn main() {
let args = Args::parse();
let training_data_path = args.training_data_path;
let learning_rate: f64 = args.learning_rate;
let stopping_loss: f64 = args.stopping_loss_percentage;
let class_column_index: Option<usize> =
if args.class_column_index >= 0 { Some(args.class_column_index as usize) }
else { None };
let data = Dataset::auto_load_dataset(training_data_path.as_str(), class_column_index);
let scaled = data.shuffle().scale_to((-1.0, 1.0));
let (training, test) = scaled.divide(75);
let mut net = Network::construct_random(&[
training.inputs.columns,
training.inputs.columns,
(training.inputs.columns + training.labels.columns) / 2,
training.labels.columns
]);
net = gradient_descent_training(net, &training, &test, learning_rate, stopping_loss);
}