Installation Guide
This guide will help you install JAX DataLoader and its dependencies.
Requirements
Python 3.7 or higher
JAX and JAXlib
NumPy
Optional: PyTorch (for some data format support)
Basic Installation
The simplest way to install JAX DataLoader is using pip:
pip install jax-dataloaders
This will install the latest stable version from PyPI along with its core dependencies.
Development Installation
If you want to contribute to the project or need the latest features, you can install from source:
# Clone the repository
git clone https://github.com/carrycooldude/JAX-Dataloader.git
cd JAX-Dataloader
# Install in editable mode
pip install -e .
# Install development dependencies
pip install -e ".[dev]"
Installing with Optional Dependencies
JAX DataLoader has several optional dependencies that you can install based on your needs:
For CSV support: .. code-block:: bash
pip install “jax-dataloaders[csv]”
For JSON support: .. code-block:: bash
pip install “jax-dataloaders[json]”
For image support: .. code-block:: bash
pip install “jax-dataloaders[image]”
For all optional dependencies: .. code-block:: bash
pip install “jax-dataloaders[all]”
GPU Support
To use JAX DataLoader with GPU support, you’ll need to install the appropriate JAX version for your CUDA version. Follow the official JAX installation guide for detailed instructions.
For example, for CUDA 11.8:
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Verifying Installation
You can verify your installation by running a simple test:
from jax_dataloader import DataLoader, DataLoaderConfig
import jax.numpy as jnp
# Create test data
data = jnp.arange(100)
config = DataLoaderConfig(batch_size=10)
dataloader = DataLoader(data=data, config=config)
# Try iterating
for batch in dataloader:
print(f"Batch shape: {batch.shape}")
break
If you see the batch shape printed without any errors, your installation is successful!
Troubleshooting
Common issues and their solutions:
ImportError: No module named ‘jax’ Make sure you have JAX installed correctly. Try reinstalling JAX following the official guide.
CUDA errors Ensure your CUDA version matches the JAX version you installed. Check the JAX installation guide for compatibility.
Memory issues If you encounter memory errors, try reducing the batch size or enabling memory optimization features.
For more help, please check the Usage Guide guide or open an issue on the GitHub repository.