Welcome to jax-dataloader’s documentation!¶
A data loader for loading and batching data for JAX models.
- jax_dataloader.jax_dataloader.JAXDataLoader.data¶
The input data for the model (features).
- Type:
np.ndarray
- jax_dataloader.jax_dataloader.JAXDataLoader.labels¶
The labels corresponding to the input data.
- Type:
np.ndarray
- jax_dataloader.jax_dataloader.JAXDataLoader.batch_size¶
The size of each batch, by default 32.
- Type:
int, optional
- jax_dataloader.jax_dataloader.JAXDataLoader.shuffle¶
Whether to shuffle the data before each epoch, by default True.
- Type:
bool, optional
- jax_dataloader.jax_dataloader.JAXDataLoader.num_workers¶
The number of workers for loading data in parallel, by default 4.
- Type:
int, optional
- __iter__():
Resets the loader and returns itself for iteration.
- __next__():
Returns the next batch of data and labels.
- _parallel_process(batch_data, batch_labels):
Processes the batch in parallel.
- _preprocess(sample):
Preprocesses each sample (normalizes it).