Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/patsy/constraint.py : 15%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of Patsy
2# Copyright (C) 2011-2012 Nathaniel Smith <njs@pobox.com>
3# See file LICENSE.txt for license information.
5# Interpreting linear constraints like "2*x1 + x2 = 0"
7from __future__ import print_function
9# These are made available in the patsy.* namespace
10__all__ = ["LinearConstraint"]
12import re
13from collections import Mapping
14import six
15import numpy as np
16from patsy import PatsyError
17from patsy.origin import Origin
18from patsy.util import (atleast_2d_column_default,
19 repr_pretty_delegate, repr_pretty_impl,
20 no_pickling, assert_no_pickling)
21from patsy.infix_parser import Token, Operator, infix_parse
22from patsy.parse_formula import _parsing_error_test
25class LinearConstraint(object):
26 """A linear constraint in matrix form.
28 This object represents a linear constraint of the form `Ax = b`.
30 Usually you won't be constructing these by hand, but instead get them as
31 the return value from :meth:`DesignInfo.linear_constraint`.
33 .. attribute:: coefs
35 A 2-dimensional ndarray with float dtype, representing `A`.
37 .. attribute:: constants
39 A 2-dimensional single-column ndarray with float dtype, representing
40 `b`.
42 .. attribute:: variable_names
44 A list of strings giving the names of the variables being
45 constrained. (Used only for consistency checking.)
46 """
47 def __init__(self, variable_names, coefs, constants=None):
48 self.variable_names = list(variable_names)
49 self.coefs = np.atleast_2d(np.asarray(coefs, dtype=float))
50 if constants is None:
51 constants = np.zeros(self.coefs.shape[0], dtype=float)
52 constants = np.asarray(constants, dtype=float)
53 self.constants = atleast_2d_column_default(constants)
54 if self.constants.ndim != 2 or self.constants.shape[1] != 1:
55 raise ValueError("constants is not (convertible to) a column matrix")
56 if self.coefs.ndim != 2 or self.coefs.shape[1] != len(variable_names):
57 raise ValueError("wrong shape for coefs")
58 if self.coefs.shape[0] == 0:
59 raise ValueError("must have at least one row in constraint matrix")
60 if self.coefs.shape[0] != self.constants.shape[0]:
61 raise ValueError("shape mismatch between coefs and constants")
63 __repr__ = repr_pretty_delegate
64 def _repr_pretty_(self, p, cycle):
65 assert not cycle
66 return repr_pretty_impl(p, self,
67 [self.variable_names, self.coefs, self.constants])
69 __getstate__ = no_pickling
71 @classmethod
72 def combine(cls, constraints):
73 """Create a new LinearConstraint by ANDing together several existing
74 LinearConstraints.
76 :arg constraints: An iterable of LinearConstraint objects. Their
77 :attr:`variable_names` attributes must all match.
78 :returns: A new LinearConstraint object.
79 """
80 if not constraints:
81 raise ValueError("no constraints specified")
82 variable_names = constraints[0].variable_names
83 for constraint in constraints:
84 if constraint.variable_names != variable_names:
85 raise ValueError("variable names don't match")
86 coefs = np.row_stack([c.coefs for c in constraints])
87 constants = np.row_stack([c.constants for c in constraints])
88 return cls(variable_names, coefs, constants)
90def test_LinearConstraint():
91 try:
92 from numpy.testing import assert_equal
93 except ImportError:
94 from numpy.testing.utils import assert_equal
95 lc = LinearConstraint(["foo", "bar"], [1, 1])
96 assert lc.variable_names == ["foo", "bar"]
97 assert_equal(lc.coefs, [[1, 1]])
98 assert_equal(lc.constants, [[0]])
100 lc = LinearConstraint(["foo", "bar"], [[1, 1], [2, 3]], [10, 20])
101 assert_equal(lc.coefs, [[1, 1], [2, 3]])
102 assert_equal(lc.constants, [[10], [20]])
104 assert lc.coefs.dtype == np.dtype(float)
105 assert lc.constants.dtype == np.dtype(float)
108 # statsmodels wants to be able to create degenerate constraints like this,
109 # see:
110 # https://github.com/pydata/patsy/issues/89
111 # We used to forbid it, but I guess it's harmless, so why not.
112 lc = LinearConstraint(["a"], [[0]])
113 assert_equal(lc.coefs, [[0]])
115 from nose.tools import assert_raises
116 assert_raises(ValueError, LinearConstraint, ["a"], [[1, 2]])
117 assert_raises(ValueError, LinearConstraint, ["a"], [[[1]]])
118 assert_raises(ValueError, LinearConstraint, ["a"], [[1, 2]], [3, 4])
119 assert_raises(ValueError, LinearConstraint, ["a", "b"], [[1, 2]], [3, 4])
120 assert_raises(ValueError, LinearConstraint, ["a"], [[1]], [[]])
121 assert_raises(ValueError, LinearConstraint, ["a", "b"], [])
122 assert_raises(ValueError, LinearConstraint, ["a", "b"],
123 np.zeros((0, 2)))
125 assert_no_pickling(lc)
127def test_LinearConstraint_combine():
128 comb = LinearConstraint.combine([LinearConstraint(["a", "b"], [1, 0]),
129 LinearConstraint(["a", "b"], [0, 1], [1])])
130 assert comb.variable_names == ["a", "b"]
131 try:
132 from numpy.testing import assert_equal
133 except ImportError:
134 from numpy.testing.utils import assert_equal
135 assert_equal(comb.coefs, [[1, 0], [0, 1]])
136 assert_equal(comb.constants, [[0], [1]])
138 from nose.tools import assert_raises
139 assert_raises(ValueError, LinearConstraint.combine, [])
140 assert_raises(ValueError, LinearConstraint.combine,
141 [LinearConstraint(["a"], [1]), LinearConstraint(["b"], [1])])
144_ops = [
145 Operator(",", 2, -100),
147 Operator("=", 2, 0),
149 Operator("+", 1, 100),
150 Operator("-", 1, 100),
151 Operator("+", 2, 100),
152 Operator("-", 2, 100),
154 Operator("*", 2, 200),
155 Operator("/", 2, 200),
156 ]
158_atomic = ["NUMBER", "VARIABLE"]
160def _token_maker(type, string):
161 def make_token(scanner, token_string):
162 if type == "__OP__":
163 actual_type = token_string
164 else:
165 actual_type = type
166 return Token(actual_type,
167 Origin(string, *scanner.match.span()),
168 token_string)
169 return make_token
171def _tokenize_constraint(string, variable_names):
172 lparen_re = r"\("
173 rparen_re = r"\)"
174 op_re = "|".join([re.escape(op.token_type) for op in _ops])
175 num_re = r"[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?"
176 whitespace_re = r"\s+"
178 # Prefer long matches:
179 variable_names = sorted(variable_names, key=len, reverse=True)
180 variable_re = "|".join([re.escape(n) for n in variable_names])
182 lexicon = [
183 (lparen_re, _token_maker(Token.LPAREN, string)),
184 (rparen_re, _token_maker(Token.RPAREN, string)),
185 (op_re, _token_maker("__OP__", string)),
186 (variable_re, _token_maker("VARIABLE", string)),
187 (num_re, _token_maker("NUMBER", string)),
188 (whitespace_re, None),
189 ]
191 scanner = re.Scanner(lexicon)
192 tokens, leftover = scanner.scan(string)
193 if leftover:
194 offset = len(string) - len(leftover)
195 raise PatsyError("unrecognized token in constraint",
196 Origin(string, offset, offset + 1))
198 return tokens
200def test__tokenize_constraint():
201 code = "2 * (a + b) = q"
202 tokens = _tokenize_constraint(code, ["a", "b", "q"])
203 expecteds = [("NUMBER", 0, 1, "2"),
204 ("*", 2, 3, "*"),
205 (Token.LPAREN, 4, 5, "("),
206 ("VARIABLE", 5, 6, "a"),
207 ("+", 7, 8, "+"),
208 ("VARIABLE", 9, 10, "b"),
209 (Token.RPAREN, 10, 11, ")"),
210 ("=", 12, 13, "="),
211 ("VARIABLE", 14, 15, "q")]
212 for got, expected in zip(tokens, expecteds):
213 assert isinstance(got, Token)
214 assert got.type == expected[0]
215 assert got.origin == Origin(code, expected[1], expected[2])
216 assert got.extra == expected[3]
218 from nose.tools import assert_raises
219 assert_raises(PatsyError, _tokenize_constraint, "1 + @b", ["b"])
220 # Shouldn't raise an error:
221 _tokenize_constraint("1 + @b", ["@b"])
223 # Check we aren't confused by names which are proper prefixes of other
224 # names:
225 for names in (["a", "aa"], ["aa", "a"]):
226 tokens = _tokenize_constraint("a aa a", names)
227 assert len(tokens) == 3
228 assert [t.extra for t in tokens] == ["a", "aa", "a"]
230 # Check that embedding ops and numbers inside a variable name works
231 tokens = _tokenize_constraint("2 * a[1,1],", ["a[1,1]"])
232 assert len(tokens) == 4
233 assert [t.type for t in tokens] == ["NUMBER", "*", "VARIABLE", ","]
234 assert [t.extra for t in tokens] == ["2", "*", "a[1,1]", ","]
236def parse_constraint(string, variable_names):
237 return infix_parse(_tokenize_constraint(string, variable_names),
238 _ops, _atomic)
240class _EvalConstraint(object):
241 def __init__(self, variable_names):
242 self._variable_names = variable_names
243 self._N = len(variable_names)
245 self._dispatch = {
246 ("VARIABLE", 0): self._eval_variable,
247 ("NUMBER", 0): self._eval_number,
248 ("+", 1): self._eval_unary_plus,
249 ("-", 1): self._eval_unary_minus,
250 ("+", 2): self._eval_binary_plus,
251 ("-", 2): self._eval_binary_minus,
252 ("*", 2): self._eval_binary_multiply,
253 ("/", 2): self._eval_binary_div,
254 ("=", 2): self._eval_binary_eq,
255 (",", 2): self._eval_binary_comma,
256 }
258 # General scheme: there are 2 types we deal with:
259 # - linear combinations ("lincomb"s) of variables and constants,
260 # represented as ndarrays with size N+1
261 # The last entry is the constant, so [10, 20, 30] means 10x + 20y +
262 # 30.
263 # - LinearConstraint objects
265 def is_constant(self, coefs):
266 return np.all(coefs[:self._N] == 0)
268 def _eval_variable(self, tree):
269 var = tree.token.extra
270 coefs = np.zeros((self._N + 1,), dtype=float)
271 coefs[self._variable_names.index(var)] = 1
272 return coefs
274 def _eval_number(self, tree):
275 coefs = np.zeros((self._N + 1,), dtype=float)
276 coefs[-1] = float(tree.token.extra)
277 return coefs
279 def _eval_unary_plus(self, tree):
280 return self.eval(tree.args[0])
282 def _eval_unary_minus(self, tree):
283 return -1 * self.eval(tree.args[0])
285 def _eval_binary_plus(self, tree):
286 return self.eval(tree.args[0]) + self.eval(tree.args[1])
288 def _eval_binary_minus(self, tree):
289 return self.eval(tree.args[0]) - self.eval(tree.args[1])
291 def _eval_binary_div(self, tree):
292 left = self.eval(tree.args[0])
293 right = self.eval(tree.args[1])
294 if not self.is_constant(right):
295 raise PatsyError("Can't divide by a variable in a linear "
296 "constraint", tree.args[1])
297 return left / right[-1]
299 def _eval_binary_multiply(self, tree):
300 left = self.eval(tree.args[0])
301 right = self.eval(tree.args[1])
302 if self.is_constant(left):
303 return left[-1] * right
304 elif self.is_constant(right):
305 return left * right[-1]
306 else:
307 raise PatsyError("Can't multiply one variable by another "
308 "in a linear constraint", tree)
310 def _eval_binary_eq(self, tree):
311 # Handle "a1 = a2 = a3", which is parsed as "(a1 = a2) = a3"
312 args = list(tree.args)
313 constraints = []
314 for i, arg in enumerate(args):
315 if arg.type == "=":
316 constraints.append(self.eval(arg, constraint=True))
317 # make our left argument be their right argument, or
318 # vice-versa
319 args[i] = arg.args[1 - i]
320 left = self.eval(args[0])
321 right = self.eval(args[1])
322 coefs = left[:self._N] - right[:self._N]
323 if np.all(coefs == 0):
324 raise PatsyError("no variables appear in constraint", tree)
325 constant = -left[-1] + right[-1]
326 constraint = LinearConstraint(self._variable_names, coefs, constant)
327 constraints.append(constraint)
328 return LinearConstraint.combine(constraints)
330 def _eval_binary_comma(self, tree):
331 left = self.eval(tree.args[0], constraint=True)
332 right = self.eval(tree.args[1], constraint=True)
333 return LinearConstraint.combine([left, right])
335 def eval(self, tree, constraint=False):
336 key = (tree.type, len(tree.args))
337 assert key in self._dispatch
338 val = self._dispatch[key](tree)
339 if constraint:
340 # Force it to be a constraint
341 if isinstance(val, LinearConstraint):
342 return val
343 else:
344 assert val.size == self._N + 1
345 if np.all(val[:self._N] == 0):
346 raise PatsyError("term is constant, with no variables",
347 tree)
348 return LinearConstraint(self._variable_names,
349 val[:self._N],
350 -val[-1])
351 else:
352 # Force it to *not* be a constraint
353 if isinstance(val, LinearConstraint):
354 raise PatsyError("unexpected constraint object", tree)
355 return val
357def linear_constraint(constraint_like, variable_names):
358 """This is the internal interface implementing
359 DesignInfo.linear_constraint, see there for docs."""
360 if isinstance(constraint_like, LinearConstraint):
361 if constraint_like.variable_names != variable_names:
362 raise ValueError("LinearConstraint has wrong variable_names "
363 "(got %r, expected %r)"
364 % (constraint_like.variable_names,
365 variable_names))
366 return constraint_like
368 if isinstance(constraint_like, Mapping):
369 # Simple conjunction-of-equality constraints can be specified as
370 # dicts. {"x": 1, "y": 2} -> tests x = 1 and y = 2. Keys can be
371 # either variable names, or variable indices.
372 coefs = np.zeros((len(constraint_like), len(variable_names)),
373 dtype=float)
374 constants = np.zeros(len(constraint_like))
375 used = set()
376 for i, (name, value) in enumerate(six.iteritems(constraint_like)):
377 if name in variable_names:
378 idx = variable_names.index(name)
379 elif isinstance(name, six.integer_types):
380 idx = name
381 else:
382 raise ValueError("unrecognized variable name/index %r"
383 % (name,))
384 if idx in used:
385 raise ValueError("duplicated constraint on %r"
386 % (variable_names[idx],))
387 used.add(idx)
388 coefs[i, idx] = 1
389 constants[i] = value
390 return LinearConstraint(variable_names, coefs, constants)
392 if isinstance(constraint_like, str):
393 constraint_like = [constraint_like]
394 # fall-through
396 if (isinstance(constraint_like, list)
397 and constraint_like
398 and isinstance(constraint_like[0], str)):
399 constraints = []
400 for code in constraint_like:
401 if not isinstance(code, str):
402 raise ValueError("expected a string, not %r" % (code,))
403 tree = parse_constraint(code, variable_names)
404 evaluator = _EvalConstraint(variable_names)
405 constraints.append(evaluator.eval(tree, constraint=True))
406 return LinearConstraint.combine(constraints)
408 if isinstance(constraint_like, tuple):
409 if len(constraint_like) != 2:
410 raise ValueError("constraint tuple must have length 2")
411 coef, constants = constraint_like
412 return LinearConstraint(variable_names, coef, constants)
414 # assume a raw ndarray
415 coefs = np.asarray(constraint_like, dtype=float)
416 return LinearConstraint(variable_names, coefs)
419def _check_lincon(input, varnames, coefs, constants):
420 try:
421 from numpy.testing import assert_equal
422 except ImportError:
423 from numpy.testing.utils import assert_equal
424 got = linear_constraint(input, varnames)
425 print("got", got)
426 expected = LinearConstraint(varnames, coefs, constants)
427 print("expected", expected)
428 assert_equal(got.variable_names, expected.variable_names)
429 assert_equal(got.coefs, expected.coefs)
430 assert_equal(got.constants, expected.constants)
431 assert_equal(got.coefs.dtype, np.dtype(float))
432 assert_equal(got.constants.dtype, np.dtype(float))
435def test_linear_constraint():
436 from nose.tools import assert_raises
437 from patsy.compat import OrderedDict
438 t = _check_lincon
440 t(LinearConstraint(["a", "b"], [2, 3]), ["a", "b"], [[2, 3]], [[0]])
441 assert_raises(ValueError, linear_constraint,
442 LinearConstraint(["b", "a"], [2, 3]),
443 ["a", "b"])
445 t({"a": 2}, ["a", "b"], [[1, 0]], [[2]])
446 t(OrderedDict([("a", 2), ("b", 3)]),
447 ["a", "b"], [[1, 0], [0, 1]], [[2], [3]])
448 t(OrderedDict([("a", 2), ("b", 3)]),
449 ["b", "a"], [[0, 1], [1, 0]], [[2], [3]])
451 t({0: 2}, ["a", "b"], [[1, 0]], [[2]])
452 t(OrderedDict([(0, 2), (1, 3)]), ["a", "b"], [[1, 0], [0, 1]], [[2], [3]])
454 t(OrderedDict([("a", 2), (1, 3)]),
455 ["a", "b"], [[1, 0], [0, 1]], [[2], [3]])
457 assert_raises(ValueError, linear_constraint, {"q": 1}, ["a", "b"])
458 assert_raises(ValueError, linear_constraint, {"a": 1, 0: 2}, ["a", "b"])
460 t(np.array([2, 3]), ["a", "b"], [[2, 3]], [[0]])
461 t(np.array([[2, 3], [4, 5]]), ["a", "b"], [[2, 3], [4, 5]], [[0], [0]])
463 t("a = 2", ["a", "b"], [[1, 0]], [[2]])
464 t("a - 2", ["a", "b"], [[1, 0]], [[2]])
465 t("a + 1 = 3", ["a", "b"], [[1, 0]], [[2]])
466 t("a + b = 3", ["a", "b"], [[1, 1]], [[3]])
467 t("a = 2, b = 3", ["a", "b"], [[1, 0], [0, 1]], [[2], [3]])
468 t("b = 3, a = 2", ["a", "b"], [[0, 1], [1, 0]], [[3], [2]])
470 t(["a = 2", "b = 3"], ["a", "b"], [[1, 0], [0, 1]], [[2], [3]])
472 assert_raises(ValueError, linear_constraint, ["a", {"b": 0}], ["a", "b"])
474 # Actual evaluator tests
475 t("2 * (a + b/3) + b + 2*3/4 = 1 + 2*3", ["a", "b"],
476 [[2, 2.0/3 + 1]], [[7 - 6.0/4]])
477 t("+2 * -a", ["a", "b"], [[-2, 0]], [[0]])
478 t("a - b, a + b = 2", ["a", "b"], [[1, -1], [1, 1]], [[0], [2]])
479 t("a = 1, a = 2, a = 3", ["a", "b"],
480 [[1, 0], [1, 0], [1, 0]], [[1], [2], [3]])
481 t("a * 2", ["a", "b"], [[2, 0]], [[0]])
482 t("-a = 1", ["a", "b"], [[-1, 0]], [[1]])
483 t("(2 + a - a) * b", ["a", "b"], [[0, 2]], [[0]])
485 t("a = 1 = b", ["a", "b"], [[1, 0], [0, -1]], [[1], [-1]])
486 t("a = (1 = b)", ["a", "b"], [[0, -1], [1, 0]], [[-1], [1]])
487 t("a = 1, a = b = c", ["a", "b", "c"],
488 [[1, 0, 0], [1, -1, 0], [0, 1, -1]], [[1], [0], [0]])
490 # One should never do this of course, but test that it works anyway...
491 t("a + 1 = 2", ["a", "a + 1"], [[0, 1]], [[2]])
493 t(([10, 20], [30]), ["a", "b"], [[10, 20]], [[30]])
494 t(([[10, 20], [20, 40]], [[30], [35]]), ["a", "b"],
495 [[10, 20], [20, 40]], [[30], [35]])
496 # wrong-length tuple
497 assert_raises(ValueError, linear_constraint,
498 ([1, 0], [0], [0]), ["a", "b"])
499 assert_raises(ValueError, linear_constraint, ([1, 0],), ["a", "b"])
501 t([10, 20], ["a", "b"], [[10, 20]], [[0]])
502 t([[10, 20], [20, 40]], ["a", "b"], [[10, 20], [20, 40]], [[0], [0]])
503 t(np.array([10, 20]), ["a", "b"], [[10, 20]], [[0]])
504 t(np.array([[10, 20], [20, 40]]), ["a", "b"],
505 [[10, 20], [20, 40]], [[0], [0]])
507 # unknown object type
508 assert_raises(ValueError, linear_constraint, None, ["a", "b"])
511_parse_eval_error_tests = [
512 # Bad token
513 "a + <f>oo",
514 # No pure constant equalities
515 "a = 1, <1 = 1>, b = 1",
516 "a = 1, <b * 2 - b + (-2/2 * b)>",
517 "a = 1, <1>, b = 2",
518 "a = 1, <2 * b = b + b>, c",
519 # No non-linearities
520 "a + <a * b> + c",
521 "a + 2 / <b> + c",
522 # Constraints are not numbers
523 "a = 1, 2 * <(a = b)>, c",
524 "a = 1, a + <(a = b)>, c",
525 "a = 1, <(a, b)> + 2, c",
526]
529def test_eval_errors():
530 def doit(bad_code):
531 return linear_constraint(bad_code, ["a", "b", "c"])
532 _parsing_error_test(doit, _parse_eval_error_tests)