Welcome to JAX DataLoader’s documentation!
JAX DataLoader is a high-performance data loading library for JAX applications, providing efficient data loading, batching, and preprocessing capabilities.
Features
Efficient data loading with automatic batching
Memory management and optimization
Multi-GPU support
Progress tracking
Automatic batch size tuning
Support for various data formats (CSV, JSON, Images)
Data augmentation capabilities
Caching system for improved performance
Installation
You can install JAX DataLoader using pip:
pip install jax-dataloaders
For development installation:
git clone https://github.com/carrycooldude/JAX-Dataloader.git
cd JAX-Dataloader
pip install -e .
Quick Start
Here’s a simple example of how to use JAX DataLoader:
from jax_dataloader import DataLoader, DataLoaderConfig
import jax.numpy as jnp
# Create some sample data
data = jnp.arange(1000)
labels = jnp.arange(1000)
# Configure the dataloader
config = DataLoaderConfig(
batch_size=32,
shuffle=True,
drop_last=True
)
# Create the dataloader
dataloader = DataLoader(
data=data,
labels=labels,
config=config
)
# Iterate over batches
for batch_data, batch_labels in dataloader:
# Process your batch
print(f"Batch shape: {batch_data.shape}")
For more detailed examples and usage instructions, see the Usage Guide guide.