Skip to content

Commit

Permalink
overhaul batch processing mechanism: should be more reliable now, and…
Browse files Browse the repository at this point in the history
… aggregation should not accidently and uncontrollably throw information away.
  • Loading branch information
lschneiderbauer committed Dec 22, 2024
1 parent 8594cfd commit c4859bc
Show file tree
Hide file tree
Showing 12 changed files with 268 additions and 116 deletions.
15 changes: 11 additions & 4 deletions R/fcwt.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,15 @@ fcwt <- function(signal,
dim(output) <- c(length(signal), n_freqs)
# }

new_fcwtr_scalogram(
output, sample_freq, freq_begin, freq_end,
freq_scale, sigma, remove_coi
)
sc <-
new_fcwtr_scalogram(
output, sample_freq, freq_begin, freq_end,
freq_scale, sigma
)

if (remove_coi) {
sc_set_coi_na(sc)
} else {
sc
}
}
55 changes: 34 additions & 21 deletions R/fcwt_batch.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
#' performed in one run.
#' For instance, in case of processing a song of 10 minutes length (assuming
#' a sampling rate of 44100 Hz), the size of the output vector is
#' `10 * 60 seconds * 44100 Hz * nfreqs * 4 bytes`,
#' which for e.g. `nfreqs = 200`, equals ~ 21 GB, hence
#' `10 * 60 seconds * 44100 Hz * nfreqs * 8 bytes`,
#' which for e.g. `nfreqs = 200`, equals ~ 42 GB, hence
#' nowadays already at the limit of the hardware of a modern personal computer.
#'
#' In cases where the required output time-resolution is smaller than the time
Expand All @@ -33,8 +33,10 @@
#'
#' @param max_batch_size
#' The maximal batch size that is used for splitting up the input sequence.
#' This limits the maximal memory that is used. Defaults to roughly 4GB.
#' The actual batch size is optimized for use with FFTW.
#' This limits the maximal memory that is used. Defaults to roughly 1GB, being
#' conservative and taking into account that R might make copies when further
#' processing it.
#' The actual batch size depends on the requested `time_resolution`.
#' @param time_resolution
#' The time resolution in inverse units of `sample_freq` of the result.
#' Memory consumption is directly related to that.
Expand Down Expand Up @@ -73,10 +75,16 @@ fcwt_batch <- function(signal,
time_resolution,
freq_begin = 2 * sample_freq / length(signal),
freq_end = sample_freq / 2,
freq_scale = c("linear", "log"),
sigma = 1,
max_batch_size = ceiling(4 * 10^9 / (n_freqs * 4)),
# factor 4 as additional security measure
max_batch_size = ceiling(1 * 10^9 / (n_freqs * 8) / 4),
n_threads = 2L,
progress_bar = FALSE) {

# aggregation window
w <- wnd_from_resolution(time_resolution, sample_freq)

# From FFTW documentation:
# FTW is best at handling sizes of the form 2^a 3^b 5^c 7^d 11^e 13^f,
# where e+f is
Expand All @@ -85,7 +93,12 @@ fcwt_batch <- function(signal,
# retains O(n lg n) performance, even for prime sizes).
# Transforms whose sizes are powers of 2 are especially fast.

batch_size <- 2^floor(log2(max_batch_size))
# we also want the batch size to be a multiple of the
# window size
# batch_size <- 2^floor(log2(max_batch_size))
batch_size <- floor(max_batch_size / w$size_n) * w$size_n

signal_size <- w$size_n * floor(length(signal) / w$size_n) # cut off the rest

total_result <- NULL

Expand All @@ -102,29 +115,24 @@ fcwt_batch <- function(signal,
diff <- 0
while (cursor < length(signal) - diff) {
begin <- cursor + 1
end <- pmin(cursor + batch_size, length(signal))
end <- pmin(cursor + batch_size, signal_size)

n <- (1 + end - begin)
reduced_n <- ceiling(n / (sample_freq * time_resolution))

result_intermediate <-
result_raw <-
fcwt(
signal[begin:end],
sample_freq = sample_freq,
freq_begin = freq_begin,
freq_end = freq_end,
n_freqs = n_freqs,
freq_scale = freq_scale,
sigma = sigma,
remove_coi = TRUE,
n_threads = n_threads
) |>
agg(n = reduced_n)
)

result <-
result_intermediate |>
rm_na_time_slices() # we fully remove COI infected time slices
time_index_interval <- sc_coi_time_interval(result_raw)

if (dim(result)[[1]] < 1) {
if (any(is.na(time_index_interval))) {
stop(paste0(
"Removing COI yields empty result. Typically that happens if ",
"the batch size is too small. ",
Expand All @@ -133,20 +141,25 @@ fcwt_batch <- function(signal,
))
}

result_agg <-
result_raw |>
sc_agg(w) |>
sc_rm_coi_time_slices() # we fully remove COI infected time slices

# take into account that some time records are lost due to boundary
# effect cut off (that's why cursor is not just end + 1)
# TODO: check if this is really correct (so that we do not have time shifts ...)
cursor <- cursor + ceiling(dim(result)[[1]] * n / reduced_n)
# we have two compensate two times the half-boundary
cursor <- cursor + (1 + end - begin) - (2 * (time_index_interval[[1]] - 1))
diff <- end - cursor

if (!is.null(total_result)) {
total_result <-
tbind(
total_result,
result
result_agg
)
} else {
total_result <- result
total_result <- result_agg
}

if (progress_bar) setTxtProgressBar(pb, cursor)
Expand Down
140 changes: 94 additions & 46 deletions R/fcwtr_scalogram.R
Original file line number Diff line number Diff line change
@@ -1,31 +1,8 @@
new_fcwtr_scalogram <- function(matrix, sample_freq, freq_begin, freq_end,
freq_scale, sigma, remove_coi) {
freq_scale, sigma) {
stopifnot(is.matrix(matrix))
stopifnot(freq_scale %in% c("linear", "log"))

if (remove_coi) {
dim_t <- dim(matrix)[[1]] # Time dimension
dim_f <- dim(matrix)[[2]] # Frequency dimension

# The standard deviation Σ of a the Gauß like wave packet at frequency f
# and sampling frequency f_s with given σ is given by
# Σ = σ / sqrt(2) f_s / f
# we choose 4Σ to define the support of a wave packet
# (and so boundary effects are expected to occur until 2Σ)
coi_pred <- \(f, t) t * f < sqrt(2) * sigma

# express in dimensionless quantities
t <- rep(1:dim_t, times = dim_f)
f <-
rep(
seq(freq_end, freq_begin, length.out = dim_f) / sample_freq,
each = dim_t
)

# check if points are inside / outside hyperbolic cone
matrix[coi_pred(f, t) | coi_pred(f, dim_t - t)] <- NA_real_
}

obj <-
structure(
matrix,
Expand All @@ -49,6 +26,93 @@ new_fcwtr_scalogram <- function(matrix, sample_freq, freq_begin, freq_end,
obj
}

sc_set_coi_na <- function(x) {
stopifnot(inherits(x, "fcwtr_scalogram"))

x[sc_coi_mask(x)] <- NA_real_

x
}

#' @return A boolean matrix of the same dimensions as `x`. `TRUE` values
#' indicate values inside the boundary "cone of influence".
#' @noRd
sc_coi_mask <- function(x) {
stopifnot(inherits(x, "fcwtr_scalogram"))

dim_t <- sc_dim_time(x) # Time dimension
dim_f <- sc_dim_freq(x) # Frequency dimension

sigma <- attr(x, "sigma")
freq_begin <- attr(x, "freq_begin")
freq_end <- attr(x, "freq_end")
sample_freq <- attr(x, "sample_freq")

# The standard deviation Σ of a the Gauß like wave packet at frequency f
# and sampling frequency f_s with given σ is given by
# Σ = σ / sqrt(2) f_s / f
# we choose 4Σ to define the support of a wave packet
# (and so boundary effects are expected to occur until 2Σ)
coi_pred <- \(f, t) t * f < sqrt(2) * attr(x, "sigma")

# express in dimensionless quantities
t <- rep(1:dim_t, times = dim_f)
f <-
rep(
seq(freq_end, freq_begin, length.out = dim_f) / sample_freq,
each = dim_t
)

mask <- coi_pred(f, t) | coi_pred(f, dim_t - t)
dim(mask) <- c(dim_t, dim_f)

mask
}

sc_dim_freq <- function(x) {
stopifnot(inherits(x, "fcwtr_scalogram"))

dim(x)[[2]]
}

sc_dim_time <- function(x) {
stopifnot(inherits(x, "fcwtr_scalogram"))

dim(x)[[1]]
}

#' @return Returns a vector of two values, the first and the last time index
#' that guarantee that all data is available and trustable (no boundary effects).
#' @noRd
sc_coi_time_interval <- function(x) {
stopifnot(inherits(x, "fcwtr_scalogram"))

#unique(which(is.na(x), arr.ind = TRUE)[, 1])

full_info_rows <- which(rowSums(sc_coi_mask(x)) == 0)

if (length(full_info_rows) > 0) {
c(head(full_info_rows, n = 1), tail(full_info_rows, n = 1))
} else {
c(NA_integer_, NA_integer_)
}
}

sc_rm_coi_time_slices <- function(x) {
stopifnot(inherits(x, "fcwtr_scalogram"))

interval <- sc_coi_time_interval(x)
rows_to_keep <- interval[[1]]:interval[[2]]

new_fcwtr_scalogram(
x[rows_to_keep, ],
attr(x, "sample_freq"),
attr(x, "freq_begin"), attr(x, "freq_end"),
attr(x, "freq_scale"),
attr(x, "sigma")
)
}

seq2 <- function(from = 1, to = 1, length.out, scale = c("linear", "log")) {
scale <- match.arg(scale)

Expand All @@ -66,10 +130,11 @@ seq2 <- function(from = 1, to = 1, length.out, scale = c("linear", "log")) {

# perform aggregation, if possible.
# if it's not possible, be identity
agg <- function(x, n) {
sc_agg <- function(x, wnd) {
stopifnot(inherits(x, "fcwtr_scalogram"))

poolsize <- floor(dim(x)[[1]] / n)
poolsize <- wnd$size_n
n <- floor(sc_dim_time(x) / poolsize)

if (poolsize <= 1) {
# do nothing in case we cannot aggregate
Expand All @@ -88,8 +153,7 @@ agg <- function(x, n) {
attr(x, "sample_freq") / poolsize,
attr(x, "freq_begin"), attr(x, "freq_end"),
attr(x, "freq_scale"),
attr(x, "sigma"),
remove_coi = FALSE
attr(x, "sigma")
)
}

Expand All @@ -100,7 +164,7 @@ tbind <- function(..., deparse.level = 1) {

# check if attributes are identical, otherwise combination
# does not make sense
if (length(unique(lapply(args, \(arg) attr(arg, "sample_freq")))) > 1) {
if (length(unique(lapply(args, \(arg) round(attr(arg, "sample_freq"))))) > 1) {
stop("Sampling frequencies need to be identical.")
}
if (length(unique(lapply(args, \(arg) attr(arg, "freq_begin")))) > 1) {
Expand All @@ -123,23 +187,7 @@ tbind <- function(..., deparse.level = 1) {
attr(args[[1]], "sample_freq"),
attr(args[[1]], "freq_begin"), attr(args[[1]], "freq_end"),
attr(args[[1]], "freq_scale"),
attr(args[[1]], "sigma"),
remove_coi = FALSE
)
}

rm_na_time_slices <- function(x) {
stopifnot(inherits(x, "fcwtr_scalogram"))

rows_to_remove <- unique(which(is.na(x), arr.ind = TRUE)[, 1])

new_fcwtr_scalogram(
x[-rows_to_remove, ],
attr(x, "sample_freq"),
attr(x, "freq_begin"), attr(x, "freq_end"),
attr(x, "freq_scale"),
attr(x, "sigma"),
remove_coi = FALSE
attr(args[[1]], "sigma")
)
}

Expand Down
48 changes: 48 additions & 0 deletions R/wnd.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
new_wnd <- function(size_n, size_time) {
structure(
list(
size_n = size_n,
size_time = size_time
),
class = "wnd"
)
}

#' @param n window length in discrete time steps (sampling steps)
#' @noRd
wnd_from_dim <- function(n, orig_sample_freq) {
new_wnd(
size_n = n,
size_time = n / orig_sample_freq
)
}

#' @param secs target window length in seconds. The exact window length
#' depends on the original sample frequency `orig_sample_freq`
#' @noRd
wnd_from_secs <- function(secs, orig_sample_freq) {
n <- round(secs * orig_sample_freq)

wnd_from_dim(n, orig_sample_freq)
}

#' @param new_resolution relative window specified by a target resolution
#' in seconds (distance between discrete time steps)
#' @noRd
wnd_from_resolution <- function(target_resolution, orig_sample_freq) {
wnd_from_dim(
floor(orig_sample_freq * target_resolution),
orig_sample_freq
)
}

#' @param target_size New total target size of `sc`.
#' @noRd
wnd_from_target_size <- function(target_size, sc) {
stopifnot(inherits(sc, "fcwtr_scalogram"))

wnd_from_dim(
floor(sc_dim_time(sc) / target_size),
attr(sc, "sample_freq")
)
}
2 changes: 1 addition & 1 deletion README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ batch_result <-
freq_end = 12000,
n_freqs = 200,
sigma = 4,
time_resolution = 1 / 44100
time_resolution = 0.01
)
plot(batch_result)
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ batch_result <-
freq_end = 12000,
n_freqs = 200,
sigma = 4,
time_resolution = 1 / 44100
time_resolution = 0.01
)

plot(batch_result)
Expand Down
1 change: 1 addition & 0 deletions fcwtr.Rproj
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
Version: 1.0
ProjectId: 0048848a-100c-4b20-8eb8-b73e3c0400e4

RestoreWorkspace: Default
SaveWorkspace: Default
Expand Down
Loading

0 comments on commit c4859bc

Please sign in to comment.