Coverage for src/seqrule/rulesets/general.py: 7%
313 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-27 10:39 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-27 10:39 -0600
1"""
2General purpose sequence rules.
4This module provides a collection of commonly useful rule patterns that can be
5applied across different domains. These patterns are abstracted from common
6use cases seen in specific domains like card games, DNA sequences, music,
7and tea processing.
9Common use cases:
10- Pattern matching and cycles
11- Property-based rules
12- Numerical constraints
13- Historical patterns
14- Meta-rules and combinations
15"""
17from typing import Any, Callable, Dict, List, Optional, Set, TypeVar
19from ..core import AbstractObject, Sequence
20from ..dsl import DSLRule
22T = TypeVar("T")
25def create_property_match_rule(property_name: str, value: Any) -> DSLRule:
26 """
27 Creates a rule requiring objects to have a specific property value.
29 Example:
30 color_is_red = create_property_match_rule("color", "red")
31 """
33 def check_property(seq: Sequence) -> bool:
34 return all(obj.properties.get(property_name) == value for obj in seq)
36 return DSLRule(check_property, f"all objects have {property_name}={value}")
39def create_property_cycle_rule(*properties: str) -> DSLRule:
40 """
41 Creates a rule requiring objects to cycle through property values.
43 Example:
44 color_cycle = create_property_cycle_rule("color") # Values must cycle
45 """
47 def check_cycle(seq: Sequence) -> bool:
48 if not seq:
49 return True
51 for prop in properties:
52 try:
53 values = [obj.properties.get(prop) for obj in seq]
54 if len(values) <= 1:
55 continue
57 # Find potential cycle by looking at first occurrence of each value
58 seen_values = []
59 cycle_found = False
60 for value in values:
61 if value in seen_values:
62 # Found potential cycle, verify it matches
63 cycle = seen_values
64 cycle_length = len(cycle)
66 # Check if rest of sequence follows the cycle
67 for i, value in enumerate(values):
68 if value != cycle[i % cycle_length]:
69 return False
70 cycle_found = True
71 break
72 seen_values.append(value)
74 # If we get here without finding a cycle, it's invalid
75 if not cycle_found:
76 return False
78 except TypeError:
79 continue # Skip properties with errors
81 return True
83 return DSLRule(check_cycle, f"properties {properties} form cycles")
86def create_alternation_rule(property_name: str) -> DSLRule:
87 """
88 Creates a rule requiring alternating property values.
90 Example:
91 alternating_colors = create_alternation_rule("color")
92 """
94 def check_alternation(seq: Sequence) -> bool:
95 if len(seq) <= 1:
96 return True
98 for i in range(len(seq) - 1):
99 val1 = seq[i].properties.get(property_name)
100 val2 = seq[i + 1].properties.get(property_name)
101 if val1 is not None and val2 is not None and val1 == val2:
102 return False
103 return True
105 return DSLRule(check_alternation, f"{property_name} values must alternate")
108def create_numerical_range_rule(
109 property_name: str, min_value: float, max_value: float
110) -> DSLRule:
111 """
112 Creates a rule requiring numerical property values within a range.
114 Example:
115 valid_temperature = create_numerical_range_rule("temperature", 20, 30)
116 """
118 def check_range(seq: Sequence) -> bool:
119 for obj in seq:
120 try:
121 value = obj.properties.get(property_name)
122 if value is not None:
123 value = float(value)
124 if not (min_value <= value <= max_value):
125 return False
126 except (ValueError, TypeError):
127 continue # Skip invalid values
128 return True
130 return DSLRule(
131 check_range, f"{property_name} must be between {min_value} and {max_value}"
132 )
135def create_sum_rule(
136 property_name: str, target: float, tolerance: float = 0.001
137) -> DSLRule:
138 """
139 Creates a rule requiring property values to sum to a target value.
141 Example:
142 total_duration = create_sum_rule("duration", 60.0) # Sum to 60
143 """
145 def check_sum(seq: Sequence) -> bool:
146 if not seq:
147 return True # Empty sequence is valid
149 values = []
150 for obj in seq:
151 if property_name not in obj.properties:
152 raise ValueError(f"Missing required property: {property_name}")
153 try:
154 value = float(obj.properties[property_name])
155 values.append(value)
156 except (ValueError, TypeError) as e:
157 raise ValueError(f"Invalid value for {property_name}") from e
159 total = sum(values)
160 return abs(total - target) <= tolerance
162 return DSLRule(check_sum, f"sum of {property_name} must be {target}")
165def create_pattern_rule(pattern: List[Any], property_name: str) -> DSLRule:
166 """
167 Creates a rule requiring property values to match a specific pattern.
169 Example:
170 color_pattern = create_pattern_rule(["red", "black", "red"], "color")
171 """
173 def check_pattern(seq: Sequence) -> bool:
174 if not seq:
175 return False
177 # If sequence is shorter than pattern, it's valid if it matches the start of the pattern
178 if len(seq) < len(pattern):
179 values = [obj.properties.get(property_name) for obj in seq]
180 return all(values[i] == pattern[i] for i in range(len(values)))
182 values = [obj.properties.get(property_name) for obj in seq]
183 pattern_length = len(pattern)
185 # Check if values match pattern cyclically
186 return all(values[i] == pattern[i % pattern_length] for i in range(len(values)))
188 return DSLRule(check_pattern, f"{property_name} must match pattern {pattern}")
191def create_historical_rule(
192 window: int, condition: Callable[[List[AbstractObject]], bool]
193) -> DSLRule:
194 """
195 Creates a rule checking a condition over a sliding window.
197 Example:
198 def no_repeats(window): return len(set(obj["value"] for obj in window)) == len(window)
199 unique_values = create_historical_rule(3, no_repeats)
200 """
202 def check_historical(seq: Sequence) -> bool:
203 if len(seq) < window:
204 return True
206 for i in range(len(seq) - window + 1):
207 try:
208 window_seq = seq[i : i + window]
209 if not condition(window_seq):
210 return False
211 except Exception: # Catch any error from condition
212 continue # Skip windows with errors
213 return True
215 return DSLRule(
216 check_historical, f"condition must hold over {window}-object windows"
217 )
220def create_dependency_rule(
221 property_name: str, dependencies: Dict[Any, Set[Any]]
222) -> DSLRule:
223 """
224 Creates a rule enforcing dependencies between property values.
226 Example:
227 stage_deps = create_dependency_rule("stage", {"deploy": {"test", "build"}})
228 """
230 def check_dependencies(seq: Sequence) -> bool:
231 for obj in seq:
232 try:
233 if property_name not in obj.properties:
234 raise KeyError(property_name)
235 value = obj.properties[property_name]
236 if value is None:
237 continue
238 for dep_value, required_values in dependencies.items():
239 if value == dep_value:
240 for required_value in required_values:
241 if not any(
242 o.properties.get(property_name) == required_value
243 for o in seq
244 ):
245 return False
246 except (KeyError, TypeError) as e:
247 raise KeyError(property_name) from e
248 return True
250 return DSLRule(check_dependencies, f"dependencies between {property_name} values")
253def create_meta_rule(rules: List[DSLRule], required_count: int) -> DSLRule:
254 """
255 Creates a rule requiring a certain number of other rules to be satisfied.
257 Example:
258 any_two = create_meta_rule([rule1, rule2, rule3], 2) # Any 2 must pass
259 """
261 def check_meta(seq: Sequence) -> bool:
262 if not rules:
263 return True # Empty rule list passes
265 passed = sum(1 for rule in rules if rule(seq))
266 return passed >= required_count
268 return DSLRule(check_meta, f"at least {required_count} rules must be satisfied")
271def create_group_rule(
272 group_size: int, condition: Callable[[List[AbstractObject]], bool]
273) -> DSLRule:
274 """
275 Creates a rule checking a condition over groups of consecutive objects.
277 Example:
278 def ascending(group):
279 return all(group[i]["value"] < group[i+1]["value"]
280 for i in range(len(group)-1))
281 ascending_pairs = create_group_rule(2, ascending)
282 """
284 def check_groups(seq: Sequence) -> bool:
285 if len(seq) < group_size:
286 return True
288 for i in range(0, len(seq) - group_size + 1):
289 try:
290 group = seq[i : i + group_size]
291 if not condition(group):
292 return False
293 except Exception: # Catch any error from condition
294 continue # Skip groups with errors
295 return True
297 return DSLRule(check_groups, f"condition must hold for each group of {group_size}")
300# Common rule combinations
301def create_bounded_sequence_rule(
302 min_length: int, max_length: int, inner_rule: DSLRule
303) -> DSLRule:
304 """
305 Creates a rule that combines length constraints with another rule.
307 Example:
308 valid_sequence = create_bounded_sequence_rule(2, 5, pattern_rule)
309 """
311 def check_bounded(seq: Sequence) -> bool:
312 return min_length <= len(seq) <= max_length and inner_rule(seq)
314 return DSLRule(
315 check_bounded, f"length {min_length}-{max_length} and {inner_rule.description}"
316 )
319def create_composite_rule(rules: List[DSLRule], mode: str = "all") -> DSLRule:
320 """
321 Creates a rule that combines multiple rules with AND/OR logic.
323 Example:
324 all_rules = create_composite_rule([rule1, rule2], mode="all")
325 any_rule = create_composite_rule([rule1, rule2], mode="any")
326 """
328 def check_composite(seq: Sequence) -> bool:
329 results = []
330 for rule in rules:
331 try:
332 result = rule(seq)
333 results.append(result)
334 if mode == "all" and not result:
335 return False # Short-circuit AND mode
336 if mode == "any" and result:
337 return True # Short-circuit OR mode
338 except Exception: # Catch any error from rules
339 if mode == "all":
340 return False # Any error fails AND mode
341 continue # Skip errors in OR mode
343 if not results:
344 return True # No valid results means pass
346 return all(results) if mode == "all" else any(results)
348 mode_desc = "all" if mode == "all" else "any"
349 return DSLRule(check_composite, f"{mode_desc} of the rules must be satisfied")
352def create_ratio_rule(
353 property_name: str,
354 min_ratio: float,
355 max_ratio: float,
356 filter_rule: Optional[Callable[[AbstractObject], bool]] = None,
357) -> DSLRule:
358 """
359 Creates a rule requiring a ratio of objects meeting a condition to be within a range.
361 Example:
362 # At least 40% but no more than 60% GC content
363 gc_content = create_ratio_rule("base", 0.4, 0.6, lambda obj: obj["base"] in ["G", "C"])
364 """
366 def check_ratio(seq: Sequence) -> bool:
367 if not seq:
368 return True
370 # First collect valid objects
371 valid_objects = []
372 for obj in seq:
373 try:
374 if property_name not in obj.properties: # Use properties directly
375 continue
376 if obj.properties[property_name] is None: # Use properties directly
377 continue
378 valid_objects.append(obj)
379 except Exception: # Catch any access errors
380 continue
382 if not valid_objects:
383 return True
385 # Count matching objects
386 if filter_rule is None:
387 # Without filter, count objects matching the first value
388 first_value = valid_objects[0].properties[
389 property_name
390 ] # Use properties directly
391 count = sum(
392 1
393 for obj in valid_objects
394 if obj.properties[property_name] == first_value
395 )
396 else:
397 try:
398 count = sum(1 for obj in valid_objects if filter_rule(obj))
399 except Exception: # Catch any filter function errors
400 return True # Skip if filter function fails
402 total = len(valid_objects)
403 ratio = count / total
404 return min_ratio <= ratio <= max_ratio
406 return DSLRule(
407 check_ratio, f"ratio must be between {min_ratio:.1%} and {max_ratio:.1%}"
408 )
411def create_transition_rule(
412 property_name: str, valid_transitions: Dict[Any, Set[Any]]
413) -> DSLRule:
414 """
415 Creates a rule enforcing valid transitions between property values.
417 Example:
418 # Valid note transitions in a scale
419 scale_rule = create_transition_rule("pitch", {
420 "C": {"D"}, "D": {"E"}, "E": {"F"}, "F": {"G"},
421 "G": {"A"}, "A": {"B"}, "B": {"C"}
422 })
423 """
425 def check_transitions(seq: Sequence) -> bool:
426 if len(seq) <= 1:
427 return True
429 # First check if we have any valid transitions to check
430 has_valid_pair = False
431 last_valid_value = None
433 for obj in seq:
434 try:
435 value = obj.properties[property_name]
436 if value is None:
437 continue
439 if last_valid_value is not None:
440 if last_valid_value in valid_transitions:
441 has_valid_pair = True
442 if value not in valid_transitions[last_valid_value]:
443 return False
444 last_valid_value = value
446 except (KeyError, TypeError):
447 continue # Skip invalid transitions
449 # If we found no valid pairs to check, pass
450 if not has_valid_pair:
451 return True
453 return True
455 return DSLRule(
456 check_transitions, f"transitions between {property_name} values must be valid"
457 )
460def create_running_stat_rule(
461 property_name: str,
462 stat_func: Callable[[List[float]], float],
463 min_value: float,
464 max_value: float,
465 window: int,
466) -> DSLRule:
467 """
468 Creates a rule checking a running statistic over a sliding window.
470 Example:
471 # Moving average of temperatures must be between 20-30
472 moving_avg = create_running_stat_rule(
473 "temp", lambda x: sum(x)/len(x), 20, 30, window=3
474 )
475 """
477 def check_stat(seq: Sequence) -> bool:
478 if len(seq) < window:
479 return True
481 # Check each window
482 for i in range(len(seq) - window + 1):
483 window_values = []
484 valid_window = True
486 # Try to get all values in window
487 for obj in seq[i : i + window]:
488 try:
489 value = float(obj.properties[property_name])
490 window_values.append(value)
491 except (ValueError, TypeError, KeyError):
492 valid_window = False
493 break
495 # Skip invalid windows
496 if not valid_window or len(window_values) < window:
497 continue
499 try:
500 stat = stat_func(window_values)
501 if not (min_value <= stat <= max_value):
502 return False
503 except (ValueError, ZeroDivisionError):
504 continue
506 # If we get here, either all windows were skipped or all were valid
507 return True
509 return DSLRule(
510 check_stat, f"running statistic must be between {min_value} and {max_value}"
511 )
514def create_unique_property_rule(property_name: str, scope: str = "global") -> DSLRule:
515 """
516 Creates a rule requiring property values to be unique within a scope.
518 Example:
519 # No duplicate IDs globally
520 unique_ids = create_unique_property_rule("id", scope="global")
521 # No adjacent duplicate values
522 no_adjacent = create_unique_property_rule("value", scope="adjacent")
523 """
525 def check_unique(seq: Sequence) -> bool:
526 if not seq:
527 return True
529 if scope == "global":
530 values = []
531 for obj in seq:
532 if property_name not in obj.properties:
533 raise KeyError(property_name)
534 values.append(obj.properties[property_name])
535 return len(values) == len(set(values))
536 elif scope == "adjacent":
537 for i in range(len(seq) - 1):
538 if (
539 property_name not in seq[i].properties
540 or property_name not in seq[i + 1].properties
541 ):
542 raise KeyError(property_name)
543 if (
544 seq[i].properties[property_name]
545 == seq[i + 1].properties[property_name]
546 ):
547 return False
548 return True
549 return True
551 return DSLRule(
552 check_unique, f"{property_name} values must be unique within {scope} scope"
553 )
556def create_property_trend_rule(
557 property_name: str, trend: str = "increasing"
558) -> DSLRule:
559 """
560 Creates a rule requiring property values to follow a trend.
562 Example:
563 # Values must strictly increase
564 increasing = create_property_trend_rule("value", "increasing")
565 # Values must be non-increasing
566 non_increasing = create_property_trend_rule("value", "non-increasing")
567 """
569 def check_trend(seq: Sequence) -> bool:
570 if len(seq) <= 1:
571 return True
573 # Collect valid values first
574 values = []
575 for obj in seq:
576 try:
577 if property_name not in obj.properties:
578 continue
579 value = obj.properties[property_name]
580 if value is None:
581 continue
582 try:
583 values.append(float(value))
584 except (ValueError, TypeError):
585 continue # Skip non-numeric values
586 except Exception:
587 continue
589 # If we don't have enough valid values, pass
590 if len(values) <= 1:
591 return True
593 # Check trend between consecutive valid values
594 for i in range(len(values) - 1):
595 current = values[i]
596 next_value = values[i + 1]
598 if trend == "increasing":
599 if not (current < next_value):
600 return False
601 elif trend == "decreasing":
602 if not (current > next_value):
603 return False
604 elif trend == "non-increasing":
605 if not (current >= next_value):
606 return False
607 elif trend == "non-decreasing":
608 if not (next_value >= current):
609 return False
611 return True
613 return DSLRule(check_trend, f"{property_name} values must be {trend}")
616def create_balanced_rule(
617 property_name: str, groups: Dict[Any, Set[Any]], tolerance: float = 0.1
618) -> DSLRule:
619 """
620 Creates a rule requiring balanced representation of property value groups.
622 Example:
623 # Equal number of red and black cards (±10%)
624 balanced_colors = create_balanced_rule("color", {
625 "red": {"red"}, "black": {"black"}
626 })
627 """
629 def check_balance(seq: Sequence) -> bool:
630 if not seq:
631 return True
633 # Count occurrences in each group
634 counts = {group: 0 for group in groups}
635 for obj in seq:
636 try:
637 value = obj.properties[property_name] # Use properties directly
638 if value is None:
639 continue
640 for group, members in groups.items():
641 if value in members:
642 counts[group] += 1
643 except (KeyError, TypeError):
644 continue # Skip missing or invalid properties
646 if not counts or not any(counts.values()):
647 return True # No valid values found
649 # Check if counts are balanced within tolerance
650 avg = sum(counts.values()) / len(counts)
651 max_deviation = max(avg * tolerance, 1) # Allow at least 1 deviation
653 return all(abs(count - avg) <= max_deviation for count in counts.values())
655 return DSLRule(check_balance, f"{property_name} groups must be balanced")