Skip to content

Commit 79479d7

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents 373e20d + 6b433b9 commit 79479d7

File tree

4 files changed

+116
-20
lines changed

4 files changed

+116
-20
lines changed

changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
##### Enhancements
66
- Added `norm` parameter to `WaveShape.compute()` to control normalisation of waveshape results.
7+
- Added `output` parameter to `compute_tfr()` to allow complex coefficients to be returned.
78

89
<br>
910

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@
136136
"components",
137137
"frequencies",
138138
"frequency_bands",
139+
"tapers",
139140
"x",
140141
"n_vertices",
141142
"n_faces",

src/pybispectra/utils/utils.py

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pooch
99
import numpy as np
1010
import scipy as sp
11-
from mne import time_frequency
11+
from mne import time_frequency, __version__ as mne_version
1212

1313
from pybispectra import __version__ as version
1414
from pybispectra.utils._defaults import _precision
@@ -164,15 +164,16 @@ def compute_tfr(
164164
zero_mean_wavelets: bool | None = None,
165165
use_fft: bool = True,
166166
multitaper_time_bandwidth: int | float = 4.0,
167+
output: str = "power",
167168
n_jobs: int = 1,
168169
verbose: bool = True,
169-
) -> tuple[np.ndarray, np.ndarray]:
170-
"""Compute the amplitude time-frequency representation (TFR) of data.
170+
) -> tuple[np.ndarray, np.ndarray] | tuple[np.ndarray, np.ndarray, np.ndarray]:
171+
"""Compute the time-frequency representation (TFR) of data.
171172
172173
Parameters
173174
----------
174175
data : ~numpy.ndarray, shape of [epochs, channels, times]
175-
Real-valued data to compute the amplitude TFR of.
176+
Real-valued data to compute the TFR of.
176177
177178
sampling_freq : int | float
178179
Sampling frequency (in Hz) of ``data``.
@@ -203,6 +204,14 @@ def compute_tfr(
203204
bandwidth (in Hz). Only used if ``tfr_mode = "multitaper"``. See
204205
:func:`mne.time_frequency.tfr_array_multitaper` for more information.
205206
207+
output : ``"power"`` | ``"complex"`` (default ``"power"``)
208+
Type of TFR output to return.
209+
210+
.. note::
211+
If ``output = "complex"`` and ``tfr_mode = "multitaper"``, returning weights
212+
for each taper requires MNE version 1.10 or higher.
213+
.. versionadded:: 1.3
214+
206215
n_jobs : int (default ``1``)
207216
Number of jobs to run in parallel. If ``-1``, all available CPUs are used.
208217
@@ -211,19 +220,24 @@ def compute_tfr(
211220
212221
Returns
213222
-------
214-
tfr : ~numpy.ndarray, shape of [epochs, channels, frequencies, times]
215-
Amplitude/power of the TFR of ``data``.
223+
tfr : ~numpy.ndarray, shape of [epochs, channels (, tapers), frequencies, times]
224+
TFR power or complex coefficients of ``data``. The ``tapers`` dimension is only
225+
present if ``output = "complex"`` and ``tfr_mode = "multitaper"``.
216226
217227
freqs : ~numpy.ndarray of float, shape of [frequencies]
218228
Frequencies (in Hz) in ``tfr``.
219229
230+
weights : ~numpy.ndarray, shape of [tapers, frequencies]
231+
Taper weights. Only returned if ``output = "complex"`` and ``tfr_mode =
232+
"multitaper"``.
233+
220234
Notes
221235
-----
222236
This function acts as a wrapper around the MNE TFR computation functions
223237
:func:`mne.time_frequency.tfr_array_morlet` and
224-
:func:`mne.time_frequency.tfr_array_multitaper` with ``output = "power"``.
238+
:func:`mne.time_frequency.tfr_array_multitaper`.
225239
"""
226-
tfr_func, n_jobs = _compute_tfr_input_checks(
240+
tfr_func, return_weights, n_jobs = _compute_tfr_input_checks(
227241
data,
228242
sampling_freq,
229243
freqs,
@@ -232,6 +246,7 @@ def compute_tfr(
232246
zero_mean_wavelets,
233247
use_fft,
234248
multitaper_time_bandwidth,
249+
output,
235250
n_jobs,
236251
verbose,
237252
)
@@ -242,23 +257,35 @@ def compute_tfr(
242257
"freqs": freqs,
243258
"n_cycles": n_cycles,
244259
"use_fft": use_fft,
245-
"output": "power",
260+
"output": output,
246261
"n_jobs": n_jobs,
247262
"verbose": verbose,
248263
}
249264
if zero_mean_wavelets is not None:
250265
tfr_func_kwargs["zero_mean"] = zero_mean_wavelets
251266
if tfr_mode == "multitaper":
252267
tfr_func_kwargs["time_bandwidth"] = multitaper_time_bandwidth
268+
if output == "complex":
269+
tfr_func_kwargs["return_weights"] = True
253270

254271
if verbose:
255272
print("Computing TFR of the data...")
256273

257-
tfr = np.array(tfr_func(**tfr_func_kwargs), dtype=_precision.real)
274+
out = tfr_func(**tfr_func_kwargs)
275+
if return_weights:
276+
tfr = out[0]
277+
weights = out[1]
278+
else:
279+
tfr = out
280+
tfr = np.asarray(
281+
tfr, dtype=_precision.real if output == "power" else _precision.complex
282+
)
258283

259284
if verbose:
260285
print(" [TFR computation finished]\n")
261286

287+
if return_weights:
288+
return tfr, freqs.astype(_precision.real), weights.astype(_precision.real)
262289
return tfr, freqs.astype(_precision.real)
263290

264291

@@ -271,16 +298,20 @@ def _compute_tfr_input_checks(
271298
zero_mean_wavelets: bool | None,
272299
use_fft: bool,
273300
multitaper_time_bandwidth: int | float,
301+
output: str,
274302
n_jobs: int,
275303
verbose: bool,
276-
) -> tuple[Callable, int]:
304+
) -> tuple[Callable, bool, int]:
277305
"""Check inputs for computing TFR.
278306
279307
Returns
280308
-------
281309
tfr_func
282310
Function to use to compute TFR.
283311
312+
return_weights : bool
313+
Whether or not taper weights will be returned.
314+
284315
n_jobs
285316
"""
286317
if not isinstance(data, np.ndarray):
@@ -334,6 +365,21 @@ def _compute_tfr_input_checks(
334365
if not isinstance(multitaper_time_bandwidth, _number_like):
335366
raise TypeError("`multitaper_time_bandwidth` must be an int or a float.")
336367

368+
outputs = ["power", "complex"]
369+
if not isinstance(output, str):
370+
raise TypeError("`output` must be a str.")
371+
if output not in outputs:
372+
raise ValueError(f"`output` must be one of {outputs}.")
373+
374+
return_weights = False
375+
if tfr_mode == "multitaper" and output == "complex": # pragma: no cover
376+
if Version(mne_version) < Version("1.10"):
377+
raise RuntimeError(
378+
"If `tfr_mode='multitaper'` and `output='complex'`, MNE >= 1.10 is "
379+
f"required to return taper weights (is {mne_version})."
380+
)
381+
return_weights = True
382+
337383
if not isinstance(n_jobs, _int_like):
338384
raise TypeError("`n_jobs` must be an integer.")
339385
if n_jobs < 1 and n_jobs != -1:
@@ -346,7 +392,7 @@ def _compute_tfr_input_checks(
346392
if verbose and not np.isreal(data).all():
347393
warn("`data` is expected to be real-valued.", UserWarning)
348394

349-
return tfr_func, n_jobs
395+
return tfr_func, return_weights, n_jobs
350396

351397

352398
def compute_rank(data: np.ndarray, sv_tol: int | float = 1e-5) -> int:

tests/test_util_funcs.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
"""Tests for toolbox utility functions."""
22

33
import os
4+
from packaging.version import Version
45

56
import numpy as np
67
import pytest
78
import scipy as sp
8-
from mne import Info
9+
from mne import Info, __version__ as mne_version
910

1011
from pybispectra.utils import (
1112
compute_fft,
@@ -106,9 +107,21 @@ def test_compute_fft(window: str) -> None:
106107

107108

108109
@pytest.mark.parametrize("tfr_mode", ["morlet", "multitaper"])
110+
@pytest.mark.parametrize("output", ["power", "complex"])
109111
@pytest.mark.parametrize("zero_mean_wavelets", [True, False, None])
110-
def test_compute_tfr(tfr_mode: str, zero_mean_wavelets: bool | None) -> None:
112+
def test_compute_tfr(
113+
tfr_mode: str, output: str, zero_mean_wavelets: bool | None
114+
) -> None:
111115
"""Test `compute_tfr`."""
116+
if (
117+
tfr_mode == "multitaper"
118+
and output == "complex"
119+
and Version(mne_version) < Version("1.10")
120+
):
121+
pytest.skip(
122+
"`output='complex'` with `tfr_mode='multitaper'` requires MNE >= 1.10."
123+
)
124+
112125
n_epochs = 5
113126
n_chans = 3
114127
n_times = 100
@@ -117,25 +130,41 @@ def test_compute_tfr(tfr_mode: str, zero_mean_wavelets: bool | None) -> None:
117130
freqs_in = np.arange(20, 50)
118131

119132
# check it runs with correct inputs
120-
tfr, freqs_out = compute_tfr(
133+
out = compute_tfr(
121134
data=data,
122135
sampling_freq=sampling_freq,
123136
freqs=freqs_in,
124137
tfr_mode=tfr_mode,
125138
zero_mean_wavelets=zero_mean_wavelets,
139+
output=output,
126140
n_jobs=1,
127141
)
142+
if tfr_mode == "multitaper" and output == "complex":
143+
tfr, freqs_out, weights = out
144+
else:
145+
tfr, freqs_out = out
146+
weights = None
147+
128148
assert isinstance(tfr, np.ndarray), "`tfr` should be a NumPy array."
129-
assert tfr.ndim == 4, "`tfr` should have 4 dimensions."
130-
assert tfr.shape[:3] == (n_epochs, n_chans, len(freqs_in)), (
131-
"The first 3 dimensions of `tfr` should have shape [epochs x channels x "
132-
"frequencies]."
133-
)
149+
if tfr_mode == "multitaper" and output == "complex":
150+
assert tfr.shape == (n_epochs, n_chans, 3, len(freqs_in), n_times), (
151+
"`tfr` should have shape [epochs x channels x tapers x frequencies x times]."
152+
) # unsure how to predict number of tapers, so use known value for this test
153+
else:
154+
assert tfr.shape == (n_epochs, n_chans, len(freqs_in), n_times), (
155+
"`tfr` should have shape [epochs x channels x frequencies x times]."
156+
)
134157
assert isinstance(freqs_out, np.ndarray), "`freqs_out` should be a NumPy array."
135158
assert np.all(freqs_in == freqs_out), (
136159
"`freqs_out` and `freqs_in` should be identical"
137160
)
138161

162+
if weights is not None:
163+
assert isinstance(weights, np.ndarray), "`weights` should be a NumPy array."
164+
assert weights.shape == tfr.shape[2:4], (
165+
"`weights` should have shape [tapers x frequencies]."
166+
)
167+
139168
# check it catches incorrect inputs
140169
with pytest.raises(TypeError, match="`data` must be a NumPy array."):
141170
compute_tfr(
@@ -297,6 +326,25 @@ def test_compute_tfr(tfr_mode: str, zero_mean_wavelets: bool | None) -> None:
297326
multitaper_time_bandwidth=[3],
298327
)
299328

329+
with pytest.raises(TypeError, match="`output` must be a str."):
330+
compute_tfr(
331+
data=data,
332+
sampling_freq=sampling_freq,
333+
freqs=freqs_in,
334+
tfr_mode=tfr_mode,
335+
zero_mean_wavelets=zero_mean_wavelets,
336+
output=[output],
337+
)
338+
with pytest.raises(ValueError, match="`output` must be one of"):
339+
compute_tfr(
340+
data=data,
341+
sampling_freq=sampling_freq,
342+
freqs=freqs_in,
343+
tfr_mode=tfr_mode,
344+
zero_mean_wavelets=zero_mean_wavelets,
345+
output="not_an_output",
346+
)
347+
300348
with pytest.raises(TypeError, match="`n_jobs` must be an integer."):
301349
compute_tfr(
302350
data=data,

0 commit comments

Comments
 (0)