Coverage for src/seqrule/generators.py: 15%

211 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-26 10:19 -0600

1""" 

2Utilities for generating sequences. 

3 

4This module provides functions for generating sequences from domains of objects, 

5with support for bounded generation, filtering, constraints, and prediction. 

6""" 

7 

8from typing import List, Optional, Callable, Dict, Any, Set, Tuple, Iterator 

9from dataclasses import dataclass 

10from collections import defaultdict 

11import random 

12import itertools 

13 

14from .core import AbstractObject, Sequence, FormalRule 

15 

16@dataclass 

17class Constraint: 

18 """Represents a constraint on sequence generation.""" 

19 property_name: str 

20 condition: Callable[[Any], bool] 

21 description: str = "" 

22 

23 def __call__(self, value: Any) -> bool: 

24 return self.condition(value) 

25 

26class PropertyPattern: 

27 def __init__(self, property_name: str, values: List[Any], is_cyclic: bool = False): 

28 self.property_name = property_name 

29 self.values = values 

30 self.is_cyclic = is_cyclic 

31 

32 def _get_property_value(self, obj: Any) -> Any: 

33 """Get property value from either AbstractObject or dict.""" 

34 if hasattr(obj, 'properties'): 

35 return obj.properties[self.property_name] 

36 elif hasattr(obj, '__getitem__'): 

37 return obj[self.property_name] 

38 return getattr(obj, self.property_name) 

39 

40 def matches(self, sequence: List[Dict[str, Any]], start_idx: int = 0) -> bool: 

41 """Check if the sequence matches the pattern starting from start_idx.""" 

42 if not sequence: 

43 return True 

44 

45 # For cyclic patterns, we need to check each position relative to the pattern start 

46 if self.is_cyclic: 

47 # Special case for single-value patterns 

48 if len(self.values) == 1: 

49 expected_value = self.values[0] 

50 return all(self._get_property_value(obj) == expected_value 

51 for obj in sequence[start_idx:]) 

52 

53 # Find where the pattern should start 

54 if start_idx == 0: 

55 # For the beginning of the sequence, we must start with the first value 

56 if self._get_property_value(sequence[0]) != self.values[0]: 

57 return False 

58 pattern_pos = 1 

59 check_from = 1 

60 else: 

61 # For other positions, find where we are in the pattern 

62 prev_value = self._get_property_value(sequence[start_idx - 1]) 

63 try: 

64 prev_pos = self.values.index(prev_value) 

65 pattern_pos = (prev_pos + 1) % len(self.values) 

66 check_from = start_idx 

67 except ValueError: 

68 return False 

69 

70 # Check the rest of the sequence follows the pattern 

71 for i in range(check_from, len(sequence)): 

72 expected_value = self.values[pattern_pos] 

73 actual_value = self._get_property_value(sequence[i]) 

74 if actual_value != expected_value: 

75 return False 

76 pattern_pos = (pattern_pos + 1) % len(self.values) 

77 

78 else: 

79 # For non-cyclic patterns, check exact match 

80 if start_idx + len(self.values) > len(sequence): 

81 return False 

82 

83 # For non-cyclic patterns, sequence must be exactly the same length as pattern 

84 if len(sequence) - start_idx > len(self.values): 

85 return False 

86 

87 for i, value in enumerate(self.values): 

88 actual_value = self._get_property_value(sequence[start_idx + i]) 

89 if actual_value != value: 

90 return False 

91 

92 return True 

93 

94 def get_next_value(self, sequence: List[Dict[str, Any]]) -> Any: 

95 """Predict the next value in the pattern based on the current sequence.""" 

96 if not sequence: 

97 return self.values[0] 

98 

99 if self.is_cyclic: 

100 # Get the last value in the sequence 

101 last_value = self._get_property_value(sequence[-1]) 

102 # Find its position in the pattern 

103 try: 

104 current_pos = self.values.index(last_value) 

105 # Return the next value in the cycle 

106 return self.values[(current_pos + 1) % len(self.values)] 

107 except ValueError: 

108 # If the last value isn't in the pattern, start from beginning 

109 return self.values[0] 

110 

