Coverage for src/seqrule/generators/constrained.py: 34%
58 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-27 10:39 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-27 10:39 -0600
1"""
2Constrained sequence generation.
4This module provides functionality for generating sequences
5that satisfy a set of constraints.
6"""
8import random
9from dataclasses import dataclass
10from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar, Union
12from ..core import AbstractObject, Sequence
13from .patterns import PropertyPattern
15T = TypeVar("T")
18@dataclass
19class GeneratorConfig:
20 """Configuration options for constrained generators."""
22 max_attempts: int = 100
23 randomize_candidates: bool = True
24 max_candidates_per_step: int = 10
25 backtracking_enabled: bool = False
28class ConstrainedGenerator:
29 """Generator that produces sequences satisfying constraints and patterns."""
31 def __init__(
32 self,
33 domain: List[Union[Dict[str, Any], AbstractObject]],
34 config: Optional[GeneratorConfig] = None,
35 ):
36 """
37 Initialize with a domain of possible objects.
39 Args:
40 domain: List of objects that can be included in the sequence
41 config: Optional configuration settings for the generator
42 """
43 # Normalize domain to ensure all items are AbstractObjects
44 self.domain = [
45 obj if isinstance(obj, AbstractObject) else AbstractObject(**obj)
46 for obj in domain
47 ]
48 self.constraints: List[Callable[[Sequence], bool]] = []
49 self.patterns: List[PropertyPattern] = []
50 self.config = config or GeneratorConfig()
52 def add_constraint(
53 self, constraint: Callable[[Sequence], bool]
54 ) -> "ConstrainedGenerator":
55 """
56 Add a constraint function that the generated sequences must satisfy.
58 Args:
59 constraint: A function that takes a sequence and returns True if the constraint is satisfied
61 Returns:
62 Self for method chaining
63 """
64 self.constraints.append(constraint)
65 return self
67 def add_pattern(self, pattern: PropertyPattern) -> "ConstrainedGenerator":
68 """
69 Add a property pattern that the generated sequences must follow.
71 Args:
72 pattern: A PropertyPattern instance defining a pattern to match
74 Returns:
75 Self for method chaining
76 """
77 self.patterns.append(pattern)
78 return self
80 def _satisfies_constraints(self, sequence: Sequence) -> bool:
81 """Check if the sequence satisfies all constraints."""
82 return all(constraint(sequence) for constraint in self.constraints)
84 def _satisfies_patterns(self, sequence: Sequence, start_idx: int = 0) -> bool:
85 """Check if the sequence satisfies all patterns starting from start_idx."""
86 return all(pattern.matches(sequence, start_idx) for pattern in self.patterns)
88 def predict_next(self, sequence: Sequence) -> List[AbstractObject]:
89 """
90 Predict the next possible items in the sequence.
92 Returns a list of possible next items that would satisfy all
93 constraints and patterns.
95 Args:
96 sequence: The current sequence to predict next items for
98 Returns:
99 List of candidate objects that could be appended to the sequence
100 """
101 if not sequence:
102 # Empty sequence - return all domain items that satisfy constraints
103 return [
104 item
105 for item in self.domain
106 if self._satisfies_constraints([item])
107 and self._satisfies_patterns([item])
108 ]
110 # Try each domain item as a potential next item
111 candidates = []
112 for item in self.domain:
113 # Create a new sequence with this item added
114 new_sequence = sequence + [item]
116 # Check if it satisfies constraints and patterns
117 if self._satisfies_constraints(new_sequence) and self._satisfies_patterns(
118 new_sequence, len(sequence) - 1
119 ):
120 candidates.append(item)
122 return candidates
124 def generate(self, max_length: int = 10) -> Iterator[Sequence]:
125 """
126 Generate sequences satisfying all constraints and patterns.
128 Args:
129 max_length: Maximum length of generated sequences
131 Yields:
132 Valid sequences of increasing length
133 """
134 # Start with empty sequence
135 sequences_to_process = [[]]
137 while sequences_to_process:
138 current = sequences_to_process.pop(0)
140 # Yield if valid
141 if self._satisfies_constraints(current) and self._satisfies_patterns(
142 current
143 ):
144 yield current
146 # Stop extending if we've reached max length
147 if len(current) >= max_length:
148 continue
150 # Get candidates for the next position
151 candidates = self.predict_next(current)
152 if not candidates:
153 # No valid candidates to extend this sequence
154 continue
156 # Randomize order to get variety
157 if self.config.randomize_candidates:
158 shuffled = list(candidates)
159 random.shuffle(shuffled)
161 # Limit the number of candidates if configured
162 if self.config.max_candidates_per_step > 0:
163 shuffled = shuffled[: self.config.max_candidates_per_step]
164 else:
165 shuffled = candidates
167 # Add new sequences to process
168 for candidate in shuffled:
169 new_sequence = current + [candidate]
170 if self._satisfies_constraints(
171 new_sequence
172 ) and self._satisfies_patterns(
173 new_sequence
174 ): # pragma: no branch
175 sequences_to_process.append(new_sequence)