SCIP Doxygen Documentation
 
Loading...
Searching...
No Matches
bandit_ucb.c
Go to the documentation of this file.
1/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2/* */
3/* This file is part of the program and library */
4/* SCIP --- Solving Constraint Integer Programs */
5/* */
6/* Copyright (c) 2002-2025 Zuse Institute Berlin (ZIB) */
7/* */
8/* Licensed under the Apache License, Version 2.0 (the "License"); */
9/* you may not use this file except in compliance with the License. */
10/* You may obtain a copy of the License at */
11/* */
12/* http://www.apache.org/licenses/LICENSE-2.0 */
13/* */
14/* Unless required by applicable law or agreed to in writing, software */
15/* distributed under the License is distributed on an "AS IS" BASIS, */
16/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */
17/* See the License for the specific language governing permissions and */
18/* limitations under the License. */
19/* */
20/* You should have received a copy of the Apache-2.0 license */
21/* along with SCIP; see the file LICENSE. If not visit scipopt.org. */
22/* */
23/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
24
25/**@file bandit_ucb.c
26 * @ingroup OTHER_CFILES
27 * @brief methods for UCB bandit selection
28 * @author Gregor Hendel
29 */
30
31/*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
32
33#include "scip/bandit.h"
34#include "scip/bandit_ucb.h"
35#include "scip/pub_bandit.h"
36#include "scip/pub_message.h"
37#include "scip/pub_misc.h"
38#include "scip/pub_misc_sort.h"
39#include "scip/scip_bandit.h"
40#include "scip/scip_mem.h"
42
43
44#define BANDIT_NAME "ucb"
45#define NUMEPS 1e-6
46
47/*
48 * Data structures
49 */
50
51/** implementation specific data of UCB bandit algorithm */
52struct SCIP_BanditData
53{
54 int nselections; /**< counter for the number of selections */
55 int* counter; /**< array of counters how often every action has been chosen */
56 int* startperm; /**< indices for starting permutation */
57 SCIP_Real* meanscores; /**< array of average scores for the actions */
58 SCIP_Real alpha; /**< parameter to increase confidence width */
59};
60
61
62/*
63 * Local methods
64 */
65
66/** data reset method */
67static
69 BMS_BUFMEM* bufmem, /**< buffer memory */
70 SCIP_BANDIT* ucb, /**< ucb bandit algorithm */
71 SCIP_BANDITDATA* banditdata, /**< UCB bandit data structure */
72 SCIP_Real* priorities, /**< priorities for start permutation, or NULL */
73 int nactions /**< number of actions */
74 )
75{
76 int i;
77 SCIP_RANDNUMGEN* rng;
78
79 assert(bufmem != NULL);
80 assert(ucb != NULL);
81 assert(nactions > 0);
82
83 /* clear counters and scores */
84 BMSclearMemoryArray(banditdata->counter, nactions);
85 BMSclearMemoryArray(banditdata->meanscores, nactions);
86 banditdata->nselections = 0;
87
88 rng = SCIPbanditGetRandnumgen(ucb);
89 assert(rng != NULL);
90
91 /* initialize start permutation as identity */
92 for( i = 0; i < nactions; ++i )
93 banditdata->startperm[i] = i;
94
95 /* prepare the start permutation in decreasing order of priority */
96 if( priorities != NULL )
97 {
98 SCIP_Real* prioritycopy;
99
100 SCIP_ALLOC( BMSduplicateBufferMemoryArray(bufmem, &prioritycopy, priorities, nactions) );
101
102 /* randomly wiggle priorities a little bit to make them unique */
103 for( i = 0; i < nactions; ++i )
104 prioritycopy[i] += SCIPrandomGetReal(rng, -NUMEPS, NUMEPS);
105
106 SCIPsortDownRealInt(prioritycopy, banditdata->startperm, nactions);
107
108 BMSfreeBufferMemoryArray(bufmem, &prioritycopy);
109 }
110 else
111 {
112 /* use a random start permutation */
113 SCIPrandomPermuteIntArray(rng, banditdata->startperm, 0, nactions);
114 }
115
116 return SCIP_OKAY;
117}
118
119
120/*
121 * Callback methods of bandit algorithm
122 */
123
124/** callback to free bandit specific data structures */
125SCIP_DECL_BANDITFREE(SCIPbanditFreeUcb)
126{ /*lint --e{715}*/
127 SCIP_BANDITDATA* banditdata;
128 int nactions;
129 assert(bandit != NULL);
130
131 banditdata = SCIPbanditGetData(bandit);
132 assert(banditdata != NULL);
133 nactions = SCIPbanditGetNActions(bandit);
134
135 BMSfreeBlockMemoryArray(blkmem, &banditdata->counter, nactions);
136 BMSfreeBlockMemoryArray(blkmem, &banditdata->startperm, nactions);
137 BMSfreeBlockMemoryArray(blkmem, &banditdata->meanscores, nactions);
138 BMSfreeBlockMemory(blkmem, &banditdata);
139
140 SCIPbanditSetData(bandit, NULL);
141
142 return SCIP_OKAY;
143}
144
145/** selection callback for bandit selector */
146SCIP_DECL_BANDITSELECT(SCIPbanditSelectUcb)
147{ /*lint --e{715}*/
148 SCIP_BANDITDATA* banditdata;
149 int nactions;
150 int* counter;
151
152 assert(bandit != NULL);
154
155 banditdata = SCIPbanditGetData(bandit);
156 assert(banditdata != NULL);
157 nactions = SCIPbanditGetNActions(bandit);
158
159 counter = banditdata->counter;
160 /* select the next uninitialized action from the start permutation */
161 if( banditdata->nselections < nactions )
162 {
163 *selection = banditdata->startperm[banditdata->nselections];
164 assert(counter[*selection] == 0);
165 }
166 else
167 {
168 /* select the action with the highest upper confidence bound */
169 SCIP_Real* meanscores;
170 SCIP_Real widthfactor;
171 SCIP_Real maxucb;
172 int i;
174 meanscores = banditdata->meanscores;
175
176 assert(rng != NULL);
177 assert(meanscores != NULL);
178
179 /* compute the confidence width factor that is common for all actions */
180 /* cppcheck-suppress unpreciseMathCall */
181 widthfactor = banditdata->alpha * LOG1P((SCIP_Real)banditdata->nselections);
182 widthfactor = sqrt(widthfactor);
183 maxucb = -1.0;
184
185 /* loop over the actions and determine the maximum upper confidence bound.
186 * The upper confidence bound of an action is the sum of its mean score
187 * plus a confidence term that decreases with increasing number of observations of
188 * this action.
189 */
190 for( i = 0; i < nactions; ++i )
191 {
192 SCIP_Real uppercb;
193 SCIP_Real rootcount;
194 assert(counter[i] > 0);
195
196 /* compute the upper confidence bound for action i */
197 uppercb = meanscores[i];
198 rootcount = sqrt((SCIP_Real)counter[i]);
199 uppercb += widthfactor / rootcount;
200 assert(uppercb > 0);
201
202 /* update maximum, breaking ties uniformly at random */
203 if( EPSGT(uppercb, maxucb, NUMEPS) || (EPSEQ(uppercb, maxucb, NUMEPS) && SCIPrandomGetReal(rng, 0.0, 1.0) >= 0.5) )
204 {
205 maxucb = uppercb;
206 *selection = i;
207 }
208 }
209 }
210
211 assert(*selection >= 0);
212 assert(*selection < nactions);
213
214 return SCIP_OKAY;
215}
216
217/** update callback for bandit algorithm */
218SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateUcb)
219{ /*lint --e{715}*/
220 SCIP_BANDITDATA* banditdata;
221 SCIP_Real delta;
222
223 assert(bandit != NULL);
224
225 banditdata = SCIPbanditGetData(bandit);
226 assert(banditdata != NULL);
227 assert(selection >= 0);
229
230 /* increase the mean by the incremental formula: A_n = A_n-1 + 1/n (a_n - A_n-1) */
231 delta = score - banditdata->meanscores[selection];
232 ++banditdata->counter[selection];
233 banditdata->meanscores[selection] += delta / (SCIP_Real)banditdata->counter[selection];
234
235 banditdata->nselections++;
236
237 return SCIP_OKAY;
238}
239
240/** reset callback for bandit algorithm */
241SCIP_DECL_BANDITRESET(SCIPbanditResetUcb)
242{ /*lint --e{715}*/
243 SCIP_BANDITDATA* banditdata;
244 int nactions;
245
246 assert(bufmem != NULL);
247 assert(bandit != NULL);
248
249 banditdata = SCIPbanditGetData(bandit);
250 assert(banditdata != NULL);
251 nactions = SCIPbanditGetNActions(bandit);
252
253 /* call the data reset for the given priorities */
254 SCIP_CALL( dataReset(bufmem, bandit, banditdata, priorities, nactions) );
255
256 return SCIP_OKAY;
257}
258
259/*
260 * bandit algorithm specific interface methods
261 */
262
263/** returns the upper confidence bound of a selected action */
265 SCIP_BANDIT* ucb, /**< UCB bandit algorithm */
266 int action /**< index of the queried action */
267 )
268{
269 SCIP_Real uppercb;
270 SCIP_BANDITDATA* banditdata;
271 int nactions;
272
273 assert(ucb != NULL);
274 banditdata = SCIPbanditGetData(ucb);
275 nactions = SCIPbanditGetNActions(ucb);
276 assert(action < nactions);
277
278 /* since only scores between 0 and 1 are allowed, 1.0 is a sure upper confidence bound */
279 if( banditdata->nselections < nactions )
280 return 1.0;
281
282 /* the bandit algorithm must have picked every action once */
283 assert(banditdata->counter[action] > 0);
284 uppercb = banditdata->meanscores[action];
285
286 /* cppcheck-suppress unpreciseMathCall */
287 uppercb += sqrt(banditdata->alpha * LOG1P((SCIP_Real)banditdata->nselections) / (SCIP_Real)banditdata->counter[action]);
288
289 return uppercb;
290}
291
292/** return start permutation of the UCB bandit algorithm */
294 SCIP_BANDIT* ucb /**< UCB bandit algorithm */
295 )
296{
297 SCIP_BANDITDATA* banditdata = SCIPbanditGetData(ucb);
298
299 assert(banditdata != NULL);
300
301 return banditdata->startperm;
302}
303
304/** internal method to create and reset UCB bandit algorithm */
306 BMS_BLKMEM* blkmem, /**< block memory */
307 BMS_BUFMEM* bufmem, /**< buffer memory */
308 SCIP_BANDITVTABLE* vtable, /**< virtual function table for UCB bandit algorithm */
309 SCIP_BANDIT** ucb, /**< pointer to store bandit algorithm */
310 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
311 SCIP_Real alpha, /**< parameter to increase confidence width */
312 int nactions, /**< the positive number of actions for this bandit algorithm */
313 unsigned int initseed /**< initial random seed */
314 )
315{
316 SCIP_BANDITDATA* banditdata;
317
318 if( alpha < 0.0 )
319 {
320 SCIPerrorMessage("UCB requires nonnegative alpha parameter, have %f\n", alpha);
321 return SCIP_INVALIDDATA;
322 }
323
324 SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
325 assert(banditdata != NULL);
326
327 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->counter, nactions) );
328 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->startperm, nactions) );
329 SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->meanscores, nactions) );
330
331 banditdata->alpha = alpha;
332
333 SCIP_CALL( SCIPbanditCreate(ucb, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
334
335 return SCIP_OKAY;
336}
337
338/** create and reset UCB bandit algorithm */
340 SCIP* scip, /**< SCIP data structure */
341 SCIP_BANDIT** ucb, /**< pointer to store bandit algorithm */
342 SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
343 SCIP_Real alpha, /**< parameter to increase confidence width */
344 int nactions, /**< the positive number of actions for this bandit algorithm */
345 unsigned int initseed /**< initial random number seed */
346 )
347{
348 SCIP_BANDITVTABLE* vtable;
349
351 if( vtable == NULL )
352 {
353 SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
354 return SCIP_INVALIDDATA;
355 }
356
358 priorities, alpha, nactions, SCIPinitializeRandomSeed(scip, initseed)) );
359
360 return SCIP_OKAY;
361}
362
363/** include virtual function table for UCB bandit algorithms */
365 SCIP* scip /**< SCIP data structure */
366 )
367{
368 SCIP_BANDITVTABLE* vtable;
369
371 SCIPbanditFreeUcb, SCIPbanditSelectUcb, SCIPbanditUpdateUcb, SCIPbanditResetUcb) );
372 assert(vtable != NULL);
373
374 return SCIP_OKAY;
375}
void SCIPbanditSetData(SCIP_BANDIT *bandit, SCIP_BANDITDATA *banditdata)
Definition bandit.c:200
SCIP_RETCODE SCIPbanditCreate(SCIP_BANDIT **bandit, SCIP_BANDITVTABLE *banditvtable, BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_Real *priorities, int nactions, unsigned int initseed, SCIP_BANDITDATA *banditdata)
Definition bandit.c:42
SCIP_BANDITDATA * SCIPbanditGetData(SCIP_BANDIT *bandit)
Definition bandit.c:190
internal methods for bandit algorithms
#define BANDIT_NAME
static SCIP_RETCODE dataReset(BMS_BUFMEM *bufmem, SCIP_BANDIT *ucb, SCIP_BANDITDATA *banditdata, SCIP_Real *priorities, int nactions)
Definition bandit_ucb.c:68
#define NUMEPS
Definition bandit_ucb.c:45
SCIP_RETCODE SCIPincludeBanditvtableUcb(SCIP *scip)
Definition bandit_ucb.c:364
SCIP_RETCODE SCIPbanditCreateUcb(BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_BANDITVTABLE *vtable, SCIP_BANDIT **ucb, SCIP_Real *priorities, SCIP_Real alpha, int nactions, unsigned int initseed)
Definition bandit_ucb.c:305
internal methods for UCB bandit algorithm
#define NULL
Definition def.h:266
#define LOG1P(x)
Definition def.h:221
#define SCIP_ALLOC(x)
Definition def.h:384
#define SCIP_Real
Definition def.h:172
#define EPSEQ(x, y, eps)
Definition def.h:197
#define EPSGT(x, y, eps)
Definition def.h:200
#define SCIP_CALL(x)
Definition def.h:373
void SCIPrandomPermuteIntArray(SCIP_RANDNUMGEN *randnumgen, int *array, int begin, int end)
Definition misc.c:10152
int * SCIPgetStartPermutationUcb(SCIP_BANDIT *ucb)
Definition bandit_ucb.c:293
int SCIPbanditGetNActions(SCIP_BANDIT *bandit)
Definition bandit.c:303
SCIP_RANDNUMGEN * SCIPbanditGetRandnumgen(SCIP_BANDIT *bandit)
Definition bandit.c:293
SCIP_BANDITVTABLE * SCIPfindBanditvtable(SCIP *scip, const char *name)
Definition scip_bandit.c:80
SCIP_RETCODE SCIPincludeBanditvtable(SCIP *scip, SCIP_BANDITVTABLE **banditvtable, const char *name, SCIP_DECL_BANDITFREE((*banditfree)), SCIP_DECL_BANDITSELECT((*banditselect)), SCIP_DECL_BANDITUPDATE((*banditupdate)),)
Definition scip_bandit.c:48
SCIP_Real SCIPgetConfidenceBoundUcb(SCIP_BANDIT *ucb, int action)
Definition bandit_ucb.c:264
SCIP_RETCODE SCIPcreateBanditUcb(SCIP *scip, SCIP_BANDIT **ucb, SCIP_Real *priorities, SCIP_Real alpha, int nactions, unsigned int initseed)
Definition bandit_ucb.c:339
BMS_BUFMEM * SCIPbuffer(SCIP *scip)
Definition scip_mem.c:72
SCIP_Real SCIPrandomGetReal(SCIP_RANDNUMGEN *randnumgen, SCIP_Real minrandval, SCIP_Real maxrandval)
Definition misc.c:10133
unsigned int SCIPinitializeRandomSeed(SCIP *scip, unsigned int initialseedvalue)
void SCIPsortDownRealInt(SCIP_Real *realarray, int *intarray, int len)
return SCIP_OKAY
int selection
assert(minobj< SCIPgetCutoffbound(scip))
SCIP_Real alpha
#define BMSfreeBlockMemory(mem, ptr)
Definition memory.h:465
#define BMSduplicateBufferMemoryArray(mem, ptr, source, num)
Definition memory.h:737
#define BMSallocBlockMemory(mem, ptr)
Definition memory.h:451
#define BMSfreeBufferMemoryArray(mem, ptr)
Definition memory.h:742
#define BMSallocBlockMemoryArray(mem, ptr, num)
Definition memory.h:454
#define BMSfreeBlockMemoryArray(mem, ptr, num)
Definition memory.h:467
#define BMSclearMemoryArray(ptr, num)
Definition memory.h:130
struct BMS_BufMem BMS_BUFMEM
Definition memory.h:721
struct BMS_BlkMem BMS_BLKMEM
Definition memory.h:437
BMS_BLKMEM * SCIPblkmem(SCIP *scip)
Definition scip_mem.c:57
public methods for bandit algorithms
public methods for message output
#define SCIPerrorMessage
Definition pub_message.h:64
public data structures and miscellaneous methods
methods for sorting joint arrays of various types
public methods for bandit algorithms
public methods for memory management
public methods for random numbers
#define SCIP_DECL_BANDITUPDATE(x)
Definition type_bandit.h:75
#define SCIP_DECL_BANDITFREE(x)
Definition type_bandit.h:63
struct SCIP_Bandit SCIP_BANDIT
Definition type_bandit.h:50
struct SCIP_BanditData SCIP_BANDITDATA
Definition type_bandit.h:56
#define SCIP_DECL_BANDITSELECT(x)
Definition type_bandit.h:69
struct SCIP_BanditVTable SCIP_BANDITVTABLE
Definition type_bandit.h:53
#define SCIP_DECL_BANDITRESET(x)
Definition type_bandit.h:82
struct SCIP_RandNumGen SCIP_RANDNUMGEN
Definition type_misc.h:126
@ SCIP_INVALIDDATA
enum SCIP_Retcode SCIP_RETCODE
struct Scip SCIP
Definition type_scip.h:39