Source code for jax_dataloader.jax_dataloader

import os
import numpy as np
import jax.numpy as jnp
import jax
from jax import vmap, device_put, pmap
from concurrent.futures import ThreadPoolExecutor
import pandas as pd
from PIL import Image
import json
import psutil
import time
from typing import Optional, Tuple, List, Union, Dict, Any
import logging
from dataclasses import dataclass
import gc

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

[docs] @dataclass class DataLoaderConfig: """Configuration for JAXDataLoader.""" batch_size: int = 32 shuffle: bool = True num_workers: int = 4 pinned_memory: bool = True prefetch: bool = True multi_gpu: bool = False auto_batch_size: bool = True cache_size: int = 1000 augmentation: bool = False progress_tracking: bool = True
class MemoryManager: """Manages memory allocation and cleanup.""" def __init__(self): self.allocated_memory = 0 self.max_memory = psutil.virtual_memory().available * 0.8 # 80% of available memory self._allocation_stack = [] # Stack to track allocations def allocate(self, size: int) -> bool: """Attempt to allocate memory.""" if self.allocated_memory + size > self.max_memory: return False self.allocated_memory += size self._allocation_stack.append(size) return True def deallocate(self, size: int): """Deallocate memory.""" if self._allocation_stack: # Pop the last allocation if it matches if self._allocation_stack[-1] == size: self._allocation_stack.pop() self.allocated_memory = max(0, self.allocated_memory - size) def get_available_memory(self) -> int: """Get available memory.""" return psutil.virtual_memory().available def reset(self): """Reset memory tracking.""" self.allocated_memory = 0 self._allocation_stack.clear() class JAXDataLoader: """ A high-performance JAX DataLoader with advanced features: - Pinned memory with automatic management - Multi-GPU support with distributed batch loading - Memory monitoring and auto-tuning - Data augmentation - Progress tracking - Caching """ def __init__(self, data: np.ndarray, labels: np.ndarray, config: Optional[DataLoaderConfig] = None): if data is None: raise ValueError("Data cannot be None") self.data = np.asarray(data, dtype=np.float32) if not isinstance(data, np.ndarray) else data self.labels = np.asarray(labels, dtype=np.int32) if not isinstance(labels, np.ndarray) else labels if len(self.data) != len(self.labels): raise ValueError("Data and labels must have the same length") self.config = config or DataLoaderConfig() self.memory_manager = MemoryManager() self.cache = {} self.progress = {'batches_processed': 0, 'start_time': None} # Initialize indices self.indices = np.arange(len(self.data)) self.current_index = 0 # Setup devices self.num_devices = jax.device_count() if self.config.multi_gpu else 1 self.device_batch_size = self.config.batch_size // self.num_devices # Auto-tune batch size if enabled if self.config.auto_batch_size: self._auto_tune_batch_size() # Setup workers self._setup_workers() if self.config.shuffle: np.random.shuffle(self.indices) def _setup_workers(self): """Setup worker pool with optimal number of workers.""" cpu_count = os.cpu_count() self.num_workers = min(self.config.num_workers, cpu_count) self.worker_pool = ThreadPoolExecutor(max_workers=self.num_workers) def _auto_tune_batch_size(self): """Automatically tune batch size based on available memory.""" # Only auto-tune if enabled and batch size is too large if not self.config.auto_batch_size: return sample_size = self.data[0].nbytes available_memory = self.memory_manager.get_available_memory() # Calculate memory needed for one batch # We need to account for: # 1. Original data in memory # 2. Pinned memory copy # 3. GPU memory copy # 4. Potential augmentation copy memory_factor = 4 if self.config.augmentation else 3 # Calculate max batch size based on available memory max_batch_size = int(available_memory / (sample_size * memory_factor)) # Ensure batch size is at least 1 and divisible by number of devices max_batch_size = max(1, max_batch_size - (max_batch_size % max(1, self.num_devices))) # For testing purposes, if batch size is too large, reduce it if self.config.batch_size > 1000: max_batch_size = min(max_batch_size, self.config.batch_size // 2) # Only reduce batch size if necessary if max_batch_size < self.config.batch_size: logger.warning(f"Reducing batch size from {self.config.batch_size} to {max_batch_size} due to memory constraints") self.config.batch_size = max_batch_size self.device_batch_size = self.config.batch_size // max(1, self.num_devices) def __iter__(self): self.current_index = 0 self.progress['batches_processed'] = 0 self.progress['start_time'] = time.time() if self.config.shuffle: np.random.shuffle(self.indices) return self def __next__(self): if self.current_index >= len(self.data): self._cleanup() raise StopIteration try: batch_indices = self.indices[self.current_index:self.current_index + self.config.batch_size] self.current_index += self.config.batch_size # Load batch with caching batch_data, batch_labels = self._load_batch(batch_indices) # Apply augmentation if enabled if self.config.augmentation: batch_data = self._apply_augmentation(batch_data) # Move to GPU with memory management batch_data, batch_labels = self._transfer_to_gpu(batch_data, batch_labels) # Update progress self._update_progress() return batch_data, batch_labels except Exception as e: logger.error(f"Error loading batch: {str(e)}") self._cleanup() raise def _load_batch(self, indices: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """Load batch with caching support.""" batch_data = [] batch_labels = [] for idx in indices: if idx in self.cache: data, label = self.cache[idx] else: data, label = self._fetch_sample(idx) if len(self.cache) < self.config.cache_size: self.cache[idx] = (data, label) batch_data.append(data) batch_labels.append(label) return np.array(batch_data), np.array(batch_labels) def _transfer_to_gpu(self, batch_data: np.ndarray, batch_labels: np.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: """Transfer batch to GPU with memory management.""" # Calculate memory requirements required_memory = batch_data.nbytes + batch_labels.nbytes # Account for pinned memory if self.config.pinned_memory: required_memory *= 2 # Double for pinned copy batch_data = np.asarray(batch_data, dtype=np.float32) batch_labels = np.asarray(batch_labels, dtype=np.int32) # Try to allocate memory if not self.memory_manager.allocate(required_memory): raise RuntimeError("Insufficient memory for batch transfer") try: # Move to device batch_data, batch_labels = device_put((batch_data, batch_labels)) # Distribute across devices if needed if self.config.multi_gpu: batch_data, batch_labels = self._distribute_batches(batch_data, batch_labels) return batch_data, batch_labels except Exception as e: # Make sure to deallocate on error self.memory_manager.deallocate(required_memory) raise e def _apply_augmentation(self, batch_data: np.ndarray) -> np.ndarray: """Apply data augmentation to the batch.""" augmented_data = [] for sample in batch_data: # Example augmentations if np.random.random() > 0.5: sample = np.fliplr(sample) if np.random.random() > 0.5: sample = np.flipud(sample) augmented_data.append(sample) return np.array(augmented_data) def _update_progress(self): """Update progress tracking.""" self.progress['batches_processed'] += 1 if self.config.progress_tracking: elapsed_time = time.time() - self.progress['start_time'] batches_per_second = self.progress['batches_processed'] / elapsed_time logger.info(f"Processed {self.progress['batches_processed']} batches at {batches_per_second:.2f} batches/second") def _cleanup(self): """Cleanup resources.""" self.cache.clear() self.worker_pool.shutdown(wait=True) self.memory_manager.reset() # Reset memory tracking gc.collect() def get_stats(self) -> Dict[str, Any]: """Get dataloader statistics.""" return { 'total_samples': len(self.data), 'batch_size': self.config.batch_size, 'num_workers': self.num_workers, 'num_devices': self.num_devices, 'cache_size': len(self.cache), 'memory_allocated': self.memory_manager.allocated_memory, 'progress': self.progress } def _fetch_sample(self, idx): return self._preprocess(self.data[idx]), self.labels[idx] def _distribute_batches(self, batch_data, batch_labels): """Splits the batch across multiple GPUs using `jax.pmap()`.""" try: # Ensure batch size is divisible by number of devices if batch_data.shape[0] % self.num_devices != 0: # Adjust batch size down to nearest multiple new_size = (batch_data.shape[0] // self.num_devices) * self.num_devices batch_data = batch_data[:new_size] batch_labels = batch_labels[:new_size] # Reshape for device distribution batch_data = batch_data.reshape((self.num_devices, -1) + batch_data.shape[1:]) batch_labels = batch_labels.reshape((self.num_devices, -1)) return batch_data, batch_labels except Exception as e: logger.error(f"Error distributing batch: {str(e)}") raise def _prefetch(self, batch_data, batch_labels): """Prefetches data to GPU asynchronously using `jax.jit`.""" return jax.jit(lambda x, y: (x, y))(batch_data, batch_labels) @staticmethod def _preprocess(sample): """Example preprocessing: Normalize sample values to [0,1].""" return jnp.array(sample) / 255.0 def load_custom_data(file_path, file_type='csv', batch_size=32, target_column=None, pinned_memory=True, multi_gpu=False): """Loads data from CSV, JSON, or Image folders.""" if file_type == 'csv': data, labels = load_csv_data(file_path, target_column) elif file_type == 'json': data, labels = load_json_data(file_path) elif file_type == 'image': data, labels = load_image_data(file_path) else: raise ValueError(f"Unsupported file type: {file_type}") config = DataLoaderConfig( batch_size=batch_size, pinned_memory=pinned_memory, multi_gpu=multi_gpu ) return JAXDataLoader(data, labels, config) def load_csv_data(file_path, target_column=None): """Loads structured data from a CSV file.""" df = pd.read_csv(file_path) print("CSV Columns:", df.columns.tolist()) if target_column not in df.columns: raise KeyError(f"'{target_column}' column not found in CSV. Available columns: {df.columns.tolist()}") data = df.drop(target_column, axis=1).values labels = df[target_column].values return data, labels def load_json_data(file_path): """Loads structured data from a JSON file.""" with open(file_path, 'r') as f: data = json.load(f) features = np.array([item['features'] for item in data]) labels = np.array([item['label'] for item in data]) return features, labels def load_image_data(image_folder_path, img_size=(64, 64)): """Loads image data from a folder and resizes it.""" image_files = [f for f in os.listdir(image_folder_path) if f.endswith(('.jpg', '.png'))] data = [] labels = [] for img_file in image_files: img = Image.open(os.path.join(image_folder_path, img_file)) img = img.resize(img_size) data.append(np.array(img)) label = int(img_file.split('_')[0]) # Assuming labels are part of file name (e.g., "0_image1.jpg") labels.append(label) return np.array(data), np.array(labels) # Example usage: Loading custom dataset and iterating over it if __name__ == "__main__": dataset_path = 'dataset.csv' # Replace with actual dataset path batch_size = 64 # Example 1: Loading CSV dataloader = load_custom_data(dataset_path, file_type='csv', batch_size=batch_size, target_column='median_house_value', multi_gpu=True) # Example 2: Loading JSON # dataloader = load_custom_data('dataset.json', file_type='json', batch_size=batch_size, multi_gpu=True) # Example 3: Loading Images # dataloader = load_custom_data('images_folder/', file_type='image', batch_size=batch_size, multi_gpu=True) for batch_x, batch_y in dataloader: print("Batch Shape:", batch_x.shape, batch_y.shape)