home / grad_rs / src / data.rs

data.rs



use std::collections::HashSet;
use std::{f64};
use rand::seq::SliceRandom;
use rand::thread_rng;
use crate::csv_reader::CsvReader;
use crate::vec2::Vec2;

#[derive(Debug)]
pub struct Dataset {
    pub inputs: Vec2<f64>,
    pub labels: Vec2<f64>,
}


impl Dataset {

    pub fn auto_load_dataset(file_path: &str, class_col_index: Option<usize>) -> Dataset {
        println!("loading dataset {} ..", file_path);
        let mut csv = CsvReader::new(file_path, ',');
        let count = csv.get_col_count();

        let mut mapped: Vec<Vec<Vec<f64>>> = vec![];
        for col in 0..count {
            let parsed_values: Result<Vec<f64>, _> = csv.vertical_iter(col).map(|s| s.parse::<f64>()).collect();
            if let Ok(real) = parsed_values {
                println!("col {} is real values", col);
                let formatted = real.iter().map(|&c| vec![c]).collect();
                mapped.push(formatted)
            }
            else {
                println!("col {} is categorical", col);
                let categorical = convert_to_cat(&mut csv, col);
                mapped.push(categorical)
            }
        }

        let mut inputs: Vec2<f64> = Vec2::create(0);
        let mut labels: Vec2<f64> = Vec2::create(0);

        let class_index = class_col_index.unwrap_or(mapped.len() - 1);

        println!("class col = {}", class_index);

        for y in 0..mapped[0].len() {
            let mut input = vec![];
            let mut label = vec![];

            for x in 0..mapped.len() {
                if x == class_index {
                    for sub_x in 0..mapped[x][0].len() {
                        label.push(mapped[x][y][sub_x])
                    }
                }
                else {
                    for sub_x in 0..mapped[x][0].len() {
                        input.push(mapped[x][y][sub_x])
                    }
                }
            }
            inputs.push(&input);
            labels.push(&label)
        }
        Dataset {
            inputs,
            labels
        }
    }

    pub fn shuffle(&self) -> Dataset {

        let mut joined: Vec<(&[f64], &[f64])> = self.inputs.iter()
            .zip(self.labels.iter())
            .collect();
        joined.shuffle(&mut thread_rng());

        let new_inputs = Vec2 {
            columns: self.inputs.columns,
            flat: joined.iter().map(|(i, _)| i).flat_map(|i| i.iter().map(|i| *i)).collect()
        };
        let new_labels = Vec2 {
            columns: self.labels.columns,
            flat: joined.iter().map(|(_, l)| l).flat_map(|l| l.iter().map(|l| *l)).collect()
        };

        return Dataset {
            inputs: new_inputs,
            labels: new_labels
        }
    }

    pub fn divide(&self, first_percentage: usize) -> (Dataset, Dataset) {
        let (first_inputs, second_inputs) = self.inputs.divide(first_percentage);
        let (first_labels, second_labels) = self.labels.divide(first_percentage);
        return (
            Dataset {
                inputs: first_inputs,
                labels: first_labels
            },
            Dataset {
                inputs: second_inputs,
                labels: second_labels
            }
        )
    }

    pub fn scale_to(&self, range: (f64, f64)) -> Dataset {
        return Dataset {
            inputs: self.inputs.scale_to(range),
            labels: self.labels.scale_to(range)
        }
    }
}

fn convert_to_cat(reader: &mut CsvReader, col: usize) -> Vec<Vec<f64>> {
    let mut set: HashSet<String> = HashSet::with_capacity(20);
    for item in reader.vertical_iter(col) {
        set.insert(item.to_string());
    }
    let mapping = set.into_iter().collect::<Vec<String>>();
    assert!(mapping.len() >= 2);

    let mut result: Vec<Vec<f64>> = vec![];

    for item in reader.vertical_iter(col) {
        if mapping.len() > 2 {
            let mut mapped_line = Vec::with_capacity(mapping.len());
            for n in 0..mapping.len() {
                if mapping[n].eq(&item) {
                    mapped_line.push(1.0)
                }
                else {
                    mapped_line.push(-1.0)
                }
            }
            result.push(mapped_line);
        }
        else if mapping[0].eq(&item) {
            result.push(vec![1.0])
        } else {
            result.push(vec![-1.0])
        }
    }
    result
}