111 last_pos = min(len(sequence), len(self.values) - 1) 

112 return self.values[last_pos] 

113 

114class ConstrainedGenerator: 

115 """Generates sequences satisfying multiple constraints.""" 

116 

117 def __init__(self, domain: List[Dict[str, Any]]): 

118 self.domain = domain 

119 self.constraints: List[Callable[[List[Dict[str, Any]]], bool]] = [] 

120 self.patterns: List[PropertyPattern] = [] 

121 

122 def add_constraint(self, constraint: Callable[[List[Dict[str, Any]]], bool]) -> 'ConstrainedGenerator': 

123 """Add a constraint and return self for chaining.""" 

124 self.constraints.append(constraint) 

125 return self 

126 

127 def add_pattern(self, pattern: PropertyPattern) -> 'ConstrainedGenerator': 

128 """Add a pattern and return self for chaining.""" 

129 self.patterns.append(pattern) 

130 return self 

131 

132 def _satisfies_constraints(self, sequence: List[Dict[str, Any]]) -> bool: 

133 return all(constraint(sequence) for constraint in self.constraints) 

134 

135 def _satisfies_patterns(self, sequence: List[Dict[str, Any]], start_idx: int = 0) -> bool: 

136 return all(pattern.matches(sequence, start_idx) for pattern in self.patterns) 

137 

