Connect++ 0.6.0
A fast, readable connection prover for first-order logic.
Loading...
Searching...
No Matches
EXP3.cpp
1/*
2
3Copyright © 2023-24 Sean Holden. All rights reserved.
4
5*/
6/*
7
8This file is part of Connect++.
9
10Connect++ is free software: you can redistribute it and/or modify it
11under the terms of the GNU General Public License as published by the
12Free Software Foundation, either version 3 of the License, or (at your
13option) any later version.
14
15Connect++ is distributed in the hope that it will be useful, but WITHOUT
16ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
17FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
18more details.
19
20You should have received a copy of the GNU General Public License along
21with Connect++. If not, see <https://www.gnu.org/licenses/>.
22
23*/
24
25#include "EXP3.hpp"
26
27boost::random::mt19937 EXP3::random_generator(params::boost_random_seed);
28
29//------------------------------------------------------------------
30//------------------------------------------------------------------
31//------------------------------------------------------------------
32// EXP3
33//------------------------------------------------------------------
34//------------------------------------------------------------------
35//------------------------------------------------------------------
36EXP3::EXP3(size_t _K, double _gamma)
37: K(_K)
38, gamma(_gamma)
39, omega(_K, 1.0)
40, p_values(_K, 0.0)
41, choose_next(true)
42, one_minus_gamma(1 - _gamma)
43, gamma_over_K(_gamma / static_cast<double>(_K))
44, choice(0)
45{}
46//------------------------------------------------------------------
47void EXP3::set_gamma(double _gamma) {
48 gamma = _gamma;
49 one_minus_gamma = 1 - _gamma;
50 gamma_over_K = _gamma / static_cast<double>(K);
51}
52//------------------------------------------------------------------
53size_t EXP3::choose() {
54 if (!choose_next) {
55 cerr << "STOP IT!! EXP3 should be receiving reward..." << endl;
56 }
57 choose_next = false;
58 double omega_sum = 0.0;
59 for (size_t i = 0; i < K; i++) {
60 omega_sum += omega[i];
61 p_values[i] = gamma_over_K;
62 }
63 for (size_t i = 0; i < K; i++) {
64 p_values[i] += (one_minus_gamma * (omega[i]/omega_sum));
65 }
66 boost::random::discrete_distribution<> p(p_values.begin(), p_values.end());
68 return choice;
69}
70//------------------------------------------------------------------
71void EXP3::reward(double r) {
72 if (choose_next) {
73 cerr << "STOP IT!! EXP3 should be choosing..." << endl;
74 }
75 choose_next = true;
76 if (r < 0.0 || r > 1.0) {
77 cerr << "STOP IT!! The reward needs to be in [0,1]." << endl;
78 }
79 double x = r / p_values[choice];
80 omega[choice] = omega[choice] * exp(x * gamma_over_K);
81}
82//------------------------------------------------------------------
83ostream& operator<<(ostream& out, const EXP3& exp3) {
84 out << "Omegas:" << endl;
85 for (size_t i = 0; i < exp3.K; i++)
86 out << exp3.omega[i] << " ";
87 out << endl << "p values:" << endl;
88 for (size_t i = 0; i < exp3.K; i++)
89 out << exp3.p_values[i] << " ";
90 out << endl;
91 return out;
92}
Implementation of the EXP3 algorithm for multiarmed bandits.
Definition EXP3.hpp:51
size_t choice
Store the last choice made.
Definition EXP3.hpp:74
bool choose_next
Belt-and braces: warn if choose/reward happens in the wrong order.
Definition EXP3.hpp:70
void set_gamma(double)
Reset gamma and associated members to something different.
Definition EXP3.cpp:47
void reward(double)
Provide reward for the most recent choice.
Definition EXP3.cpp:71
size_t choose()
Choose using the current state.
Definition EXP3.cpp:53
static boost::random::mt19937 random_generator
Random source.
Definition EXP3.hpp:59