1#ifndef INCLUDE_COMMON_AGENT_H
2#define INCLUDE_COMMON_AGENT_H
10 template <
typename Input,
typename Output>
16 template <
typename Input,
typename Output>
21 virtual auto predict(
const Input& input) -> Output = 0;
26 std::mt19937& rng) = 0;
36 using Input = std::vector<double>;
45 void train(
const std::vector<DigitSample>& samples,
48 std::mt19937& rng)
override;
50 auto test_accuracy(
const std::vector<DigitSample>& samples) -> double;
54 void save_net(std::ostream& out)
const;
void load_net(std::istream &in)
Definition agent.cpp:82
DigitReader(std::mt19937 &rng)
Definition agent.cpp:14
std::vector< double > Input
Definition agent.h:36
int Output
Definition agent.h:37
auto test_accuracy(const std::vector< DigitSample > &samples) -> double
Definition agent.cpp:69
void save_net(std::ostream &out) const
Definition agent.cpp:86
auto predict(const Input &features) -> int override
Definition agent.cpp:32
void train(const std::vector< DigitSample > &samples, std::size_t epochs, double learning_rate, std::mt19937 &rng) override
Definition agent.cpp:48
virtual auto predict(const Input &input) -> Output=0
virtual void train(const std::vector< Sample< Input, Output > > &samples, std::size_t epochs, double learning_rate, std::mt19937 &rng)=0
virtual ~IAgent()=default
Clase que representa una red neuronal completamente conectada.
Definition neural_network.h:22
Sample< std::vector< double >, int > DigitSample
Definition agent.h:30
IAgent< std::vector< double >, int > IDigitAgent
Definition agent.h:29
std::vector< double > input
Definition agent.h:12
int output
Definition agent.h:13