Skip to content

fix(LearnerTorch): don't log during private .train #401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion R/CallbackSetEarlyStopping.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ CallbackSetEarlyStopping = R6Class("CallbackSetEarlyStopping",
improvement = multiplier * (self$ctx$last_scores_valid[[1L]] - self$best_score)

if (is.na(improvement)) {
lg$warn("Learner %s in epoch %s: Difference between subsequent validation performances is NA",
warningf("Learner %s in epoch %s: Difference between subsequent validation performances is NA",
self$ctx$learner$id, self$ctx$epoch)
return(NULL)
}
Expand Down
24 changes: 17 additions & 7 deletions R/CallbackSetUnfreeze.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#' @param unfreeze (`data.table`)\cr
#' A `data.table` with a column `weights` (a list column of `Select`s) and a column `epoch` or `batch`.
#' The selector indicates which parameters to unfreeze, while the `epoch` or `batch` column indicates when to do so.
#' @param verbose (`logical(1)`)\cr
#' Whether to print messages to the console.
#'
#' @family Callback
#' @export
Expand All @@ -35,9 +37,10 @@ CallbackSetUnfreeze = R6Class("CallbackSetUnfreeze",
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(starting_weights, unfreeze) {
initialize = function(starting_weights, unfreeze, verbose) {
self$starting_weights = starting_weights
self$unfreeze = unfreeze
self$verbose = verbose
private$.batchwise = "batch" %in% names(self$unfreeze)
},
#' @description
Expand All @@ -49,7 +52,9 @@ CallbackSetUnfreeze = R6Class("CallbackSetUnfreeze",
walk(self$ctx$network$parameters[frozen_weights], function(param) param$requires_grad_(FALSE))

frozen_weights_str = paste(trainable_weights, collapse = ", ")
lg$info(sprintf("Training the following weights at the start: %s", paste0(trainable_weights, collapse = ", ")))
if (self$verbose) {
messagef("Training the following weights at the start: %s", paste0(trainable_weights, collapse = ", "))
}
},
#' @description
#' Unfreezes weights if the training is at the correct epoch
Expand All @@ -58,11 +63,13 @@ CallbackSetUnfreeze = R6Class("CallbackSetUnfreeze",
if (self$ctx$epoch %in% self$unfreeze$epoch) {
weights = (self$unfreeze[get("epoch") == self$ctx$epoch]$weights)[[1]](names(self$ctx$network$parameters))
if (!length(weights)) {
lg$warn(paste0("No weights unfrozen at epoch ", self$ctx$epoch, " , check the specification of the Selector"))
warningf(paste0("No weights unfrozen at epoch ", self$ctx$epoch, " , check the specification of the Selector"))
} else {
walk(self$ctx$network$parameters[weights], function(param) param$requires_grad_(TRUE))
weights_str = paste(weights, collapse = ", ")
lg$info(paste0("Unfreezing at epoch ", self$ctx$epoch, ": ", weights_str))
if (self$verbose) {
messagef(paste0("Unfreezing at epoch ", self$ctx$epoch, ": ", weights_str))
}
}

}
Expand All @@ -76,11 +83,13 @@ CallbackSetUnfreeze = R6Class("CallbackSetUnfreeze",
if (batch_num %in% self$unfreeze$batch) {
weights = (self$unfreeze[get("batch") == batch_num]$weights)[[1]](names(self$ctx$network$parameters))
if (!length(weights)) {
lg$warn(paste0("No weights unfrozen at batch ", batch_num, " , check the specification of the Selector"))
warningf(paste0("No weights unfrozen at batch ", batch_num, " , check the specification of the Selector"))
} else {
walk(self$ctx$network$parameters[weights], function(param) param$requires_grad_(TRUE))
weights_str = paste(weights, collapse = ", ")
lg$info(paste0("Unfreezing at batch ", batch_num, ": ", weights_str))
if (self$verbose) {
messagef(paste0("Unfreezing at batch ", batch_num, ": ", weights_str))
}
}
}
}
Expand All @@ -100,7 +109,8 @@ mlr3torch_callbacks$add("unfreeze", function() {
unfreeze = p_uty(
tags = c("train", "required"),
custom_check = check_unfreeze_dt
)
),
verbose = p_lgl(init = FALSE, tags = c("train", "required"))
),
id = "unfreeze",
label = "Unfreeze",
Expand Down
1 change: 0 additions & 1 deletion R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
auto_device = function(device = NULL) {
if (device == "auto") {
device = if (cuda_is_available()) "cuda" else "cpu"
lg$debug("Auto-detected device '%s'.", device)
}
return(device)
}
Expand Down
4 changes: 2 additions & 2 deletions R/with_torch_settings.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ with_torch_settings = function(seed, num_threads = 1, num_interop_threads = 1, e
old_num_threads = torch_get_num_threads()
if (running_on_mac()) {
if (!isTRUE(all.equal(num_threads, 1L))) {
lg$warn("Cannot set number of threads on macOS.")
warningf("Cannot set number of threads on macOS.")
}
} else {
on.exit({torch_set_num_threads(old_num_threads)},
Expand All @@ -14,7 +14,7 @@ with_torch_settings = function(seed, num_threads = 1, num_interop_threads = 1, e
if (num_interop_threads != torch_get_num_interop_threads()) {
result = try(torch::torch_set_num_interop_threads(num_interop_threads), silent = TRUE)
if (inherits(result, "try-error")) {
lg$warn(sprintf("Can only set the interop threads once, keeping the previous value %s", torch_get_num_interop_threads()))
warningf(sprintf("Can only set the interop threads once, keeping the previous value %s", torch_get_num_interop_threads()))
}
}
# sets the seed back when exiting the function
Expand Down
2 changes: 1 addition & 1 deletion man/mlr_callback_set.unfreeze.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading