Source code for neuroconv.tools.roiextractors.imagingextractordatachunkiterator

"""General purpose iterator for all ImagingExtractor data."""

import numpy as np
from roiextractors import ImagingExtractor
from tqdm import tqdm

from neuroconv.tools.hdmf import GenericDataChunkIterator
from neuroconv.tools.iterative_write import (
    get_image_series_buffer_shape,
    get_image_series_chunk_shape,
)


[docs] class ImagingExtractorDataChunkIterator(GenericDataChunkIterator): """DataChunkIterator for ImagingExtractor objects primarily used when writing imaging data to an NWB file.""" def __init__( self, imaging_extractor: ImagingExtractor, buffer_gb: float | None = None, buffer_shape: tuple | None = None, chunk_mb: float | None = None, chunk_shape: tuple | None = None, display_progress: bool = False, progress_bar_class: tqdm | None = None, progress_bar_options: dict | None = None, ): """ Initialize an Iterable object which returns DataChunks with data and their selections on each iteration. Parameters ---------- imaging_extractor : ImagingExtractor The ImagingExtractor object which handles the data access. buffer_gb : float, optional The upper bound on size in gigabytes (GB) of each selection from the iteration. The buffer_shape will be set implicitly by this argument. Cannot be set if `buffer_shape` is also specified. The default is 1GB. buffer_shape : tuple, optional Manual specification of buffer shape to return on each iteration. Must be a multiple of chunk_shape along each axis. Cannot be set if `buffer_gb` is also specified. The default is None. chunk_mb : float, optional The upper bound on size in megabytes (MB) of the internal chunk for the HDF5 dataset. The chunk_shape will be set implicitly by this argument. Cannot be set if `chunk_shape` is also specified. The default is 10MB, as recommended by the HDF5 group. For more details, search the hdf5 documentation for "Improving IO Performance Compressed Datasets". chunk_shape : tuple, optional Manual specification of the internal chunk shape for the HDF5 dataset. Cannot be set if `chunk_mb` is also specified. The default is None. display_progress : bool, default=False Display a progress bar with iteration rate and estimated completion time. progress_bar_class : dict, optional The progress bar class to use. Defaults to tqdm.tqdm if the TQDM package is installed. progress_bar_options : dict, optional Dictionary of keyword arguments to be passed directly to tqdm. See https://github.com/tqdm/tqdm#parameters for options. """ self.imaging_extractor = imaging_extractor assert not (buffer_gb and buffer_shape), "Only one of 'buffer_gb' or 'buffer_shape' can be specified!" assert not (chunk_mb and chunk_shape), "Only one of 'chunk_mb' or 'chunk_shape' can be specified!" if chunk_mb and buffer_gb: assert chunk_mb * 1e6 <= buffer_gb * 1e9, "chunk_mb must be less than or equal to buffer_gb!" if chunk_mb is None and chunk_shape is None: chunk_mb = 10.0 if chunk_shape is None: chunk_shape = self._get_default_chunk_shape(chunk_mb=chunk_mb) if buffer_gb is None and buffer_shape is None: buffer_gb = 1.0 if buffer_shape is None: buffer_shape = self._get_scaled_buffer_shape(buffer_gb=buffer_gb, chunk_shape=chunk_shape) super().__init__( buffer_shape=buffer_shape, chunk_shape=chunk_shape, display_progress=display_progress, progress_bar_class=progress_bar_class, progress_bar_options=progress_bar_options, ) def _get_sample_shape(self) -> tuple: """This translate the sample shape in roiextractors to the nwb convention by transposing the frame shape.""" roi_extractors_frame_shape = self.imaging_extractor.get_frame_shape() height, width = roi_extractors_frame_shape[0], roi_extractors_frame_shape[1] nwb_frame_shape = (width, height) if self.imaging_extractor.is_volumetric: num_planes = self.imaging_extractor.get_num_planes() sample_shape = nwb_frame_shape + (num_planes,) else: sample_shape = nwb_frame_shape return sample_shape def _get_default_chunk_shape(self, chunk_mb: float) -> tuple: """Select the chunk_shape less than the threshold of chunk_mb while keeping the original image size.""" assert chunk_mb > 0, f"chunk_mb ({chunk_mb}) must be greater than zero!" num_samples = self.imaging_extractor.get_num_samples() sample_shape = self._get_sample_shape() dtype = self.imaging_extractor.get_dtype() chunk_shape = get_image_series_chunk_shape( num_samples=num_samples, sample_shape=sample_shape, dtype=dtype, chunk_mb=chunk_mb, ) return chunk_shape def _get_scaled_buffer_shape(self, buffer_gb: float, chunk_shape: tuple) -> tuple: """Select the buffer_shape less than the threshold of buffer_gb that is also a multiple of the chunk_shape.""" assert buffer_gb > 0, f"buffer_gb ({buffer_gb}) must be greater than zero!" assert all(np.array(chunk_shape) > 0), f"Some dimensions of chunk_shape ({chunk_shape}) are less than zero!" sample_shape = self._get_sample_shape() series_shape = self._get_maxshape() dtype = self._get_dtype() buffer_shape = get_image_series_buffer_shape( chunk_shape=chunk_shape, sample_shape=sample_shape, series_shape=series_shape, dtype=dtype, buffer_gb=buffer_gb, ) return buffer_shape def _get_dtype(self) -> np.dtype: return self.imaging_extractor.get_dtype() def _get_maxshape(self) -> tuple: num_frames = self.imaging_extractor.get_num_samples() sample_shape = self._get_sample_shape() max_shape = (num_frames,) + sample_shape return max_shape def _get_data(self, selection: tuple[slice]) -> np.ndarray: data = self.imaging_extractor.get_series( start_sample=selection[0].start, end_sample=selection[0].stop, ) tranpose_axes = (0, 2, 1) if len(data.shape) == 3 else (0, 2, 1, 3) sliced_selection = (slice(0, self.buffer_shape[0]),) + selection[1:] return data.transpose(tranpose_axes)[sliced_selection]