Proyecto Final - Turinmachin
Recreación del minijuego de matemáticas de Brain-Age usando redes neuronales
Loading...
Searching...
No Matches
agent.h
Go to the documentation of this file.
1#ifndef INCLUDE_COMMON_AGENT_H
2#define INCLUDE_COMMON_AGENT_H
3
4#include <random>
5#include <vector>
7
8namespace common {
9
10 template <typename Input, typename Output>
11 struct Sample {
12 Input input;
13 Output output;
14 };
15
16 template <typename Input, typename Output>
17 class IAgent {
18 public:
19 virtual ~IAgent() = default;
20
21 virtual auto predict(const Input& input) -> Output = 0;
22
23 virtual void train(const std::vector<Sample<Input, Output>>& samples,
24 std::size_t epochs,
25 double learning_rate,
26 std::mt19937& rng) = 0;
27 };
28
31
32 class DigitReader final : public IDigitAgent {
34
35 public:
36 using Input = std::vector<double>;
37 using Output = int;
38
39 explicit DigitReader(std::mt19937& rng);
40
41 explicit DigitReader(std::istream& net_in);
42
43 auto predict(const Input& features) -> int override;
44
45 void train(const std::vector<DigitSample>& samples,
46 std::size_t epochs,
47 double learning_rate,
48 std::mt19937& rng) override;
49
50 auto test_accuracy(const std::vector<DigitSample>& samples) -> double;
51
52 void load_net(std::istream& in);
53
54 void save_net(std::ostream& out) const;
55 };
56
57} // namespace common
58
59#endif
void load_net(std::istream &in)
Definition agent.cpp:82
DigitReader(std::mt19937 &rng)
Definition agent.cpp:14
std::vector< double > Input
Definition agent.h:36
int Output
Definition agent.h:37
auto test_accuracy(const std::vector< DigitSample > &samples) -> double
Definition agent.cpp:69
void save_net(std::ostream &out) const
Definition agent.cpp:86
auto predict(const Input &features) -> int override
Definition agent.cpp:32
void train(const std::vector< DigitSample > &samples, std::size_t epochs, double learning_rate, std::mt19937 &rng) override
Definition agent.cpp:48
Definition agent.h:17
virtual auto predict(const Input &input) -> Output=0
virtual void train(const std::vector< Sample< Input, Output > > &samples, std::size_t epochs, double learning_rate, std::mt19937 &rng)=0
virtual ~IAgent()=default
Clase que representa una red neuronal completamente conectada.
Definition neural_network.h:22
Definition agent.h:8
Sample< std::vector< double >, int > DigitSample
Definition agent.h:30
IAgent< std::vector< double >, int > IDigitAgent
Definition agent.h:29
Definition agent.h:11
std::vector< double > input
Definition agent.h:12