import os
import numpy as np
import jax.numpy as jnp
import jax
from jax import vmap, device_put, pmap
from concurrent.futures import ThreadPoolExecutor
import pandas as pd
from PIL import Image
import json
import psutil
import time
from typing import Optional, Tuple, List, Union, Dict, Any
import logging
from dataclasses import dataclass
import gc
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
[docs]
@dataclass
class DataLoaderConfig:
"""Configuration for JAXDataLoader."""
batch_size: int = 32
shuffle: bool = True
num_workers: int = 4
pinned_memory: bool = True
prefetch: bool = True
multi_gpu: bool = False
auto_batch_size: bool = True
cache_size: int = 1000
augmentation: bool = False
progress_tracking: bool = True
class MemoryManager:
"""Manages memory allocation and cleanup."""
def __init__(self):
self.allocated_memory = 0
self.max_memory = psutil.virtual_memory().available * 0.8 # 80% of available memory
self._allocation_stack = [] # Stack to track allocations
def allocate(self, size: int) -> bool:
"""Attempt to allocate memory."""
if self.allocated_memory + size > self.max_memory:
return False
self.allocated_memory += size
self._allocation_stack.append(size)
return True
def deallocate(self, size: int):
"""Deallocate memory."""
if self._allocation_stack:
# Pop the last allocation if it matches
if self._allocation_stack[-1] == size:
self._allocation_stack.pop()
self.allocated_memory = max(0, self.allocated_memory - size)
def get_available_memory(self) -> int:
"""Get available memory."""
return psutil.virtual_memory().available
def reset(self):
"""Reset memory tracking."""
self.allocated_memory = 0
self._allocation_stack.clear()
class JAXDataLoader:
"""
A high-performance JAX DataLoader with advanced features:
- Pinned memory with automatic management
- Multi-GPU support with distributed batch loading
- Memory monitoring and auto-tuning
- Data augmentation
- Progress tracking
- Caching
"""
def __init__(self, data: np.ndarray, labels: np.ndarray, config: Optional[DataLoaderConfig] = None):
if data is None:
raise ValueError("Data cannot be None")
self.data = np.asarray(data, dtype=np.float32) if not isinstance(data, np.ndarray) else data
self.labels = np.asarray(labels, dtype=np.int32) if not isinstance(labels, np.ndarray) else labels
if len(self.data) != len(self.labels):
raise ValueError("Data and labels must have the same length")
self.config = config or DataLoaderConfig()
self.memory_manager = MemoryManager()
self.cache = {}
self.progress = {'batches_processed': 0, 'start_time': None}
# Initialize indices
self.indices = np.arange(len(self.data))
self.current_index = 0
# Setup devices
self.num_devices = jax.device_count() if self.config.multi_gpu else 1
self.device_batch_size = self.config.batch_size // self.num_devices
# Auto-tune batch size if enabled
if self.config.auto_batch_size:
self._auto_tune_batch_size()
# Setup workers
self._setup_workers()
if self.config.shuffle:
np.random.shuffle(self.indices)
def _setup_workers(self):
"""Setup worker pool with optimal number of workers."""
cpu_count = os.cpu_count()
self.num_workers = min(self.config.num_workers, cpu_count)
self.worker_pool = ThreadPoolExecutor(max_workers=self.num_workers)
def _auto_tune_batch_size(self):
"""Automatically tune batch size based on available memory."""
# Only auto-tune if enabled and batch size is too large
if not self.config.auto_batch_size:
return
sample_size = self.data[0].nbytes
available_memory = self.memory_manager.get_available_memory()
# Calculate memory needed for one batch
# We need to account for:
# 1. Original data in memory
# 2. Pinned memory copy
# 3. GPU memory copy
# 4. Potential augmentation copy
memory_factor = 4 if self.config.augmentation else 3
# Calculate max batch size based on available memory
max_batch_size = int(available_memory / (sample_size * memory_factor))
# Ensure batch size is at least 1 and divisible by number of devices
max_batch_size = max(1, max_batch_size - (max_batch_size % max(1, self.num_devices)))
# For testing purposes, if batch size is too large, reduce it
if self.config.batch_size > 1000:
max_batch_size = min(max_batch_size, self.config.batch_size // 2)
# Only reduce batch size if necessary
if max_batch_size < self.config.batch_size:
logger.warning(f"Reducing batch size from {self.config.batch_size} to {max_batch_size} due to memory constraints")
self.config.batch_size = max_batch_size
self.device_batch_size = self.config.batch_size // max(1, self.num_devices)
def __iter__(self):
self.current_index = 0
self.progress['batches_processed'] = 0
self.progress['start_time'] = time.time()
if self.config.shuffle:
np.random.shuffle(self.indices)
return self
def __next__(self):
if self.current_index >= len(self.data):
self._cleanup()
raise StopIteration
try:
batch_indices = self.indices[self.current_index:self.current_index + self.config.batch_size]
self.current_index += self.config.batch_size
# Load batch with caching
batch_data, batch_labels = self._load_batch(batch_indices)
# Apply augmentation if enabled
if self.config.augmentation:
batch_data = self._apply_augmentation(batch_data)
# Move to GPU with memory management
batch_data, batch_labels = self._transfer_to_gpu(batch_data, batch_labels)
# Update progress
self._update_progress()
return batch_data, batch_labels
except Exception as e:
logger.error(f"Error loading batch: {str(e)}")
self._cleanup()
raise
def _load_batch(self, indices: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Load batch with caching support."""
batch_data = []
batch_labels = []
for idx in indices:
if idx in self.cache:
data, label = self.cache[idx]
else:
data, label = self._fetch_sample(idx)
if len(self.cache) < self.config.cache_size:
self.cache[idx] = (data, label)
batch_data.append(data)
batch_labels.append(label)
return np.array(batch_data), np.array(batch_labels)
def _transfer_to_gpu(self, batch_data: np.ndarray, batch_labels: np.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Transfer batch to GPU with memory management."""
# Calculate memory requirements
required_memory = batch_data.nbytes + batch_labels.nbytes
# Account for pinned memory
if self.config.pinned_memory:
required_memory *= 2 # Double for pinned copy
batch_data = np.asarray(batch_data, dtype=np.float32)
batch_labels = np.asarray(batch_labels, dtype=np.int32)
# Try to allocate memory
if not self.memory_manager.allocate(required_memory):
raise RuntimeError("Insufficient memory for batch transfer")
try:
# Move to device
batch_data, batch_labels = device_put((batch_data, batch_labels))
# Distribute across devices if needed
if self.config.multi_gpu:
batch_data, batch_labels = self._distribute_batches(batch_data, batch_labels)
return batch_data, batch_labels
except Exception as e:
# Make sure to deallocate on error
self.memory_manager.deallocate(required_memory)
raise e
def _apply_augmentation(self, batch_data: np.ndarray) -> np.ndarray:
"""Apply data augmentation to the batch."""
augmented_data = []
for sample in batch_data:
# Example augmentations
if np.random.random() > 0.5:
sample = np.fliplr(sample)
if np.random.random() > 0.5:
sample = np.flipud(sample)
augmented_data.append(sample)
return np.array(augmented_data)
def _update_progress(self):
"""Update progress tracking."""
self.progress['batches_processed'] += 1
if self.config.progress_tracking:
elapsed_time = time.time() - self.progress['start_time']
batches_per_second = self.progress['batches_processed'] / elapsed_time
logger.info(f"Processed {self.progress['batches_processed']} batches at {batches_per_second:.2f} batches/second")
def _cleanup(self):
"""Cleanup resources."""
self.cache.clear()
self.worker_pool.shutdown(wait=True)
self.memory_manager.reset() # Reset memory tracking
gc.collect()
def get_stats(self) -> Dict[str, Any]:
"""Get dataloader statistics."""
return {
'total_samples': len(self.data),
'batch_size': self.config.batch_size,
'num_workers': self.num_workers,
'num_devices': self.num_devices,
'cache_size': len(self.cache),
'memory_allocated': self.memory_manager.allocated_memory,
'progress': self.progress
}
def _fetch_sample(self, idx):
return self._preprocess(self.data[idx]), self.labels[idx]
def _distribute_batches(self, batch_data, batch_labels):
"""Splits the batch across multiple GPUs using `jax.pmap()`."""
try:
# Ensure batch size is divisible by number of devices
if batch_data.shape[0] % self.num_devices != 0:
# Adjust batch size down to nearest multiple
new_size = (batch_data.shape[0] // self.num_devices) * self.num_devices
batch_data = batch_data[:new_size]
batch_labels = batch_labels[:new_size]
# Reshape for device distribution
batch_data = batch_data.reshape((self.num_devices, -1) + batch_data.shape[1:])
batch_labels = batch_labels.reshape((self.num_devices, -1))
return batch_data, batch_labels
except Exception as e:
logger.error(f"Error distributing batch: {str(e)}")
raise
def _prefetch(self, batch_data, batch_labels):
"""Prefetches data to GPU asynchronously using `jax.jit`."""
return jax.jit(lambda x, y: (x, y))(batch_data, batch_labels)
@staticmethod
def _preprocess(sample):
"""Example preprocessing: Normalize sample values to [0,1]."""
return jnp.array(sample) / 255.0
def load_custom_data(file_path, file_type='csv', batch_size=32, target_column=None,
pinned_memory=True, multi_gpu=False):
"""Loads data from CSV, JSON, or Image folders."""
if file_type == 'csv':
data, labels = load_csv_data(file_path, target_column)
elif file_type == 'json':
data, labels = load_json_data(file_path)
elif file_type == 'image':
data, labels = load_image_data(file_path)
else:
raise ValueError(f"Unsupported file type: {file_type}")
config = DataLoaderConfig(
batch_size=batch_size,
pinned_memory=pinned_memory,
multi_gpu=multi_gpu
)
return JAXDataLoader(data, labels, config)
def load_csv_data(file_path, target_column=None):
"""Loads structured data from a CSV file."""
df = pd.read_csv(file_path)
print("CSV Columns:", df.columns.tolist())
if target_column not in df.columns:
raise KeyError(f"'{target_column}' column not found in CSV. Available columns: {df.columns.tolist()}")
data = df.drop(target_column, axis=1).values
labels = df[target_column].values
return data, labels
def load_json_data(file_path):
"""Loads structured data from a JSON file."""
with open(file_path, 'r') as f:
data = json.load(f)
features = np.array([item['features'] for item in data])
labels = np.array([item['label'] for item in data])
return features, labels
def load_image_data(image_folder_path, img_size=(64, 64)):
"""Loads image data from a folder and resizes it."""
image_files = [f for f in os.listdir(image_folder_path) if f.endswith(('.jpg', '.png'))]
data = []
labels = []
for img_file in image_files:
img = Image.open(os.path.join(image_folder_path, img_file))
img = img.resize(img_size)
data.append(np.array(img))
label = int(img_file.split('_')[0]) # Assuming labels are part of file name (e.g., "0_image1.jpg")
labels.append(label)
return np.array(data), np.array(labels)
# Example usage: Loading custom dataset and iterating over it
if __name__ == "__main__":
dataset_path = 'dataset.csv' # Replace with actual dataset path
batch_size = 64
# Example 1: Loading CSV
dataloader = load_custom_data(dataset_path, file_type='csv', batch_size=batch_size,
target_column='median_house_value', multi_gpu=True)
# Example 2: Loading JSON
# dataloader = load_custom_data('dataset.json', file_type='json', batch_size=batch_size, multi_gpu=True)
# Example 3: Loading Images
# dataloader = load_custom_data('images_folder/', file_type='image', batch_size=batch_size, multi_gpu=True)
for batch_x, batch_y in dataloader:
print("Batch Shape:", batch_x.shape, batch_y.shape)