Usage Guide
This guide will help you get started with JAX DataLoader and show you how to use its various features effectively.
Basic Usage
The most basic usage of JAX DataLoader involves creating a DataLoader instance with your data and configuration:
from jax_dataloader import DataLoader, DataLoaderConfig
import jax.numpy as jnp
# Create sample data
data = jnp.arange(1000)
labels = jnp.arange(1000)
# Configure the dataloader
config = DataLoaderConfig(
batch_size=32,
shuffle=True,
drop_last=True
)
# Create the dataloader
dataloader = DataLoader(
data=data,
labels=labels,
config=config
)
# Iterate over batches
for batch_data, batch_labels in dataloader:
# Process your batch
print(f"Batch shape: {batch_data.shape}")
Configuration Options
The DataLoaderConfig class provides various options to customize your data loading:
config = DataLoaderConfig(
batch_size=32, # Size of each batch
shuffle=True, # Whether to shuffle data
drop_last=False, # Whether to drop the last incomplete batch
num_workers=4, # Number of worker processes
pin_memory=True, # Whether to pin memory for faster GPU transfer
prefetch_factor=2, # Number of batches to prefetch
persistent_workers=True # Whether to keep workers alive between epochs
)
Memory Management
JAX DataLoader provides several features for efficient memory management:
config = DataLoaderConfig(
batch_size=32,
memory_fraction=0.8, # Maximum fraction of available memory to use
auto_batch_size=True, # Automatically adjust batch size based on memory
cache_size=1000 # Number of batches to cache
)
dataloader = DataLoader(
data=data,
config=config
)
# Enable memory optimization
dataloader.optimize_memory()
Multi-GPU Support
To use multiple GPUs, you can configure the DataLoader to distribute data across devices:
import jax
from jax_dataloader import DataLoader, DataLoaderConfig
# Get available devices
devices = jax.devices()
config = DataLoaderConfig(
batch_size=32,
num_devices=len(devices), # Number of devices to use
device_map="auto" # Automatic device mapping
)
dataloader = DataLoader(
data=data,
config=config
)
# Data will be automatically distributed across devices
for batch in dataloader:
# batch will be a tuple of (data, device_id)
data, device_id = batch
Progress Tracking
You can track the progress of data loading using the built-in progress bar:
from jax_dataloader import DataLoader, DataLoaderConfig
import jax.numpy as jnp
data = jnp.arange(1000)
config = DataLoaderConfig(
batch_size=32,
show_progress=True, # Enable progress bar
progress_interval=0.1 # Update interval in seconds
)
dataloader = DataLoader(
data=data,
config=config
)
for batch in dataloader:
# Progress bar will show automatically
pass
Data Augmentation
JAX DataLoader supports data augmentation through the transform system:
from jax_dataloader import DataLoader, DataLoaderConfig
import jax.numpy as jnp
import jax.random as random
def augment_fn(batch, key):
# Example augmentation: add random noise
noise = random.normal(key, batch.shape) * 0.1
return batch + noise
config = DataLoaderConfig(
batch_size=32,
transform=augment_fn, # Apply augmentation function
transform_key=random.PRNGKey(0) # Random key for augmentation
)
dataloader = DataLoader(
data=data,
config=config
)
Loading Different Data Formats
JAX DataLoader supports various data formats:
CSV Files:
from jax_dataloader import DataLoader, DataLoaderConfig
from jax_dataloader.data import CSVLoader
loader = CSVLoader("data.csv")
config = DataLoaderConfig(batch_size=32)
dataloader = DataLoader(loader=loader, config=config)
JSON Files:
from jax_dataloader import DataLoader, DataLoaderConfig
from jax_dataloader.data import JSONLoader
loader = JSONLoader("data.json")
config = DataLoaderConfig(batch_size=32)
dataloader = DataLoader(loader=loader, config=config)
Image Files:
from jax_dataloader import DataLoader, DataLoaderConfig
from jax_dataloader.data import ImageLoader
loader = ImageLoader("image_directory")
config = DataLoaderConfig(
batch_size=32,
image_size=(224, 224) # Resize images to 224x224
)
dataloader = DataLoader(loader=loader, config=config)
Best Practices
Batch Size Selection - Start with a small batch size and increase based on available memory - Use auto_batch_size=True for automatic optimization - Consider using gradient accumulation for large models
Memory Management - Enable pin_memory=True when using GPU - Use memory_fraction to limit memory usage - Enable caching for frequently accessed data
Performance Optimization - Use num_workers > 0 for parallel data loading - Enable persistent_workers=True for better performance - Use prefetch_factor to overlap data loading with computation
Error Handling - Always wrap data loading in try-except blocks - Use the built-in error handling features - Monitor memory usage and adjust configuration accordingly
For more advanced usage and examples, check out the Examples guide.