88import pooch
99import numpy as np
1010import scipy as sp
11- from mne import time_frequency
11+ from mne import time_frequency , __version__ as mne_version
1212
1313from pybispectra import __version__ as version
1414from pybispectra .utils ._defaults import _precision
@@ -164,15 +164,16 @@ def compute_tfr(
164164 zero_mean_wavelets : bool | None = None ,
165165 use_fft : bool = True ,
166166 multitaper_time_bandwidth : int | float = 4.0 ,
167+ output : str = "power" ,
167168 n_jobs : int = 1 ,
168169 verbose : bool = True ,
169- ) -> tuple [np .ndarray , np .ndarray ]:
170- """Compute the amplitude time-frequency representation (TFR) of data.
170+ ) -> tuple [np .ndarray , np .ndarray ] | tuple [ np . ndarray , np . ndarray , np . ndarray ] :
171+ """Compute the time-frequency representation (TFR) of data.
171172
172173 Parameters
173174 ----------
174175 data : ~numpy.ndarray, shape of [epochs, channels, times]
175- Real-valued data to compute the amplitude TFR of.
176+ Real-valued data to compute the TFR of.
176177
177178 sampling_freq : int | float
178179 Sampling frequency (in Hz) of ``data``.
@@ -203,6 +204,14 @@ def compute_tfr(
203204 bandwidth (in Hz). Only used if ``tfr_mode = "multitaper"``. See
204205 :func:`mne.time_frequency.tfr_array_multitaper` for more information.
205206
207+ output : ``"power"`` | ``"complex"`` (default ``"power"``)
208+ Type of TFR output to return.
209+
210+ .. note::
211+ If ``output = "complex"`` and ``tfr_mode = "multitaper"``, returning weights
212+ for each taper requires MNE version 1.10 or higher.
213+ .. versionadded:: 1.3
214+
206215 n_jobs : int (default ``1``)
207216 Number of jobs to run in parallel. If ``-1``, all available CPUs are used.
208217
@@ -211,19 +220,24 @@ def compute_tfr(
211220
212221 Returns
213222 -------
214- tfr : ~numpy.ndarray, shape of [epochs, channels, frequencies, times]
215- Amplitude/power of the TFR of ``data``.
223+ tfr : ~numpy.ndarray, shape of [epochs, channels (, tapers), frequencies, times]
224+ TFR power or complex coefficients of ``data``. The ``tapers`` dimension is only
225+ present if ``output = "complex"`` and ``tfr_mode = "multitaper"``.
216226
217227 freqs : ~numpy.ndarray of float, shape of [frequencies]
218228 Frequencies (in Hz) in ``tfr``.
219229
230+ weights : ~numpy.ndarray, shape of [tapers, frequencies]
231+ Taper weights. Only returned if ``output = "complex"`` and ``tfr_mode =
232+ "multitaper"``.
233+
220234 Notes
221235 -----
222236 This function acts as a wrapper around the MNE TFR computation functions
223237 :func:`mne.time_frequency.tfr_array_morlet` and
224- :func:`mne.time_frequency.tfr_array_multitaper` with ``output = "power"`` .
238+ :func:`mne.time_frequency.tfr_array_multitaper`.
225239 """
226- tfr_func , n_jobs = _compute_tfr_input_checks (
240+ tfr_func , return_weights , n_jobs = _compute_tfr_input_checks (
227241 data ,
228242 sampling_freq ,
229243 freqs ,
@@ -232,6 +246,7 @@ def compute_tfr(
232246 zero_mean_wavelets ,
233247 use_fft ,
234248 multitaper_time_bandwidth ,
249+ output ,
235250 n_jobs ,
236251 verbose ,
237252 )
@@ -242,23 +257,35 @@ def compute_tfr(
242257 "freqs" : freqs ,
243258 "n_cycles" : n_cycles ,
244259 "use_fft" : use_fft ,
245- "output" : "power" ,
260+ "output" : output ,
246261 "n_jobs" : n_jobs ,
247262 "verbose" : verbose ,
248263 }
249264 if zero_mean_wavelets is not None :
250265 tfr_func_kwargs ["zero_mean" ] = zero_mean_wavelets
251266 if tfr_mode == "multitaper" :
252267 tfr_func_kwargs ["time_bandwidth" ] = multitaper_time_bandwidth
268+ if output == "complex" :
269+ tfr_func_kwargs ["return_weights" ] = True
253270
254271 if verbose :
255272 print ("Computing TFR of the data..." )
256273
257- tfr = np .array (tfr_func (** tfr_func_kwargs ), dtype = _precision .real )
274+ out = tfr_func (** tfr_func_kwargs )
275+ if return_weights :
276+ tfr = out [0 ]
277+ weights = out [1 ]
278+ else :
279+ tfr = out
280+ tfr = np .asarray (
281+ tfr , dtype = _precision .real if output == "power" else _precision .complex
282+ )
258283
259284 if verbose :
260285 print (" [TFR computation finished]\n " )
261286
287+ if return_weights :
288+ return tfr , freqs .astype (_precision .real ), weights .astype (_precision .real )
262289 return tfr , freqs .astype (_precision .real )
263290
264291
@@ -271,16 +298,20 @@ def _compute_tfr_input_checks(
271298 zero_mean_wavelets : bool | None ,
272299 use_fft : bool ,
273300 multitaper_time_bandwidth : int | float ,
301+ output : str ,
274302 n_jobs : int ,
275303 verbose : bool ,
276- ) -> tuple [Callable , int ]:
304+ ) -> tuple [Callable , bool , int ]:
277305 """Check inputs for computing TFR.
278306
279307 Returns
280308 -------
281309 tfr_func
282310 Function to use to compute TFR.
283311
312+ return_weights : bool
313+ Whether or not taper weights will be returned.
314+
284315 n_jobs
285316 """
286317 if not isinstance (data , np .ndarray ):
@@ -334,6 +365,21 @@ def _compute_tfr_input_checks(
334365 if not isinstance (multitaper_time_bandwidth , _number_like ):
335366 raise TypeError ("`multitaper_time_bandwidth` must be an int or a float." )
336367
368+ outputs = ["power" , "complex" ]
369+ if not isinstance (output , str ):
370+ raise TypeError ("`output` must be a str." )
371+ if output not in outputs :
372+ raise ValueError (f"`output` must be one of { outputs } ." )
373+
374+ return_weights = False
375+ if tfr_mode == "multitaper" and output == "complex" : # pragma: no cover
376+ if Version (mne_version ) < Version ("1.10" ):
377+ raise RuntimeError (
378+ "If `tfr_mode='multitaper'` and `output='complex'`, MNE >= 1.10 is "
379+ f"required to return taper weights (is { mne_version } )."
380+ )
381+ return_weights = True
382+
337383 if not isinstance (n_jobs , _int_like ):
338384 raise TypeError ("`n_jobs` must be an integer." )
339385 if n_jobs < 1 and n_jobs != - 1 :
@@ -346,7 +392,7 @@ def _compute_tfr_input_checks(
346392 if verbose and not np .isreal (data ).all ():
347393 warn ("`data` is expected to be real-valued." , UserWarning )
348394
349- return tfr_func , n_jobs
395+ return tfr_func , return_weights , n_jobs
350396
351397
352398def compute_rank (data : np .ndarray , sv_tol : int | float = 1e-5 ) -> int :
0 commit comments