Skip to content

Commit ae82e78

Browse files
committed
refactor: drop assertion of insample performance of surrogate as a feature completely
1 parent b745d5f commit ae82e78

10 files changed

+7
-304
lines changed

R/Surrogate.R

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -161,16 +161,6 @@ Surrogate = R6Class("Surrogate",
161161
}
162162
},
163163

164-
#' @field insample_perf (`numeric()`)\cr
165-
#' Surrogate model's current insample performance.
166-
insample_perf = function(rhs) {
167-
if (missing(rhs)) {
168-
private$.insample_perf %??% NaN
169-
} else {
170-
stop("$insample_perf is read-only.")
171-
}
172-
},
173-
174164
#' @field param_set ([paradox::ParamSet])\cr
175165
#' Set of hyperparameters.
176166
param_set = function(rhs) {
@@ -181,11 +171,6 @@ Surrogate = R6Class("Surrogate",
181171
}
182172
},
183173

184-
#' @template field_assert_insample_perf_surrogate
185-
assert_insample_perf = function(rhs) {
186-
stop("Abstract.")
187-
},
188-
189174
#' @template field_packages_surrogate
190175
packages = function(rhs) {
191176
if (missing(rhs)) {
@@ -231,8 +216,6 @@ Surrogate = R6Class("Surrogate",
231216

232217
.cols_y = NULL,
233218

234-
.insample_perf = NULL,
235-
236219
.param_set = NULL,
237220

238221
.update = function() {

R/SurrogateLearner.R

Lines changed: 1 addition & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,6 @@
55
#'
66
#' @section Parameters:
77
#' \describe{
8-
#' \item{`assert_insample_perf`}{`logical(1)`\cr
9-
#' Should the insample performance of the [mlr3::LearnerRegr] be asserted after updating the surrogate?
10-
#' If the assertion fails (i.e., the insample performance based on the `perf_measure` does not meet the
11-
#' `perf_threshold`), an error is thrown.
12-
#' Default is `FALSE`.
13-
#' }
14-
#' \item{`perf_measure`}{[mlr3::MeasureRegr]\cr
15-
#' Performance measure which should be use to assert the insample performance of the [mlr3::LearnerRegr].
16-
#' Only relevant if `assert_insample_perf = TRUE`.
17-
#' Default is [mlr3::mlr_measures_regr.rsq].
18-
#' }
19-
#' \item{`perf_threshold`}{`numeric(1)`\cr
20-
#' Threshold the insample performance of the [mlr3::LearnerRegr] should be asserted against.
21-
#' Only relevant if `assert_insample_perf = TRUE`.
22-
#' Default is `0`.
23-
#' }
248
#' \item{`catch_errors`}{`logical(1)`\cr
259
#' Should errors during updating the surrogate be caught and propagated to the `loop_function` which can then handle
2610
#' the failed acquisition function optimization (as a result of the failed surrogate) appropriately by, e.g., proposing a randomly sampled point for evaluation?
@@ -89,15 +73,10 @@ SurrogateLearner = R6Class("SurrogateLearner",
8973
assert_string(col_y, null.ok = TRUE)
9074

9175
ps = ps(
92-
assert_insample_perf = p_lgl(),
93-
perf_measure = p_uty(custom_check = function(x) check_r6(x, classes = "MeasureRegr")), # FIXME: actually want check_measure
94-
perf_threshold = p_dbl(lower = -Inf, upper = Inf),
9576
catch_errors = p_lgl(),
9677
impute_method = p_fct(c("mean", "random"), default = "random")
9778
)
98-
ps$values = list(assert_insample_perf = FALSE, catch_errors = TRUE, impute_method = "random")
99-
ps$add_dep("perf_measure", on = "assert_insample_perf", cond = CondEqual$new(TRUE))
100-
ps$add_dep("perf_threshold", on = "assert_insample_perf", cond = CondEqual$new(TRUE))
79+
ps$values = list(catch_errors = TRUE, impute_method = "random")
10180

10281
super$initialize(learner = learner, archive = archive, cols_x = cols_x, cols_y = col_y, param_set = ps)
10382
},
@@ -138,47 +117,6 @@ SurrogateLearner = R6Class("SurrogateLearner",
138117
1L
139118
},
140119

141-
#' @template field_assert_insample_perf_surrogate
142-
assert_insample_perf = function(rhs) {
143-
if (missing(rhs)) {
144-
if (!self$param_set$values$assert_insample_perf) {
145-
return(invisible(self$insample_perf))
146-
}
147-
148-
perf_measure = self$param_set$values$perf_measure %??% mlr_measures$get("regr.rsq")
149-
perf_threshold = self$param_set$values$perf_threshold %??% 0
150-
check = if (perf_measure$minimize) {
151-
self$insample_perf < perf_threshold
152-
} else {
153-
self$insample_perf > perf_threshold
154-
}
155-
156-
if (!check) {
157-
stop("Current insample performance of the Surrogate Model does not meet the performance threshold.")
158-
}
159-
invisible(self$insample_perf)
160-
} else {
161-
stop("$assert_insample_perf is read-only.")
162-
}
163-
164-
if (!self$param_set$values$assert_insample_perf) {
165-
return(invisible(self$insample_perf))
166-
}
167-
168-
perf_measure = self$param_set$values$perf_measure %??% mlr_measures$get("regr.rsq")
169-
perf_threshold = self$param_set$values$perf_threshold %??% 0
170-
check = if (perf_measure$minimize) {
171-
self$insample_perf < perf_threshold
172-
} else {
173-
self$insample_perf > perf_threshold
174-
}
175-
176-
if (!check) {
177-
stop("Current insample performance of the Surrogate Model does not meet the performance threshold.")
178-
}
179-
invisible(self$insample_perf)
180-
},
181-
182120
#' @template field_packages_surrogate
183121
packages = function(rhs) {
184122
if (missing(rhs)) {
@@ -218,23 +156,15 @@ SurrogateLearner = R6Class("SurrogateLearner",
218156

219157
private = list(
220158
# Train learner with new data.
221-
# Also calculates the insample performance based on the `perf_measure` hyperparameter if `assert_insample_perf = TRUE`.
222159
.update = function() {
223160
xydt = self$archive$data[, c(self$cols_x, self$cols_y), with = FALSE]
224161
task = TaskRegr$new(id = "surrogate_task", backend = xydt, target = self$cols_y)
225162
assert_learnable(task, learner = self$learner)
226163
self$learner$train(task)
227-
228-
if (self$param_set$values$assert_insample_perf) {
229-
measure = assert_measure(self$param_set$values$perf_measure %??% mlr_measures$get("regr.rsq"), task = task, learner = self$learner)
230-
private$.insample_perf = self$learner$predict(task)$score(measure, task = task, learner = self$learner)
231-
self$assert_insample_perf
232-
}
233164
},
234165

235166
# Train learner with new data.
236167
# Operates on an asynchronous archive and performs imputation as needed.
237-
# Also calculates the insample performance based on the `perf_measure` hyperparameter if `assert_insample_perf = TRUE`.
238168
.update_async = function() {
239169
xydt = self$archive$rush$fetch_tasks_with_state(states = c("queued", "running", "finished"))[, c(self$cols_x, self$cols_y, "state"), with = FALSE]
240170
if (self$param_set$values$impute_method == "mean") {
@@ -250,12 +180,6 @@ SurrogateLearner = R6Class("SurrogateLearner",
250180
task = TaskRegr$new(id = "surrogate_task", backend = xydt, target = self$cols_y)
251181
assert_learnable(task, learner = self$learner)
252182
self$learner$train(task)
253-
254-
if (self$param_set$values$assert_insample_perf) {
255-
measure = assert_measure(self$param_set$values$perf_measure %??% mlr_measures$get("regr.rsq"), task = task, learner = self$learner)
256-
private$.insample_perf = self$learner$predict(task)$score(measure, task = task, learner = self$learner)
257-
self$assert_insample_perf
258-
}
259183
},
260184

261185
.reset = function() {

R/SurrogateLearnerCollection.R

Lines changed: 1 addition & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,6 @@
77
#'
88
#' @section Parameters:
99
#' \describe{
10-
#' \item{`assert_insample_perf`}{`logical(1)`\cr
11-
#' Should the insample performance of the [mlr3::LearnerRegr] be asserted after updating the surrogate?
12-
#' If the assertion fails (i.e., the insample performance based on the `perf_measure` does not meet the
13-
#' `perf_threshold`), an error is thrown.
14-
#' Default is `FALSE`.
15-
#' }
16-
#' \item{`perf_measure`}{List of [mlr3::MeasureRegr]\cr
17-
#' Performance measures which should be use to assert the insample performance of the [mlr3::LearnerRegr].
18-
#' Only relevant if `assert_insample_perf = TRUE`.
19-
#' Default is [mlr3::mlr_measures_regr.rsq] for each learner.
20-
#' }
21-
#' \item{`perf_threshold`}{List of `numeric(1)`\cr
22-
#' Thresholds the insample performance of the [mlr3::LearnerRegr] should be asserted against.
23-
#' Only relevant if `assert_insample_perf = TRUE`.
24-
#' Default is `0` for each learner.
25-
#' }
2610
#' \item{`catch_errors`}{`logical(1)`\cr
2711
#' Should errors during updating the surrogate be caught and propagated to the `loop_function` which can then handle
2812
#' the failed acquisition function optimization (as a result of the failed surrogate) appropriately by, e.g., proposing a randomly sampled point for evaluation?
@@ -104,15 +88,10 @@ SurrogateLearnerCollection = R6Class("SurrogateLearnerCollection",
10488
assert_character(cols_y, len = length(learners), null.ok = TRUE)
10589

10690
ps = ps(
107-
assert_insample_perf = p_lgl(),
108-
perf_measures = p_uty(custom_check = function(x) check_list(x, types = "MeasureRegr", any.missing = FALSE, len = length(learners))), # FIXME: actually want check_measures
109-
perf_thresholds = p_uty(custom_check = function(x) check_double(x, lower = -Inf, upper = Inf, any.missing = FALSE, len = length(learners))),
11091
catch_errors = p_lgl(),
11192
impute_method = p_fct(c("mean", "random"), default = "random")
11293
)
113-
ps$values = list(assert_insample_perf = FALSE, catch_errors = TRUE, impute_method = "random")
114-
ps$add_dep("perf_measures", on = "assert_insample_perf", cond = CondEqual$new(TRUE))
115-
ps$add_dep("perf_thresholds", on = "assert_insample_perf", cond = CondEqual$new(TRUE))
94+
ps$values = list(catch_errors = TRUE, impute_method = "random")
11695

11796
super$initialize(learner = learners, archive = archive, cols_x = cols_x, cols_y = cols_y, param_set = ps)
11897
},
@@ -159,33 +138,6 @@ SurrogateLearnerCollection = R6Class("SurrogateLearnerCollection",
159138
length(self$learner)
160139
},
161140

162-
#' @template field_assert_insample_perf_surrogate
163-
assert_insample_perf = function(rhs) {
164-
if (missing(rhs)) {
165-
check = all(pmap_lgl(
166-
list(
167-
insample_perf = self$insample_perf,
168-
perf_threshold = self$param_set$values$perf_thresholds %??% rep(0, self$n_learner),
169-
perf_measure = self$param_set$values$perf_measures %??% replicate(self$n_learner, mlr_measures$get("regr.rsq"), simplify = FALSE)
170-
),
171-
.f = function(insample_perf, perf_threshold, perf_measure) {
172-
if (perf_measure$minimize) {
173-
insample_perf < perf_threshold
174-
} else {
175-
insample_perf > perf_threshold
176-
}
177-
})
178-
)
179-
180-
if (!check) {
181-
stop("Current insample performance of the Surrogate Model does not meet the performance threshold.")
182-
}
183-
invisible(self$insample_perf)
184-
} else {
185-
stop("$assert_insample_perf is read-only.")
186-
}
187-
},
188-
189141
#' @template field_packages_surrogate
190142
packages = function(rhs) {
191143
if (missing(rhs)) {
@@ -230,7 +182,6 @@ SurrogateLearnerCollection = R6Class("SurrogateLearnerCollection",
230182
private = list(
231183

232184
# Train learner with new data.
233-
# Also calculates the insample performance based on the `perf_measures` hyperparameter if `assert_insample_perf = TRUE`.
234185
.update = function() {
235186
assert_true((length(self$cols_y) == length(self$learner)) || length(self$cols_y) == 1L) # either as many cols_y as learner or only one
236187
one_to_multiple = length(self$cols_y) == 1L
@@ -255,21 +206,10 @@ SurrogateLearnerCollection = R6Class("SurrogateLearnerCollection",
255206
} else {
256207
names(self$learner) = self$cols_y
257208
}
258-
259-
if (self$param_set$values$assert_insample_perf) {
260-
private$.insample_perf = setNames(pmap_dbl(list(learner = self$learner, task = tasks, perf_measure = self$param_set$values$perf_measures %??% replicate(self$n_learner, mlr_measures$get("regr.rsq"), simplify = FALSE)),
261-
.f = function(learner, task, perf_measure) {
262-
assert_measure(perf_measure, task = task, learner = learner)
263-
learner$predict(task)$score(perf_measure, task = task, learner = learner)
264-
}
265-
), nm = map_chr(self$param_set$values$perf_measures, "id"))
266-
self$assert_insample_perf
267-
}
268209
},
269210

270211
# Train learner with new data.
271212
# Operates on an asynchronous archive and performs imputation as needed.
272-
# Also calculates the insample performance based on the `perf_measures` hyperparameter if `assert_insample_perf = TRUE`.
273213
.update_async = function() {
274214
assert_true((length(self$cols_y) == length(self$learner)) || length(self$cols_y) == 1L) # either as many cols_y as learner or only one
275215
one_to_multiple = length(self$cols_y) == 1L
@@ -309,16 +249,6 @@ SurrogateLearnerCollection = R6Class("SurrogateLearnerCollection",
309249
} else {
310250
names(self$learner) = self$cols_y
311251
}
312-
313-
if (self$param_set$values$assert_insample_perf) {
314-
private$.insample_perf = setNames(pmap_dbl(list(learner = self$learner, task = tasks, perf_measure = self$param_set$values$perf_measures %??% replicate(self$n_learner, mlr_measures$get("regr.rsq"), simplify = FALSE)),
315-
.f = function(learner, task, perf_measure) {
316-
assert_measure(perf_measure, task = task, learner = learner)
317-
learner$predict(task)$score(perf_measure, task = task, learner = learner)
318-
}
319-
), nm = map_chr(self$param_set$values$perf_measures, "id"))
320-
self$assert_insample_perf
321-
}
322252
},
323253

324254
.reset = function() {

man-roxygen/field_assert_insample_perf_surrogate.R

Lines changed: 0 additions & 2 deletions
This file was deleted.

man/Surrogate.Rd

Lines changed: 0 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/SurrogateLearner.Rd

Lines changed: 0 additions & 19 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/SurrogateLearnerCollection.Rd

Lines changed: 0 additions & 19 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)