Proyecto Final - Turinmachin
Recreación del minijuego de matemáticas de Brain-Age usando redes neuronales
|
Clase que representa una red neuronal completamente conectada. More...
#include <neural_network.h>
Public Member Functions | |
NeuralNetwork () | |
Constructor por defecto. | |
template<typename L, typename... Args> | |
void | add_layer (Args &&... args) |
Agrega una nueva capa a la red. | |
template<template< typename... > class LossType, template< typename... > class OptimizerType = SGD> | |
void | train (const algebra::Tensor< T, 2 > &x, const algebra::Tensor< T, 2 > &y, const size_t epochs, const size_t batch_size, T learning_rate, std::mt19937 &rng) |
Entrena la red neuronal usando descenso por lotes. | |
auto | predict (const algebra::Tensor< T, 2 > &X) -> algebra::Tensor< T, 2 > |
Realiza una predicción sobre un conjunto de datos. | |
void | save (std::ostream &out) const |
Guarda el modelo en un flujo de salida binario. |
Static Public Member Functions | |
static auto | load (std::istream &in) -> NeuralNetwork< T > |
Carga una red neuronal desde un flujo de entrada binario. |
Clase que representa una red neuronal completamente conectada.
T | Tipo de dato para los pesos y cálculos (float, double, etc.). |
|
inline |
Constructor por defecto.
|
inline |
|
inlinestatic |
Carga una red neuronal desde un flujo de entrada binario.
in | Flujo de entrada desde el cual se carga la red. |
|
inline |
Realiza una predicción sobre un conjunto de datos.
X | Datos de entrada. |
|
inline |
Guarda el modelo en un flujo de salida binario.
out | Flujo de salida donde se guarda la red. @complexity O(L*s), donde s es el tamaño serializado de cada capa. |
|
inline |
Entrena la red neuronal usando descenso por lotes.
LossType | Tipo de función de pérdida (ej. MSELoss, BCELoss, ...). |
OptimizerType | Tipo de optimizador (ej. SGD, Adam, ...). Por defecto es SGD. |
x | Datos de entrada. |
y | Etiquetas esperadas. |
epochs | Número de épocas de entrenamiento. |
batch_size | Tamaño del batch. |
learning_rate | Tasa de aprendizaje. |
rng | Generador aleatorio para mezclar los datos. @complexity O(e*(n/b)*L*(f+b+u))), donde:
|