Coverage for src/seqrule/generators.py: 15%
211 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-26 10:19 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-26 10:19 -0600
1"""
2Utilities for generating sequences.
4This module provides functions for generating sequences from domains of objects,
5with support for bounded generation, filtering, constraints, and prediction.
6"""
8from typing import List, Optional, Callable, Dict, Any, Set, Tuple, Iterator
9from dataclasses import dataclass
10from collections import defaultdict
11import random
12import itertools
14from .core import AbstractObject, Sequence, FormalRule
16@dataclass
17class Constraint:
18 """Represents a constraint on sequence generation."""
19 property_name: str
20 condition: Callable[[Any], bool]
21 description: str = ""
23 def __call__(self, value: Any) -> bool:
24 return self.condition(value)
26class PropertyPattern:
27 def __init__(self, property_name: str, values: List[Any], is_cyclic: bool = False):
28 self.property_name = property_name
29 self.values = values
30 self.is_cyclic = is_cyclic
32 def _get_property_value(self, obj: Any) -> Any:
33 """Get property value from either AbstractObject or dict."""
34 if hasattr(obj, 'properties'):
35 return obj.properties[self.property_name]
36 elif hasattr(obj, '__getitem__'):
37 return obj[self.property_name]
38 return getattr(obj, self.property_name)
40 def matches(self, sequence: List[Dict[str, Any]], start_idx: int = 0) -> bool:
41 """Check if the sequence matches the pattern starting from start_idx."""
42 if not sequence:
43 return True
45 # For cyclic patterns, we need to check each position relative to the pattern start
46 if self.is_cyclic:
47 # Special case for single-value patterns
48 if len(self.values) == 1:
49 expected_value = self.values[0]
50 return all(self._get_property_value(obj) == expected_value
51 for obj in sequence[start_idx:])
53 # Find where the pattern should start
54 if start_idx == 0:
55 # For the beginning of the sequence, we must start with the first value
56 if self._get_property_value(sequence[0]) != self.values[0]:
57 return False
58 pattern_pos = 1
59 check_from = 1
60 else:
61 # For other positions, find where we are in the pattern
62 prev_value = self._get_property_value(sequence[start_idx - 1])
63 try:
64 prev_pos = self.values.index(prev_value)
65 pattern_pos = (prev_pos + 1) % len(self.values)
66 check_from = start_idx
67 except ValueError:
68 return False
70 # Check the rest of the sequence follows the pattern
71 for i in range(check_from, len(sequence)):
72 expected_value = self.values[pattern_pos]
73 actual_value = self._get_property_value(sequence[i])
74 if actual_value != expected_value:
75 return False
76 pattern_pos = (pattern_pos + 1) % len(self.values)
78 else:
79 # For non-cyclic patterns, check exact match
80 if start_idx + len(self.values) > len(sequence):
81 return False
83 # For non-cyclic patterns, sequence must be exactly the same length as pattern
84 if len(sequence) - start_idx > len(self.values):
85 return False
87 for i, value in enumerate(self.values):
88 actual_value = self._get_property_value(sequence[start_idx + i])
89 if actual_value != value:
90 return False
92 return True
94 def get_next_value(self, sequence: List[Dict[str, Any]]) -> Any:
95 """Predict the next value in the pattern based on the current sequence."""
96 if not sequence:
97 return self.values[0]
99 if self.is_cyclic:
100 # Get the last value in the sequence
101 last_value = self._get_property_value(sequence[-1])
102 # Find its position in the pattern
103 try:
104 current_pos = self.values.index(last_value)
105 # Return the next value in the cycle
106 return self.values[(current_pos + 1) % len(self.values)]
107 except ValueError:
108 # If the last value isn't in the pattern, start from beginning
109 return self.values[0]
111 last_pos = min(len(sequence), len(self.values) - 1)
112 return self.values[last_pos]
114class ConstrainedGenerator:
115 """Generates sequences satisfying multiple constraints."""
117 def __init__(self, domain: List[Dict[str, Any]]):
118 self.domain = domain
119 self.constraints: List[Callable[[List[Dict[str, Any]]], bool]] = []
120 self.patterns: List[PropertyPattern] = []
122 def add_constraint(self, constraint: Callable[[List[Dict[str, Any]]], bool]) -> 'ConstrainedGenerator':
123 """Add a constraint and return self for chaining."""
124 self.constraints.append(constraint)
125 return self
127 def add_pattern(self, pattern: PropertyPattern) -> 'ConstrainedGenerator':
128 """Add a pattern and return self for chaining."""
129 self.patterns.append(pattern)
130 return self
132 def _satisfies_constraints(self, sequence: List[Dict[str, Any]]) -> bool:
133 return all(constraint(sequence) for constraint in self.constraints)
135 def _satisfies_patterns(self, sequence: List[Dict[str, Any]], start_idx: int = 0) -> bool:
136 return all(pattern.matches(sequence, start_idx) for pattern in self.patterns)
138 def predict_next(self, sequence: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
139 """Predict possible next items that maintain all patterns and constraints."""
140 predictions = []
142 # Get expected values from patterns
143 expected_values = {}
144 for pattern in self.patterns:
145 next_val = pattern.get_next_value(sequence)
146 expected_values[pattern.property_name] = next_val
148 # Filter domain based on expected values
149 for item in self.domain:
150 matches_patterns = True
151 for pattern in self.patterns:
152 expected_val = expected_values[pattern.property_name]
153 actual_val = pattern._get_property_value(item)
154 if actual_val != expected_val:
155 matches_patterns = False
156 break
158 if matches_patterns:
159 test_sequence = sequence + [item]
160 if self._satisfies_constraints(test_sequence):
161 predictions.append(item)
163 return predictions
165 def generate(self, max_length: int = 10) -> Iterator[List[Dict[str, Any]]]:
166 """Generate sequences that satisfy all constraints and patterns."""
167 def extend_sequence(current: List[Dict[str, Any]], start_with_first: bool = True) -> Iterator[List[Dict[str, Any]]]:
168 if not current and start_with_first:
169 # Start with sequences that begin with the first value of each cyclic pattern
170 initial_items = []
171 cyclic_patterns = [p for p in self.patterns if p.is_cyclic]
173 if cyclic_patterns:
174 # Get items that match the first value of all cyclic patterns
175 candidates = self.domain
176 for pattern in cyclic_patterns:
177 first_value = pattern.values[0]
178 candidates = [
179 item for item in candidates
180 if pattern._get_property_value(item) == first_value
181 ]
182 initial_items = candidates
184 if initial_items:
185 for item in initial_items:
186 if self._satisfies_constraints([item]) and self._satisfies_patterns([item]):
187 yield from extend_sequence([item], False)
188 return
189 else:
190 # If no initial items match cyclic patterns, start with empty sequence
191 yield []
192 return
194 yield current
196 if len(current) >= max_length:
197 return
199 next_items = self.predict_next(current)
200 for item in next_items:
201 new_sequence = current + [item]
202 if self._satisfies_patterns(new_sequence):
203 yield from extend_sequence(new_sequence, False)
205 # Start with empty sequence if no cyclic patterns
206 if not any(pattern.is_cyclic for pattern in self.patterns):
207 yield from extend_sequence([], False)
208 else:
209 yield from extend_sequence([])
211def generate_counter_examples(rule: FormalRule, domain: List[AbstractObject],
212 max_length: int, max_attempts: int = 1000) -> List[Sequence]:
213 """Generate sequences that violate the given rule."""
214 counter_examples = []
215 attempts = 0
217 while attempts < max_attempts and len(counter_examples) < 5:
218 # Generate random sequence
219 length = random.randint(1, max_length)
220 sequence = [random.choice(domain) for _ in range(length)]
222 # Check if it violates the rule
223 if not rule(sequence):
224 # Check if it's minimal
225 is_minimal = True
226 for i in range(len(sequence)):
227 subsequence = sequence[:i] + sequence[i+1:]
228 if not rule(subsequence):
229 is_minimal = False
230 break
231 if is_minimal:
232 counter_examples.append(sequence)
233 attempts += 1
235 return counter_examples
237def generate_sequences(domain, max_length=10, filter_rule=None):
238 """Generate all valid sequences up to max_length from the given domain.
240 Args:
241 domain: List of objects to generate sequences from
242 max_length: Maximum length of sequences to generate
243 filter_rule: Optional function to filter sequences
245 Returns:
246 List of valid sequences
247 """
248 if max_length < 0:
249 raise ValueError("max_length must be non-negative")
251 sequences = [[]] # Start with empty sequence
252 if max_length == 0:
253 return sequences
255 for length in range(1, max_length + 1):
256 for seq in itertools.product(domain, repeat=length):
257 sequences.append(list(seq))
259 if filter_rule:
260 try:
261 sequences = [seq for seq in sequences if filter_rule(seq)]
262 except Exception as e:
263 # Log the error but continue with empty result
264 print(f"Error applying filter: {str(e)}")
265 return []
267 return sequences
269class LazyGenerator:
270 """Generates sequences lazily one at a time."""
272 def __init__(self, domain, max_length=10, filter_rule=None):
273 self.domain = domain
274 self.max_length = max_length
275 self.filter_rule = filter_rule
276 self.current_length = 1 # Always start with length 1
277 self.current_index = 0
278 self.total_sequences = sum(len(domain) ** i for i in range(max_length + 1))
279 self.empty_returned = False
281 def __call__(self):
282 """Generate the next sequence."""
283 # Handle empty domain case first
284 if not self.domain:
285 if not self.empty_returned:
286 self.empty_returned = True
287 return []
288 return None
290 # Handle empty sequence only if max_length=1 and no filter
291 if not self.empty_returned and not self.filter_rule and self.max_length == 1:
292 self.empty_returned = True
293 return []
295 while self.current_length <= self.max_length:
296 # Calculate base-n representation of current index
297 indices = []
298 remaining = self.current_index
299 for _ in range(self.current_length):
300 indices.append(remaining % len(self.domain))
301 remaining //= len(self.domain)
303 # Reverse indices to get correct order
304 indices.reverse()
305 sequence = [self.domain[i] for i in indices]
307 # Update state for next call
308 self.current_index += 1
309 if self.current_index >= len(self.domain) ** self.current_length:
310 self.current_index = 0
311 self.current_length += 1
313 try:
314 if not self.filter_rule or self.filter_rule(sequence):
315 return sequence
316 except Exception as e:
317 print(f"Error applying filter: {str(e)}")
318 return None
320 return None
322 def __iter__(self):
323 """Make the generator iterable."""
324 while True:
325 sequence = self()
326 if sequence is None:
327 break
328 yield sequence
330def generate_lazy(domain, max_length=10, filter_rule=None):
331 """Create a lazy sequence generator.
333 Args:
334 domain: List of objects to generate sequences from
335 max_length: Maximum length of sequences to generate
336 filter_rule: Optional function to filter sequences
338 Returns:
339 LazyGenerator instance
340 """
341 return LazyGenerator(domain, max_length, filter_rule)