Connect++ 0.6.0
A fast, readable connection prover for first-order logic.
Loading...
Searching...
No Matches
ERWA.cpp
1/*
2
3Copyright © 2023-24 Sean Holden. All rights reserved.
4
5*/
6/*
7
8This file is part of Connect++.
9
10Connect++ is free software: you can redistribute it and/or modify it
11under the terms of the GNU General Public License as published by the
12Free Software Foundation, either version 3 of the License, or (at your
13option) any later version.
14
15Connect++ is distributed in the hope that it will be useful, but WITHOUT
16ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
17FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
18more details.
19
20You should have received a copy of the GNU General Public License along
21with Connect++. If not, see <https://www.gnu.org/licenses/>.
22
23*/
24
25#include "ERWA.hpp"
26
27boost::random::mt19937 ERWA::random_generator(params::boost_random_seed);
28
29//------------------------------------------------------------------
30//------------------------------------------------------------------
31//------------------------------------------------------------------
32// EXP3
33//------------------------------------------------------------------
34//------------------------------------------------------------------
35//------------------------------------------------------------------
36ERWA::ERWA(size_t _K, bool _eg, bool _a)
37: K(_K)
38, epsilon(0.0)
39, alpha(0.0)
40, p()
41, p2(0, _K - 1)
42, r_hat(_K, 0.0)
43, choose_next(true)
44, epsilon_greedy(_eg)
45, alpha_is_1_over_n(_a)
46, choice(0)
47, n(0)
48{}
49//------------------------------------------------------------------
50size_t ERWA::find_max() const {
51 double r = std::numeric_limits<double>::min();
52 size_t result = 0;
53 for (int i = 0; i < r_hat.size(); i++) {
54 double d = r_hat[i];
55 if (d > r) {
56 r = d;
57 result = i;
58 }
59 }
60 return result;
61}
62//------------------------------------------------------------------
63size_t ERWA::choose() {
64 if (!choose_next) {
65 cerr << "STOP IT!! EXP3 should be receiving reward..." << endl;
66 }
67 choose_next = false;
68 if (epsilon_greedy && p(random_generator)) {
70 }
71 else {
72 choice = find_max();
73 }
74 n++;
75 return choice;
76}
77//------------------------------------------------------------------
78void ERWA::reward(double reward) {
79 if (choose_next) {
80 cerr << "STOP IT!! EXP3 should be choosing..." << endl;
81 }
82 choose_next = true;
83 double r = r_hat[choice];
84 if (alpha_is_1_over_n) {
85 r_hat[choice] = r + ((1 / static_cast<double>(n)) * (reward - r));
86 }
87 else {
88 r_hat[choice] = r + (alpha * (reward - r));
89 }
90}
91//------------------------------------------------------------------
92ostream& operator<<(ostream& out, const ERWA& erwa) {
93 out << "r_hats:" << endl;
94 for (size_t i = 0; i < erwa.K; i++)
95 out << erwa.r_hat[i] << " ";
96 out << endl;
97 return out;
98}
Implementation of the ERWA algorithm for multiarmed bandits.
Definition ERWA.hpp:53
static boost::random::mt19937 random_generator
Random source.
Definition ERWA.hpp:61
size_t choice
Store the last choice made.
Definition ERWA.hpp:79
size_t choose()
Choose using the current state.
Definition ERWA.cpp:63
bool choose_next
Belt-and braces: warn if choose/reward happens in the wrong order.
Definition ERWA.hpp:75
size_t find_max() const
Find the index of the currently maximum r_hat.
Definition ERWA.cpp:50
void reward(double)
Provide reward for the most recent choice.
Definition ERWA.cpp:78