138 def predict_next(self, sequence: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 

139 """Predict possible next items that maintain all patterns and constraints.""" 

140 predictions = [] 

141 

142 # Get expected values from patterns 

143 expected_values = {} 

144 for pattern in self.patterns: 

145 next_val = pattern.get_next_value(sequence) 

146 expected_values[pattern.property_name] = next_val 

147 

148 # Filter domain based on expected values 

149 for item in self.domain: 

150 matches_patterns = True 

151 for pattern in self.patterns: 

152 expected_val = expected_values[pattern.property_name] 

153 actual_val = pattern._get_property_value(item) 

154 if actual_val != expected_val: 

155 matches_patterns = False 

156 break 

157 

158 if matches_patterns: 

159 test_sequence = sequence + [item] 

160 if self._satisfies_constraints(test_sequence): 

161 predictions.append(item) 

162 

163 return predictions 

164 

165 def generate(self, max_length: int = 10) -> Iterator[List[Dict[str, Any]]]: 

166 """Generate sequences that satisfy all constraints and patterns.""" 

167 def extend_sequence(current: List[Dict[str, Any]], start_with_first: bool = True) -> Iterator[List[Dict[str, Any]]]: 

168 if not current and start_with_first: 

169 # Start with sequences that begin with the first value of each cyclic pattern 

170 initial_items = [] 

171 cyclic_patterns = [p for p in self.patterns if p.is_cyclic] 

172 

173 if cyclic_patterns: 

174 # Get items that match the first value of all cyclic patterns 

175 candidates = self.domain 

176 for pattern in cyclic_patterns: 

177 first_value = pattern.values[0] 

178 candidates = [ 

179 item for item in candidates 

180 if pattern._get_property_value(item) == first_value 

181 ] 

182 initial_items = candidates 

183 

184 if initial_items: 

185 for item in initial_items: 

186 if self._satisfies_constraints([item]) and self._satisfies_patterns([item]): 

187 yield from extend_sequence([item], False) 

188 return 

189 else: 

190 # If no initial items match cyclic patterns, start with empty sequence 

191 yield [] 

192 return 

193 

194 yield current 

195 

196 if len(current) >= max_length: 

197 return 

198 

199 next_items = self.predict_next(current) 

200 for item in next_items: 

201 new_sequence = current + [item] 

202 if self._satisfies_patterns(new_sequence): 

203 yield from extend_sequence(new_sequence, False) 

204 

205 # Start with empty sequence if no cyclic patterns 

206 if not any(pattern.is_cyclic for pattern in self.patterns): 

207 yield from extend_sequence([], False) 

208 else: 

209 yield from extend_sequence([]) 

210 

211def generate_counter_examples(rule: FormalRule, domain: List[AbstractObject], 

212 max_length: int, max_attempts: int = 1000) -> List[Sequence]: 

213 """Generate sequences that violate the given rule.""" 

214 counter_examples = [] 

215 attempts = 0 

216 

217 while attempts < max_attempts and len(counter_examples) < 5: 

218 # Generate random sequence 

219 length = random.randint(1, max_length) 

220 sequence = [random.choice(domain) for _ in range(length)] 

221 

222 # Check if it violates the rule 

223 if not rule(sequence): 

224 # Check if it's minimal 

225 is_minimal = True 

226 for i in range(len(sequence)): 

227 subsequence = sequence[:i] + sequence[i+1:] 

228 if not rule(subsequence): 

229 is_minimal = False 

230 break 

231 if is_minimal: 

232 counter_examples.append(sequence) 

233 attempts += 1 

234 

235 return counter_examples 

236 

237def generate_sequences(domain, max_length=10, filter_rule=None): 

238 """Generate all valid sequences up to max_length from the given domain. 

239  

240 Args: 

241 domain: List of objects to generate sequences from 

242 max_length: Maximum length of sequences to generate 

243 filter_rule: Optional function to filter sequences 

244  

245 Returns: 

246 List of valid sequences 

247 """ 

248 if max_length < 0: 

249 raise ValueError("max_length must be non-negative") 

250 

251 sequences = [[]] # Start with empty sequence 

252 if max_length == 0: 

253 return sequences 

254 

255 for length in range(1, max_length + 1): 

256 for seq in itertools.product(domain, repeat=length): 

257 sequences.append(list(seq)) 

258 

259 if filter_rule: 

260 try: 

261 sequences = [seq for seq in sequences if filter_rule(seq)] 

262 except Exception as e: 

263 # Log the error but continue with empty result 

264 print(f"Error applying filter: {str(e)}") 

265 return [] 

266 

267 return sequences 

268 

269class LazyGenerator: 

270 """Generates sequences lazily one at a time.""" 

271 

272 def __init__(self, domain, max_length=10, filter_rule=None): 

273 self.domain = domain 

274 self.max_length = max_length 

275 self.filter_rule = filter_rule 

276 self.current_length = 1 # Always start with length 1 

277 self.current_index = 0 

278 self.total_sequences = sum(len(domain) ** i for i in range(max_length + 1)) 

279 self.empty_returned = False 

280 

281 def __call__(self): 

282 """Generate the next sequence.""" 

283 # Handle empty domain case first 

284 if not self.domain: 

285 if not self.empty_returned: 

286 self.empty_returned = True 

287 return [] 

288 return None 

289 

290 # Handle empty sequence only if max_length=1 and no filter 

291 if not self.empty_returned and not self.filter_rule and self.max_length == 1: 

292 self.empty_returned = True 

293 return [] 

294 

295 while self.current_length <= self.max_length: 

296 # Calculate base-n representation of current index 

297 indices = [] 

298 remaining = self.current_index 

299 for _ in range(self.current_length): 

300 indices.append(remaining % len(self.domain)) 

301 remaining //= len(self.domain) 

302 

303 # Reverse indices to get correct order 

304 indices.reverse() 

305 sequence = [self.domain[i] for i in indices] 

306 

307 # Update state for next call 

308 self.current_index += 1 

309 if self.current_index >= len(self.domain) ** self.current_length: 

310 self.current_index = 0 

311 self.current_length += 1 

312 

313 try: 

314 if not self.filter_rule or self.filter_rule(sequence): 

315 return sequence 

316 except Exception as e: 

317 print(f"Error applying filter: {str(e)}") 

318 return None 

319 

320 return None 

321 

322 def __iter__(self): 

323 """Make the generator iterable.""" 

324 while True: 

325 sequence = self() 

326 if sequence is None: 

327 break 

328 yield sequence 

329 

330def generate_lazy(domain, max_length=10, filter_rule=None): 

331 """Create a lazy sequence generator. 

332  

333 Args: 

334 domain: List of objects to generate sequences from 

335 max_length: Maximum length of sequences to generate 

336 filter_rule: Optional function to filter sequences 

337  

338 Returns: 

339 LazyGenerator instance 

340 """ 

341 return LazyGenerator(domain, max_length, filter_rule)