Coverage for src/seqrule/generators/constrained.py: 34%

58 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-27 10:39 -0600

1""" 

2Constrained sequence generation. 

3 

4This module provides functionality for generating sequences 

5that satisfy a set of constraints. 

6""" 

7 

8import random 

9from dataclasses import dataclass 

10from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar, Union 

11 

12from ..core import AbstractObject, Sequence 

13from .patterns import PropertyPattern 

14 

15T = TypeVar("T") 

16 

17 

18@dataclass 

19class GeneratorConfig: 

20 """Configuration options for constrained generators.""" 

21 

22 max_attempts: int = 100 

23 randomize_candidates: bool = True 

24 max_candidates_per_step: int = 10 

25 backtracking_enabled: bool = False 

26 

27 

28class ConstrainedGenerator: 

29 """Generator that produces sequences satisfying constraints and patterns.""" 

30 

31 def __init__( 

32 self, 

33 domain: List[Union[Dict[str, Any], AbstractObject]], 

34 config: Optional[GeneratorConfig] = None, 

35 ): 

36 """ 

37 Initialize with a domain of possible objects. 

38 

39 Args: 

40 domain: List of objects that can be included in the sequence 

41 config: Optional configuration settings for the generator 

42 """ 

43 # Normalize domain to ensure all items are AbstractObjects 

44 self.domain = [ 

45 obj if isinstance(obj, AbstractObject) else AbstractObject(**obj) 

46 for obj in domain 

47 ] 

48 self.constraints: List[Callable[[Sequence], bool]] = [] 

49 self.patterns: List[PropertyPattern] = [] 

50 self.config = config or GeneratorConfig() 

51 

52 def add_constraint( 

53 self, constraint: Callable[[Sequence], bool] 

54 ) -> "ConstrainedGenerator": 

55 """ 

56 Add a constraint function that the generated sequences must satisfy. 

57 

58 Args: 

59 constraint: A function that takes a sequence and returns True if the constraint is satisfied 

60 

61 Returns: 

62 Self for method chaining 

63 """ 

64 self.constraints.append(constraint) 

65 return self 

66 

67 def add_pattern(self, pattern: PropertyPattern) -> "ConstrainedGenerator": 

68 """ 

69 Add a property pattern that the generated sequences must follow. 

70 

71 Args: 

72 pattern: A PropertyPattern instance defining a pattern to match 

73 

74 Returns: 

75 Self for method chaining 

76 """ 

77 self.patterns.append(pattern) 

78 return self 

79 

80 def _satisfies_constraints(self, sequence: Sequence) -> bool: 

81 """Check if the sequence satisfies all constraints.""" 

82 return all(constraint(sequence) for constraint in self.constraints) 

83 

84 def _satisfies_patterns(self, sequence: Sequence, start_idx: int = 0) -> bool: 

85 """Check if the sequence satisfies all patterns starting from start_idx.""" 

86 return all(pattern.matches(sequence, start_idx) for pattern in self.patterns) 

87 

88 def predict_next(self, sequence: Sequence) -> List[AbstractObject]: 

89 """ 

90 Predict the next possible items in the sequence. 

91 

92 Returns a list of possible next items that would satisfy all 

93 constraints and patterns. 

94 

95 Args: 

96 sequence: The current sequence to predict next items for 

97 

98 Returns: 

99 List of candidate objects that could be appended to the sequence 

100 """ 

101 if not sequence: 

102 # Empty sequence - return all domain items that satisfy constraints 

103 return [ 

104 item 

105 for item in self.domain 

106 if self._satisfies_constraints([item]) 

107 and self._satisfies_patterns([item]) 

108 ] 

109 

110 # Try each domain item as a potential next item 

111 candidates = [] 

112 for item in self.domain: 

113 # Create a new sequence with this item added 

114 new_sequence = sequence + [item] 

115 

116 # Check if it satisfies constraints and patterns 

117 if self._satisfies_constraints(new_sequence) and self._satisfies_patterns( 

118 new_sequence, len(sequence) - 1 

119 ): 

120 candidates.append(item) 

121 

122 return candidates 

123 

124 def generate(self, max_length: int = 10) -> Iterator[Sequence]: 

125 """ 

126 Generate sequences satisfying all constraints and patterns. 

127 

128 Args: 

129 max_length: Maximum length of generated sequences 

130 

131 Yields: 

132 Valid sequences of increasing length 

133 """ 

134 # Start with empty sequence 

135 sequences_to_process = [[]] 

136 

137 while sequences_to_process: 

138 current = sequences_to_process.pop(0) 

139 

140 # Yield if valid 

141 if self._satisfies_constraints(current) and self._satisfies_patterns( 

142 current 

143 ): 

144 yield current 

145 

146 # Stop extending if we've reached max length 

147 if len(current) >= max_length: 

148 continue 

149 

150 # Get candidates for the next position 

151 candidates = self.predict_next(current) 

152 if not candidates: 

153 # No valid candidates to extend this sequence 

154 continue 

155 

156 # Randomize order to get variety 

157 if self.config.randomize_candidates: 

158 shuffled = list(candidates) 

159 random.shuffle(shuffled) 

160 

161 # Limit the number of candidates if configured 

162 if self.config.max_candidates_per_step > 0: 

163 shuffled = shuffled[: self.config.max_candidates_per_step] 

164 else: 

165 shuffled = candidates 

166 

167 # Add new sequences to process 

168 for candidate in shuffled: 

169 new_sequence = current + [candidate] 

170 if self._satisfies_constraints( 

171 new_sequence 

172 ) and self._satisfies_patterns( 

173 new_sequence 

174 ): # pragma: no branch 

175 sequences_to_process.append(new_sequence)