Examples

This section provides detailed examples of using JAX DataLoader in various scenarios.

Basic Examples

Simple Data Loading

from jax_dataloader import DataLoader, DataLoaderConfig
import jax.numpy as jnp

# Create sample data
data = jnp.arange(1000)
labels = jnp.arange(1000)

# Configure the dataloader
config = DataLoaderConfig(
    batch_size=32,
    shuffle=True,
    drop_last=True
)

# Create the dataloader
dataloader = DataLoader(
    data=data,
    labels=labels,
    config=config
)

# Iterate over batches
for batch_data, batch_labels in dataloader:
    print(f"Batch shape: {batch_data.shape}")

Loading from Files

CSV Data

from jax_dataloader import DataLoader, DataLoaderConfig
from jax_dataloader.data import CSVLoader

# Create CSV loader
loader = CSVLoader(
    "data.csv",
    target_column="label",
    feature_columns=["feature1", "feature2"]
)

# Configure the dataloader
config = DataLoaderConfig(
    batch_size=32,
    shuffle=True
)

# Create the dataloader
dataloader = DataLoader(
    loader=loader,
    config=config
)

# Iterate over batches
for features, labels in dataloader:
    print(f"Features shape: {features.shape}")
    print(f"Labels shape: {labels.shape}")

JSON Data

from jax_dataloader import DataLoader, DataLoaderConfig
from jax_dataloader.data import JSONLoader

# Create JSON loader
loader = JSONLoader(
    "data.json",
    data_key="features",
    label_key="labels"
)

# Configure the dataloader
config = DataLoaderConfig(
    batch_size=32,
    shuffle=True
)

# Create the dataloader
dataloader = DataLoader(
    loader=loader,
    config=config
)

# Iterate over batches
for data, labels in dataloader:
    print(f"Data shape: {data.shape}")
    print(f"Labels shape: {labels.shape}")

Image Data

from jax_dataloader import DataLoader, DataLoaderConfig
from jax_dataloader.data import ImageLoader

# Create image loader
loader = ImageLoader(
    "image_directory",
    image_size=(224, 224),
    normalize=True
)

# Configure the dataloader
config = DataLoaderConfig(
    batch_size=32,
    shuffle=True,
    num_workers=4
)

# Create the dataloader
dataloader = DataLoader(
    loader=loader,
    config=config
)

# Iterate over batches
for images, labels in dataloader:
    print(f"Images shape: {images.shape}")
    print(f"Labels shape: {labels.shape}")

Advanced Examples

Multi-GPU Training

import jax
from jax_dataloader import DataLoader, DataLoaderConfig
import jax.numpy as jnp

# Get available devices
devices = jax.devices()

# Create sample data
data = jnp.arange(10000)
labels = jnp.arange(10000)

# Configure for multi-GPU
config = DataLoaderConfig(
    batch_size=32,
    num_devices=len(devices),
    device_map="auto",
    pin_memory=True
)

# Create the dataloader
dataloader = DataLoader(
    data=data,
    labels=labels,
    config=config
)

# Training loop
for batch_data, batch_labels in dataloader:
    # batch_data and batch_labels are already on the correct devices
    # Your training code here
    pass

Data Augmentation

from jax_dataloader import DataLoader, DataLoaderConfig
import jax.numpy as jnp
import jax.random as random

# Define augmentation function
def augment_fn(batch, key):
    # Add random noise
    noise = random.normal(key, batch.shape) * 0.1
    augmented = batch + noise

    # Random rotation
    angle = random.uniform(key, minval=-0.1, maxval=0.1)
    augmented = jnp.rot90(augmented, k=int(angle * 10))

    return augmented

# Create sample data
data = jnp.arange(1000).reshape(100, 10, 10)
labels = jnp.arange(100)

# Configure with augmentation
config = DataLoaderConfig(
    batch_size=32,
    transform=augment_fn,
    transform_key=random.PRNGKey(0)
)

# Create the dataloader
dataloader = DataLoader(
    data=data,
    labels=labels,
    config=config
)

# Iterate over augmented batches
for batch_data, batch_labels in dataloader:
    print(f"Augmented batch shape: {batch_data.shape}")

Memory Management

from jax_dataloader import DataLoader, DataLoaderConfig
import jax.numpy as jnp

# Create large dataset
data = jnp.arange(1000000)
labels = jnp.arange(1000000)

# Configure for memory efficiency
config = DataLoaderConfig(
    batch_size=32,
    memory_fraction=0.8,
    auto_batch_size=True,
    cache_size=1000,
    num_workers=4
)

# Create the dataloader
dataloader = DataLoader(
    data=data,
    labels=labels,
    config=config
)

# Enable memory optimization
dataloader.optimize_memory()

# Iterate over memory-efficient batches
for batch_data, batch_labels in dataloader:
    print(f"Batch shape: {batch_data.shape}")
    print(f"Memory usage: {dataloader.memory_manager.get_memory_usage()}")

Progress Tracking

from jax_dataloader import DataLoader, DataLoaderConfig
import jax.numpy as jnp
import time

# Create sample data
data = jnp.arange(1000)
labels = jnp.arange(1000)

# Configure with progress tracking
config = DataLoaderConfig(
    batch_size=32,
    show_progress=True,
    progress_interval=0.1
)

# Create the dataloader
dataloader = DataLoader(
    data=data,
    labels=labels,
    config=config
)

# Training loop with progress tracking
start_time = time.time()
for batch_data, batch_labels in dataloader:
    # Simulate processing time
    time.sleep(0.1)

    # Progress bar will show automatically
    print(f"Processing batch...")

end_time = time.time()
print(f"Total time: {end_time - start_time:.2f} seconds")

Error Handling

from jax_dataloader import DataLoader, DataLoaderConfig
from jax_dataloader.exceptions import DataLoaderError
import jax.numpy as jnp

# Create sample data
data = jnp.arange(1000)
labels = jnp.arange(1000)

# Configure the dataloader
config = DataLoaderConfig(
    batch_size=32,
    error_handling=True
)

# Create the dataloader
dataloader = DataLoader(
    data=data,
    labels=labels,
    config=config
)

# Training loop with error handling
try:
    for batch_data, batch_labels in dataloader:
        try:
            # Your processing code here
            pass
        except Exception as e:
            print(f"Error processing batch: {e}")
            continue
except DataLoaderError as e:
    print(f"DataLoader error: {e}")
except Exception as e:
    print(f"Unexpected error: {e}")