Skip to content

Commit

Permalink
Merge pull request #19 from crnolan/feature/18/multi-channel
Browse files Browse the repository at this point in the history
Handle multiple channels per recording site
  • Loading branch information
crnolan authored Apr 30, 2024
2 parents feb4d4c + 9e9b7d6 commit 0b05d67
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 69 deletions.
4 changes: 2 additions & 2 deletions src/behapy/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ def preprocess_dash(bidsroot):

def get_recording(index):
r = signals.iloc[index]
signal = fp.load_signal(bidsroot, r.subject, r.session, r.task, r.run,
r.label, 'iso')
signal = fp.load_signals(bidsroot, r.subject, r.session, r.task, r.run,
r.label, 'iso')
return signal

dash = PreprocessDashboard(signals, get_recording, bidsroot)
Expand Down
169 changes: 132 additions & 37 deletions src/behapy/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
def series_like(df, name, default=0.):
series = pd.Series(default, index=df.index, name=name)
series.attrs = df.attrs.copy()
_ = series.attrs.pop('artifact_channel', None)
_ = series.attrs.pop('channels', None)
_ = series.attrs.pop('iso_channel', None)
_ = series.attrs.pop('channel', None)
return series
Expand All @@ -39,8 +41,9 @@ def load_channel(root, subject, session, task, run, label, channel):
return data, meta


