API Reference
This section provides detailed documentation for the JAX DataLoader API.
Core Classes
DataLoader
DataLoaderConfig
- class jax_dataloader.DataLoaderConfig(batch_size=32, shuffle=True, num_workers=4, pinned_memory=True, prefetch=True, multi_gpu=False, auto_batch_size=True, cache_size=1000, augmentation=False, progress_tracking=True)[source]
Bases:
object
Configuration for JAXDataLoader.
Examples
Basic configuration:
config = DataLoaderConfig( batch_size=32, shuffle=True, drop_last=True, num_workers=4, pin_memory=True )
Advanced configuration with memory management:
config = DataLoaderConfig( batch_size=32, memory_fraction=0.8, auto_batch_size=True, cache_size=1000, num_workers=4, prefetch_factor=2, persistent_workers=True )
Methods
- __init__(batch_size=32, shuffle=True, num_workers=4, pinned_memory=True, prefetch=True, multi_gpu=False, auto_batch_size=True, cache_size=1000, augmentation=False, progress_tracking=True)
-
batch_size:
int
= 32
-
shuffle:
bool
= True
-
num_workers:
int
= 4
-
pinned_memory:
bool
= True
-
prefetch:
bool
= True
-
multi_gpu:
bool
= False
-
auto_batch_size:
bool
= True
-
cache_size:
int
= 1000
-
augmentation:
bool
= False
-
progress_tracking:
bool
= True
- __init__(batch_size=32, shuffle=True, num_workers=4, pinned_memory=True, prefetch=True, multi_gpu=False, auto_batch_size=True, cache_size=1000, augmentation=False, progress_tracking=True)