Coverage for /Users/davegaeddert/Developer/dropseed/plain/plain-models/plain/models/fields/related_lookups.py: 49%

94 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-23 11:16 -0600

1from plain.models.lookups import ( 

2 Exact, 

3 GreaterThan, 

4 GreaterThanOrEqual, 

5 In, 

6 IsNull, 

7 LessThan, 

8 LessThanOrEqual, 

9) 

10 

11 

12class MultiColSource: 

13 contains_aggregate = False 

14 contains_over_clause = False 

15 

16 def __init__(self, alias, targets, sources, field): 

17 self.targets, self.sources, self.field, self.alias = ( 

18 targets, 

19 sources, 

20 field, 

21 alias, 

22 ) 

23 self.output_field = self.field 

24 

25 def __repr__(self): 

26 return f"{self.__class__.__name__}({self.alias}, {self.field})" 

27 

28 def relabeled_clone(self, relabels): 

29 return self.__class__( 

30 relabels.get(self.alias, self.alias), self.targets, self.sources, self.field 

31 ) 

32 

33 def get_lookup(self, lookup): 

34 return self.output_field.get_lookup(lookup) 

35 

36 def resolve_expression(self, *args, **kwargs): 

37 return self 

38 

39 

40def get_normalized_value(value, lhs): 

41 from plain.models import Model 

42 

43 if isinstance(value, Model): 

44 if value.pk is None: 

45 raise ValueError("Model instances passed to related filters must be saved.") 

46 value_list = [] 

47 sources = lhs.output_field.path_infos[-1].target_fields 

48 for source in sources: 

49 while not isinstance(value, source.model) and source.remote_field: 

50 source = source.remote_field.model._meta.get_field( 

51 source.remote_field.field_name 

52 ) 

53 try: 

54 value_list.append(getattr(value, source.attname)) 

55 except AttributeError: 

56 # A case like Restaurant.objects.filter(place=restaurant_instance), 

57 # where place is a OneToOneField and the primary key of Restaurant. 

58 return (value.pk,) 

59 return tuple(value_list) 

60 if not isinstance(value, tuple): 

61 return (value,) 

62 return value 

63 

64 

65class RelatedIn(In): 

66 def get_prep_lookup(self): 

67 if not isinstance(self.lhs, MultiColSource): 

68 if self.rhs_is_direct_value(): 

69 # If we get here, we are dealing with single-column relations. 

70 self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs] 

71 # We need to run the related field's get_prep_value(). Consider 

72 # case ForeignKey to IntegerField given value 'abc'. The 

73 # ForeignKey itself doesn't have validation for non-integers, 

74 # so we must run validation using the target field. 

75 if hasattr(self.lhs.output_field, "path_infos"): 

76 # Run the target field's get_prep_value. We can safely 

77 # assume there is only one as we don't get to the direct 

78 # value branch otherwise. 

79 target_field = self.lhs.output_field.path_infos[-1].target_fields[ 

80 -1 

81 ] 

82 self.rhs = [target_field.get_prep_value(v) for v in self.rhs] 

83 elif not getattr(self.rhs, "has_select_fields", True) and not getattr( 

84 self.lhs.field.target_field, "primary_key", False 

85 ): 

86 if ( 

87 getattr(self.lhs.output_field, "primary_key", False) 

88 and self.lhs.output_field.model == self.rhs.model 

89 ): 

90 # A case like 

91 # Restaurant.objects.filter(place__in=restaurant_qs), where 

92 # place is a OneToOneField and the primary key of 

93 # Restaurant. 

94 target_field = self.lhs.field.name 

95 else: 

96 target_field = self.lhs.field.target_field.name 

97 self.rhs.set_values([target_field]) 

98 return super().get_prep_lookup() 

99 

100 def as_sql(self, compiler, connection): 

101 if isinstance(self.lhs, MultiColSource): 

102 # For multicolumn lookups we need to build a multicolumn where clause. 

103 # This clause is either a SubqueryConstraint (for values that need 

104 # to be compiled to SQL) or an OR-combined list of 

105 # (col1 = val1 AND col2 = val2 AND ...) clauses. 

106 from plain.models.sql.where import ( 

107 AND, 

108 OR, 

109 SubqueryConstraint, 

110 WhereNode, 

111 ) 

112 

113 root_constraint = WhereNode(connector=OR) 

114 if self.rhs_is_direct_value(): 

115 values = [get_normalized_value(value, self.lhs) for value in self.rhs] 

116 for value in values: 

117 value_constraint = WhereNode() 

118 for source, target, val in zip( 

119 self.lhs.sources, self.lhs.targets, value 

120 ): 

121 lookup_class = target.get_lookup("exact") 

122 lookup = lookup_class( 

123 target.get_col(self.lhs.alias, source), val 

124 ) 

125 value_constraint.add(lookup, AND) 

126 root_constraint.add(value_constraint, OR) 

127 else: 

128 root_constraint.add( 

129 SubqueryConstraint( 

130 self.lhs.alias, 

131 [target.column for target in self.lhs.targets], 

132 [source.name for source in self.lhs.sources], 

133 self.rhs, 

134 ), 

135 AND, 

136 ) 

137 return root_constraint.as_sql(compiler, connection) 

138 return super().as_sql(compiler, connection) 

139 

140 

141class RelatedLookupMixin: 

142 def get_prep_lookup(self): 

143 if not isinstance(self.lhs, MultiColSource) and not hasattr( 

144 self.rhs, "resolve_expression" 

145 ): 

146 # If we get here, we are dealing with single-column relations. 

147 self.rhs = get_normalized_value(self.rhs, self.lhs)[0] 

148 # We need to run the related field's get_prep_value(). Consider case 

149 # ForeignKey to IntegerField given value 'abc'. The ForeignKey itself 

150 # doesn't have validation for non-integers, so we must run validation 

151 # using the target field. 

152 if self.prepare_rhs and hasattr(self.lhs.output_field, "path_infos"): 

153 # Get the target field. We can safely assume there is only one 

154 # as we don't get to the direct value branch otherwise. 

155 target_field = self.lhs.output_field.path_infos[-1].target_fields[-1] 

156 self.rhs = target_field.get_prep_value(self.rhs) 

157 

158 return super().get_prep_lookup() 

159 

160 def as_sql(self, compiler, connection): 

161 if isinstance(self.lhs, MultiColSource): 

162 assert self.rhs_is_direct_value() 

163 self.rhs = get_normalized_value(self.rhs, self.lhs) 

164 from plain.models.sql.where import AND, WhereNode 

165 

166 root_constraint = WhereNode() 

167 for target, source, val in zip( 

168 self.lhs.targets, self.lhs.sources, self.rhs 

169 ): 

170 lookup_class = target.get_lookup(self.lookup_name) 

171 root_constraint.add( 

172 lookup_class(target.get_col(self.lhs.alias, source), val), AND 

173 ) 

174 return root_constraint.as_sql(compiler, connection) 

175 return super().as_sql(compiler, connection) 

176 

177 

178class RelatedExact(RelatedLookupMixin, Exact): 

179 pass 

180 

181 

182class RelatedLessThan(RelatedLookupMixin, LessThan): 

183 pass 

184 

185 

186class RelatedGreaterThan(RelatedLookupMixin, GreaterThan): 

187 pass 

188 

189 

190class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual): 

191 pass 

192 

193 

194class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual): 

195 pass 

196 

197 

198class RelatedIsNull(RelatedLookupMixin, IsNull): 

199 pass