def load_signal(root, subject, session, task, run, label, iso_channel='iso'):
"""Load a raw signal, including the isosbestic channel if present.
def load_signals(root, subject, session, task, run, label,
artifact_channel=None, exclude_artifact_channel=True):
"""Load all raw signals for a given site.
"""
root = Path(root).absolute()
recordings = pd.DataFrame(
Expand All @@ -50,6 +53,93 @@ def load_signal(root, subject, session, task, run, label, iso_channel='iso'):
sessions = recordings.loc[:, 'session'].unique()
tasks = recordings.loc[:, 'task'].unique()
labels = recordings.loc[:, 'label'].unique()
if any([item.shape[0] != 1
for item in [subjects, sessions, tasks, labels]]):
msg = (f'Multiple signal names found for session '
f'with subject {subject}, session {session}, task {task}, '
f'run {run} and label {label}')
logging.error(msg)
raise ValueError(msg)

# Load channels
data = []
t0 = None
fs = None
for r in recordings.itertuples():
d, meta = load_channel(root=root,
subject=r.subject,
session=r.session,
task=r.task,
run=r.run,
label=r.label,
channel=r.channel)
if fs is None:
fs = meta['fs']
if t0 is None:
t0 = meta['start_time']
if (fs != meta['fs']) or (t0 != meta['start_time']):
msg = ('Unequal sample frequencies and/or start times '
'for subject {}, session {}, task {}, run {} and label {}')
msg.format(subject, session, task, run, label)
raise ValueError(msg)
t = pd.Index(np.arange(d.shape[0]) / fs + t0, name='time')
data.append(pd.Series(d, name=r.channel, index=t))

signal = pd.concat(data, axis=1)
signal.index.name = 'time'
signal.attrs['root'] = root
signal.attrs['fs'] = fs
signal.attrs['start_time'] = t0
signal.attrs['subject'] = subject
signal.attrs['session'] = session
signal.attrs['task'] = task
signal.attrs['run'] = run
signal.attrs['label'] = label
channels = signal.columns.to_list()
if artifact_channel is None:
acd = set(['iso', 'isos', 'isosbestic'])
artifact_channel = set(channels).intersection(acd)
if len(artifact_channel) == 0:
logging.warning(f'No artifact channel found for subject {subject}, '
f'session {session}, task {task}, run {run} and '
f'label {label}, using first channel {channels[0]}')
artifact_channel = channels[0]
elif len(artifact_channel) == 1:
artifact_channel = artifact_channel.pop()
else:
raise ValueError(
f'Multiple default artifact channels found for subject '
f'{subject}, session {session}, task {task}, run {run} and '
f'label {label}: {artifact_channel}')

signal.attrs['artifact_channel'] = artifact_channel
if exclude_artifact_channel:
signal.attrs['channels'] = [c for c in channels
if c != artifact_channel]
else:
signal.attrs['channels'] = channels

if len(signal.attrs['channels']) == 0:
raise ValueError(f'No channels found for subject {subject}, '
f'session {session}, task {task}, run {run} '
f'and label {label}')
return signal


def load_signal(root, subject, session, task, run, label, iso_channel='iso',
channel=None):
"""Load a raw signal, including the isosbestic channel if present.
"""
root = Path(root).absolute()
if channel is None:
channel = '*'
recordings = pd.DataFrame(
list_raw(root, subject=subject, session=session, task=task,
run=run, label=label, channel=channel))
subjects = recordings.loc[:, 'subject'].unique()
sessions = recordings.loc[:, 'session'].unique()
tasks = recordings.loc[:, 'task'].unique()
labels = recordings.loc[:, 'label'].unique()
if any([item.shape[0] != 1
for item in [subjects, sessions, tasks, labels]]):
msg = ('Multiple signal names found for session'
Expand Down Expand Up @@ -135,8 +225,6 @@ def downsample(signal, factor=None):
factor *= 2
ds = sig.decimate(signal.to_numpy(), factor, ftype='fir',
zero_phase=True, axis=0)
ts = (np.arange(ds.shape[0]) / (signal.attrs['fs'] / factor) +
signal.attrs['start_time'])
df = pd.DataFrame(ds, index=signal.index[::factor], columns=signal.columns)
df.attrs = signal.attrs
df.attrs['fs'] = signal.attrs['fs'] / factor
Expand Down Expand Up @@ -176,19 +264,23 @@ def find_discontinuities(signal, mean_window=3, std_window=30, nstd_thresh=2):
# then use the median of a sliding window STD as our
# characteristic STD.
std_n = int(signal.attrs['fs'] * std_window)
# iso_rstds = np.std(sliding_window_view(site.iso(), std_n), axis=-1)
data = signal[signal.attrs['channel']].to_numpy()
data_rstds = bn.move_std(data, std_n, axis=-1)
data_thresh = np.median(data_rstds[~np.isnan(data_rstds)], axis=-1)
data_rmeans = bn.move_mean(np.pad(data, n, 'edge'), n, axis=-1)
iso = signal[signal.attrs['iso_channel']].to_numpy()
iso_rstds = bn.move_std(iso, std_n, axis=-1)
iso_thresh = np.median(iso_rstds[~np.isnan(iso_rstds)], axis=-1)
mean_thresh = iso_thresh * nstd_thresh
# Calculate a sliding mean
# iso_rmeans = np.mean(sliding_window_view(np.pad(site.iso(), n, 'edge'), n), axis=-1)
iso_rmeans = bn.move_mean(np.pad(iso, n, 'edge'), n, axis=-1)
d = (iso_rmeans[n:-n] - iso_rmeans[(n*2):])
if 'channels' not in signal.attrs:
logging.info('Using original single-channel format')
channels = [signal.attrs['channel']]
artifact_channel = signal.attrs['iso_channel']
else:
channels = signal.attrs['channels']
artifact_channel = signal.attrs['artifact_channel']
data = signal[channels].to_numpy()
data_rstds = bn.move_std(data, std_n, axis=0)
data_thresh = np.nanmedian(data_rstds, axis=0)
data_rmeans = bn.move_mean(np.pad(data, ((n, n), (0, 0)), 'edge'), n, axis=0)
afct = signal[artifact_channel].to_numpy()
afct_rstds = bn.move_std(afct, std_n, axis=0)
afct_thresh = np.nanmedian(afct_rstds, axis=0)
mean_thresh = afct_thresh * nstd_thresh
afct_rmeans = bn.move_mean(np.pad(afct, ((n, n)), 'edge'), n, axis=0)
d = (afct_rmeans[n:-n] - afct_rmeans[(n*2):])
d_thresh = np.abs(d) > mean_thresh
# Find the start and end of each mean shift
mean_shift_bounds = np.diff(d_thresh.astype(int))
Expand All @@ -212,10 +304,10 @@ def find_discontinuities(signal, mean_window=3, std_window=30, nstd_thresh=2):
onsets = np.where(mean_shift_bounds == 1)[0]
offsets = np.where(mean_shift_bounds == -1)[0]
for i, (onset, offset) in enumerate(zip(onsets, offsets)):
k = np.argmax(np.abs(data[offset:onset:-1] - data_rmeans[onset+n]) < data_thresh)
k = np.argmax(np.abs(data[offset:onset:-1, :] - data_rmeans[[onset+n], :]) < data_thresh)
if k > 0:
onsets[i] = offset - k
k = np.argmax(np.abs(data[onset:offset:1] - data_rmeans[offset+n]) < data_thresh)
k = np.argmax(np.abs(data[onset:offset:1, :] - data_rmeans[[offset+n], :]) < data_thresh)
if k > 0:
offsets[i] = onset + k
return [(onset, offset)
Expand All @@ -227,20 +319,22 @@ def find_disconnects(signal, zero_nstd_thresh=5, mean_window=3, std_window=30,
nstd_thresh=2):
bounds = find_discontinuities(signal, mean_window=mean_window,
std_window=std_window, nstd_thresh=nstd_thresh)
# data = signal[signal.attrs['iso_channel']].to_numpy()
data = signal[signal.attrs['channel']].to_numpy()
if 'channels' not in signal.attrs:
logging.info('Using original single-channel format')
channels = [signal.attrs['channel']]
else:
channels = signal.attrs['channels']
data = signal[channels].to_numpy()
ts = signal.index.to_numpy()
std_n = int(signal.attrs['fs'] * std_window)
data_rstds = bn.move_std(data, std_n, axis=-1)
data_rstds = data_rstds[~np.isnan(data_rstds)]
zero_thresh = np.median(data_rstds, axis=-1) * zero_nstd_thresh
data_rstds = bn.move_std(data, std_n, axis=0)
zero_thresh = np.nanmedian(data_rstds, axis=0) * zero_nstd_thresh
dc_intervals = IntervalTree()
bounds = [(0, 0)] + bounds + [(len(data)-1, len(data)-1)]
bounds = [(0, 0)] + bounds + [(data.shape[0]-1, data.shape[0]-1)]
for (on0, off0), (on1, off1) in zip(bounds[:-1], bounds[1:]):
if np.mean(data[off0:on1]) < zero_thresh:
if np.any(np.mean(data[off0:on1, :], axis=0) < zero_thresh):
dc_intervals.add(Interval(ts[on0], ts[off1]))
dc_intervals.merge_overlaps()
# return [(b[0], b[1]) for b in list(dc_intervals)]
return dc_intervals


Expand Down Expand Up @@ -272,8 +366,6 @@ def reject(signal, intervals, fill=False):
# replace with a linear interpolation between the endpoints.
signal = signal.copy()
signal.loc[~mask] = np.nan
# for start, end in interval_list:
# signal.loc[start:end] = np.nan
signal = signal.interpolate(method='linear', limit_direction='both')
signal['mask'] = mask
return signal
Expand Down Expand Up @@ -314,8 +406,9 @@ def smooth(data, cutoff=1):
except AttributeError:
b = sig.firwin(1001, cutoff=[cutoff], fs=data.attrs['fs'], pass_zero=True)
smooth.filter_b = b
smoothed = series_like(data, 'smoothed')
smoothed[:] = sig.filtfilt(b, 1, data)
# smoothed = series_like(data, 'smoothed')
smoothed = data.copy()
smoothed[:] = sig.filtfilt(b, 1, data.to_numpy(), axis=0).astype(np.float32)
return smoothed


Expand All @@ -326,8 +419,9 @@ def detrend(data, numtaps=1001, cutoff=0.05):
b = sig.firwin(numtaps, cutoff=[cutoff], fs=data.attrs['fs'],
pass_zero=False)
detrend.filter_b = b
detrended = series_like(data, 'detrended')
detrended[:] = sig.filtfilt(b, 1, data)
# detrended = series_like(data, 'detrended')
detrended = data.copy()
detrended[:] = sig.filtfilt(b, 1, data.to_numpy(), axis=0).astype(np.float32)
return detrended


Expand Down Expand Up @@ -428,19 +522,20 @@ def preprocess(root, subject, session, task, run, label):
logging.info(f'Preprocessing subject {subject}, '
f'session {session}, task {task}, '
f'run {run}, label {label}...')
recording = load_signal(root, subject, session, task, run, label, 'iso')
recording = load_signals(root, subject, session, task, run, label, 'iso')
recording = downsample(recording, 64)
rej = reject(recording, intervals, fill=True)
ch = recording.attrs['channel']
ch = recording.attrs['channels']
# We were doing a robust regression, but the fit isn't good enough.
# Let's just detrend and divide by the smoothed signal instead.
# dff = fp.series_like(recording, name='dff')
# dff.loc[rej.index] = fp.detrend(rej[ch])
dff = detrend(rej[ch], cutoff=config['detrend_cutoff'])
dff = dff / smooth(rej[ch])
dff.name = 'dff'
dff = dff.to_frame()
# dff.name = 'dff'
# dff = dff.to_frame()
dff['mask'] = rej['mask']
dff.attrs['root'] = str(dff.attrs['root'])
data_fn = get_preprocessed_fibre_path(
root, subject, session, task, run, label, 'parquet')
meta_fn = get_preprocessed_fibre_path(
Expand Down
10 changes: 6 additions & 4 deletions src/behapy/pathutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ def get_preprocessed_fibre_path(root, sub, ses, task, run, label, ext):
label=label, ext=ext)


def list_recordings(base, subject='*', session='*', task='*', run='*', label='*', ext='npy'):
def list_recordings(base, subject='*', session='*', task='*', run='*',
label='*', channel='*', ext='npy'):
Recording = namedtuple("Recording", ["subject", "session", "task", "run", "label", "channel", "file_path"])

# Set the pattern for the data files
base = Path(base)
pattern = (f'sub-{subject}/ses-{session}/fp/'
f'sub-{subject}_ses-{session}_'
f'task-{task}_run-{run}_label-{label}_channel-*.{ext}')
f'task-{task}_run-{run}_label-{label}_channel-{channel}.{ext}')
# Search for files that match the pattern
data_files = list(base.glob(str(pattern)))
# Regex pattern to extract variables from the file names
Expand All @@ -63,9 +64,10 @@ def list_recordings(base, subject='*', session='*', task='*', run='*', label='*'
return extracted_data


def list_raw(root, subject='*', session='*', task='*', run='*', label='*'):
def list_raw(root, subject='*', session='*', task='*', run='*', label='*',
channel='*'):
return list_recordings(Path(root) / 'rawdata', subject,
session, task, run, label, 'npy')
session, task, run, label, channel, 'npy')


def list_preprocessed(root, subject='*', session='*', task='*', run='*',
Expand Down
60 changes: 34 additions & 26 deletions src/behapy/visuals.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from functools import partial
from functools import partial, reduce
from intervaltree import Interval
import json
from . import fp
Expand Down Expand Up @@ -136,24 +136,16 @@ def update_intervals(self):
@param.depends("selected_index", "interval_update", watch=True)
def update_regressions(self):
rej = fp.reject(self.recording, self.intervals, fill=True)
ch = self.recording.attrs['channel']
if 'channels' not in self.recording.attrs:
ch = [self.recording.attrs['channel']]
else:
ch = self.recording.attrs['channels']
config = load_preprocess_config(self.bidsroot)
# We were doing a robust regression, but the fit isn't good enough.
# Let's just detrend and divide by the smoothed signal instead.
dff = fp.detrend(rej[ch], cutoff=config['detrend_cutoff'])
dff = dff / fp.smooth(rej[ch])
dff.name = 'dff'
# dff = fp.series_like(self.recording, name='dff')
# dff.loc[rej.index] = fp.detrend(rej[ch])
# dff = dff / fp.smooth(rej[ch])
# OLD REGRESSION CODE
# fit = fp.fit(rej)
# regression = fp.series_like(rej, name='regression')
# regression[:] = fit.fittedvalues
# self.regression = regression
# dff = fp.series_like(self.recording, name='dff')
# dff.loc[rej.index] = (rej[ch] - regression) / regression
# dff = dff / dff.std()
# dff.name = 'dff'
self.dff = dff
self.regression_update += 1

Expand All @@ -163,22 +155,38 @@ def plot_all(self):
return
regression = self.regression
tools = ['xbox_select']
isoch = self.recording.attrs['iso_channel']
ch = self.recording.attrs['channel']
if 'channels' not in self.recording.attrs:
channels = [self.recording.attrs['channel']]
ach = self.recording.attrs['iso_channel']
else:
channels = self.recording.attrs['channels']
ach = self.recording.attrs['artifact_channel']
iso_shade = datashade(
signal_curve(self.recording[isoch], y_dim='F'),
signal_curve(self.recording[ach], y_dim='F'),
aggregator=ds.count(), cmap='blue')
sig_shade = datashade(
signal_curve(self.recording[ch], y_dim='F'),
aggregator=ds.count(), cmap='red').opts(tools=tools)
dff_shade = datashade(
signal_curve(self.dff, y_dim='dF/F'),
aggregator=ds.count(), cmap='green')
sig_shades = []
dff_shades = []
cmaps = ['red', 'green', 'purple']
for ch, cmap in zip(channels, cmaps):
sig_shade = datashade(
signal_curve(self.recording[ch], y_dim='F'),
aggregator=ds.count(), cmap=cmap).opts(tools=tools)
dff_shade = datashade(
signal_curve(self.dff[ch], y_dim=ch),
aggregator=ds.count(), cmap=cmap)
sig_shades.append(sig_shade)
dff_shades.append(dff_shade)
overlay = interval_overlay_map(iso_shade, self.intervals,
self.update_intervals)
# plot = (rej_shade.opts(xaxis=None) +
plot = ((iso_shade * sig_shade * overlay).opts(xaxis=None) +
dff_shade)
# overlay = (iso_shade * sig_shade * overlay)
# raw_plot = hv.Overlay([iso_shade] + sig_shades) * overlay
raw_plot = reduce(lambda a, b: a * b,
[iso_shade, overlay] + sig_shades)
raw_plot = raw_plot.opts(xaxis=None)
# plot = ((iso_shade * sig_shade * overlay).opts(xaxis=None) +
# dff_shade)
plot = raw_plot + hv.Layout(dff_shades)
plot = plot.opts(
opts.RGB(responsive=True, min_width=600, min_height=300,
tools=tools))
Expand All @@ -197,5 +205,5 @@ def view(self):
sizing_mode='stretch_both'
),
styles=dict(background='WhiteSmoke'),
sizing_mode='stretch_both'
sizing_mode='stretch_both'
)

0 comments on commit 0b05d67

Please sign in to comment.