Matrix Flow

A GPU-accelerated Machine Learning Library in Rust and CUDA

Overview

Matrix Flow is a machine learning library designed to leverage the power of GPU acceleration using Rust and CUDA. Built as a personal project to deepen understanding of neural networks, this library provides flexible tools for matrix manipulation, optimized multi-layer perceptron (MLP) creation, and performance profiling.

Features

Prerequisites

Install Rust

Download Rust from the official Rust website.

Install CUDA Toolkit

Ensure that CUDA is installed on your system. The library defaults to Linux paths; adjust for your setup as needed.

Install Sample Data Sets

Matrix Flow expects datasets in CSV format for training and testing. Make sure to prepare or download labeled CSV data sets.

Example Usage

To use Matrix Flow for a multi-layer perceptron (MLP) model:

use matrix_flow::prelude::*;

fn read_labeled_data>(path: P, output_size: usize, batch_size: usize, max_value: ValueType) -> Result<(Vec, Vec), Box> {
    // Load Batches to Matrices (items, labels)
}


fn main() {
    // Parameters
    const EPOCHS: u32 = 100;
    const BATCH_SIZE: usize = 128;

    let layers = [
        Layer::new(28*28, 100, ActivationType::Tanh),
        Layer::new(100, 100, ActivationType::Tanh),
        Layer::new(100, 10, ActivationType::Linear),
    ];

    let (input_data, output_data) = read_labeled_data(
        "data_sets/mnist_train.csv", // Path to Your Data Set
        10,
        BATCH_SIZE,
        255.0
    ).expect("Unable to read data");

    let optim = Optimizer::adam(layers, 0.9, 0.999, 1e-8);
    let network = MLP::new(BATCH_SIZE, 0.001, optim, layers);

    for e in 0..EPOCHS {
        let mut error = 0.;
        for (x, y) in zip(&input_data, &output_data) {
            let output = network.forward(x);

            error += mse(y, &output);
            let gradient = mse_prime(y, &output);

            let _ = network.backward(&gradient);
        }
        println!("{e}: {}", error / input_data.len() as f32);
    }

}

Important Functions