Welcome to jax-dataloader’s documentation!

A data loader for loading and batching data for JAX models.

jax_dataloader.jax_dataloader.JAXDataLoader.data

The input data for the model (features).

Type:

np.ndarray

jax_dataloader.jax_dataloader.JAXDataLoader.labels

The labels corresponding to the input data.

Type:

np.ndarray

jax_dataloader.jax_dataloader.JAXDataLoader.batch_size

The size of each batch, by default 32.

Type:

int, optional

jax_dataloader.jax_dataloader.JAXDataLoader.shuffle

Whether to shuffle the data before each epoch, by default True.

Type:

bool, optional

jax_dataloader.jax_dataloader.JAXDataLoader.num_workers

The number of workers for loading data in parallel, by default 4.

Type:

int, optional

__iter__():

Resets the loader and returns itself for iteration.

__next__():

Returns the next batch of data and labels.

_parallel_process(batch_data, batch_labels):

Processes the batch in parallel.

_preprocess(sample):

Preprocesses each sample (normalizes it).