Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/patsy/util.py : 18%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of Patsy
2# Copyright (C) 2011-2013 Nathaniel Smith <njs@pobox.com>
3# See file LICENSE.txt for license information.
5# Some generic utilities.
7__all__ = ["atleast_2d_column_default", "uniqueify_list",
8 "widest_float", "widest_complex", "wide_dtype_for", "widen",
9 "repr_pretty_delegate", "repr_pretty_impl",
10 "SortAnythingKey", "safe_scalar_isnan", "safe_isnan",
11 "iterable",
12 "have_pandas",
13 "have_pandas_categorical",
14 "have_pandas_categorical_dtype",
15 "pandas_Categorical_from_codes",
16 "pandas_Categorical_categories",
17 "pandas_Categorical_codes",
18 "safe_is_pandas_categorical_dtype",
19 "safe_is_pandas_categorical",
20 "safe_issubdtype",
21 "no_pickling",
22 "assert_no_pickling",
23 "safe_string_eq",
24 ]
26import sys
27import numpy as np
28import six
29from six.moves import cStringIO as StringIO
30from .compat import optional_dep_ok
32try:
33 import pandas
34except ImportError:
35 have_pandas = False
36else:
37 have_pandas = True
39# Pandas versions < 0.9.0 don't have Categorical
40# Can drop this guard whenever we drop support for such older versions of
41# pandas.
42have_pandas_categorical = (have_pandas and hasattr(pandas, "Categorical"))
43if not have_pandas:
44 have_pandas_categorical_dtype = False
45 _pandas_is_categorical_dtype = None
46else:
47 if hasattr(pandas, "api"):
48 # This is available starting in pandas v0.19.0
49 have_pandas_categorical_dtype = True
50 _pandas_is_categorical_dtype = pandas.api.types.is_categorical_dtype
51 else:
52 # This is needed for pandas v0.18.0 and earlier
53 _pandas_is_categorical_dtype = getattr(pandas.core.common,
54 "is_categorical_dtype", None)
55 have_pandas_categorical_dtype = (_pandas_is_categorical_dtype
56 is not None)
59# Passes through Series and DataFrames, call np.asarray() on everything else
60def asarray_or_pandas(a, copy=False, dtype=None, subok=False):
61 if have_pandas:
62 if isinstance(a, (pandas.Series, pandas.DataFrame)):
63 # The .name attribute on Series is discarded when passing through
64 # the constructor:
65 # https://github.com/pydata/pandas/issues/1578
66 extra_args = {}
67 if hasattr(a, "name"):
68 extra_args["name"] = a.name
69 return a.__class__(a, copy=copy, dtype=dtype, **extra_args)
70 return np.array(a, copy=copy, dtype=dtype, subok=subok)
73def test_asarray_or_pandas():
74 import warnings
75 assert type(asarray_or_pandas([1, 2, 3])) is np.ndarray
76 with warnings.catch_warnings() as w:
77 warnings.filterwarnings('ignore', 'the matrix subclass',
78 PendingDeprecationWarning)
79 assert type(asarray_or_pandas(np.matrix([[1, 2, 3]]))) is np.ndarray
80 assert type(asarray_or_pandas(
81 np.matrix([[1, 2, 3]]), subok=True)) is np.matrix
82 assert w is None
83 a = np.array([1, 2, 3])
84 assert asarray_or_pandas(a) is a
85 a_copy = asarray_or_pandas(a, copy=True)
86 assert np.array_equal(a, a_copy)
87 a_copy[0] = 100
88 assert not np.array_equal(a, a_copy)
89 assert np.allclose(asarray_or_pandas([1, 2, 3], dtype=float),
90 [1.0, 2.0, 3.0])
91 assert asarray_or_pandas([1, 2, 3], dtype=float).dtype == np.dtype(float)
92 a_view = asarray_or_pandas(a, dtype=a.dtype)
93 a_view[0] = 99
94 assert a[0] == 99
95 global have_pandas
96 if have_pandas:
97 s = pandas.Series([1, 2, 3], name="A", index=[10, 20, 30])
98 s_view1 = asarray_or_pandas(s)
99 assert s_view1.name == "A"
100 assert np.array_equal(s_view1.index, [10, 20, 30])
101 s_view1[10] = 101
102 assert s[10] == 101
103 s_copy = asarray_or_pandas(s, copy=True)
104 assert s_copy.name == "A"
105 assert np.array_equal(s_copy.index, [10, 20, 30])
106 assert np.array_equal(s_copy, s)
107 s_copy[10] = 100
108 assert not np.array_equal(s_copy, s)
109 assert asarray_or_pandas(s, dtype=float).dtype == np.dtype(float)
110 s_view2 = asarray_or_pandas(s, dtype=s.dtype)
111 assert s_view2.name == "A"
112 assert np.array_equal(s_view2.index, [10, 20, 30])
113 s_view2[10] = 99
114 assert s[10] == 99
116 df = pandas.DataFrame([[1, 2, 3]],
117 columns=["A", "B", "C"],
118 index=[10])
119 df_view1 = asarray_or_pandas(df)
120 df_view1.loc[10, "A"] = 101
121 assert np.array_equal(df_view1.columns, ["A", "B", "C"])
122 assert np.array_equal(df_view1.index, [10])
123 assert df.loc[10, "A"] == 101
124 df_copy = asarray_or_pandas(df, copy=True)
125 assert np.array_equal(df_copy, df)
126 assert np.array_equal(df_copy.columns, ["A", "B", "C"])
127 assert np.array_equal(df_copy.index, [10])
128 df_copy.loc[10, "A"] = 100
129 assert not np.array_equal(df_copy, df)
130 df_converted = asarray_or_pandas(df, dtype=float)
131 assert df_converted["A"].dtype == np.dtype(float)
132 assert np.allclose(df_converted, df)
133 assert np.array_equal(df_converted.columns, ["A", "B", "C"])
134 assert np.array_equal(df_converted.index, [10])
135 df_view2 = asarray_or_pandas(df, dtype=df["A"].dtype)
136 assert np.array_equal(df_view2.columns, ["A", "B", "C"])
137 assert np.array_equal(df_view2.index, [10])
138 # This actually makes a copy, not a view, because of a pandas bug:
139 # https://github.com/pydata/pandas/issues/1572
140 assert np.array_equal(df, df_view2)
141 # df_view2[0][0] = 99
142 # assert df[0][0] == 99
144 had_pandas = have_pandas
145 try:
146 have_pandas = False
147 assert (type(asarray_or_pandas(pandas.Series([1, 2, 3])))
148 is np.ndarray)
149 assert (type(asarray_or_pandas(pandas.DataFrame([[1, 2, 3]])))
150 is np.ndarray)
151 finally:
152 have_pandas = had_pandas
154# Like np.atleast_2d, but this converts lower-dimensional arrays into columns,
155# instead of rows. It also converts ndarray subclasses into basic ndarrays,
156# which makes it easier to guarantee correctness. However, there are many
157# places in the code where we want to preserve pandas indexing information if
158# present, so there is also an option
159def atleast_2d_column_default(a, preserve_pandas=False):
160 if preserve_pandas and have_pandas:
161 if isinstance(a, pandas.Series):
162 return pandas.DataFrame(a)
163 elif isinstance(a, pandas.DataFrame):
164 return a
165 # fall through
166 a = np.asarray(a)
167 a = np.atleast_1d(a)
168 if a.ndim <= 1:
169 a = a.reshape((-1, 1))
170 assert a.ndim >= 2
171 return a
174def test_atleast_2d_column_default():
175 import warnings
176 assert np.all(atleast_2d_column_default([1, 2, 3]) == [[1], [2], [3]])
178 assert atleast_2d_column_default(1).shape == (1, 1)
179 assert atleast_2d_column_default([1]).shape == (1, 1)
180 assert atleast_2d_column_default([[1]]).shape == (1, 1)
181 assert atleast_2d_column_default([[[1]]]).shape == (1, 1, 1)
183 assert atleast_2d_column_default([1, 2, 3]).shape == (3, 1)
184 assert atleast_2d_column_default([[1], [2], [3]]).shape == (3, 1)
186 with warnings.catch_warnings() as w:
187 warnings.filterwarnings('ignore', 'the matrix subclass',
188 PendingDeprecationWarning)
189 assert type(atleast_2d_column_default(np.matrix(1))) == np.ndarray
190 assert w is None
192 global have_pandas
193 if have_pandas:
194 assert (type(atleast_2d_column_default(pandas.Series([1, 2])))
195 == np.ndarray)
196 assert (type(atleast_2d_column_default(pandas.DataFrame([[1], [2]])))
197 == np.ndarray)
198 assert (type(atleast_2d_column_default(pandas.Series([1, 2]),
199 preserve_pandas=True))
200 == pandas.DataFrame)
201 assert (type(atleast_2d_column_default(pandas.DataFrame([[1], [2]]),
202 preserve_pandas=True))
203 == pandas.DataFrame)
204 s = pandas.Series([10, 11, 12], name="hi", index=["a", "b", "c"])
205 df = atleast_2d_column_default(s, preserve_pandas=True)
206 assert isinstance(df, pandas.DataFrame)
207 assert np.all(df.columns == ["hi"])
208 assert np.all(df.index == ["a", "b", "c"])
209 with warnings.catch_warnings() as w:
210 warnings.filterwarnings('ignore', 'the matrix subclass',
211 PendingDeprecationWarning)
212 assert (type(atleast_2d_column_default(np.matrix(1),
213 preserve_pandas=True))
214 == np.ndarray)
215 assert w is None
216 assert (type(atleast_2d_column_default([1, 2, 3], preserve_pandas=True))
217 == np.ndarray)
219 if have_pandas:
220 had_pandas = have_pandas
221 try:
222 have_pandas = False
223 assert (type(atleast_2d_column_default(pandas.Series([1, 2]),
224 preserve_pandas=True))
225 == np.ndarray)
226 assert (type(atleast_2d_column_default(pandas.DataFrame([[1], [2]]),
227 preserve_pandas=True))
228 == np.ndarray)
229 finally:
230 have_pandas = had_pandas
232# A version of .reshape() that knows how to down-convert a 1-column
233# pandas.DataFrame into a pandas.Series. Useful for code that wants to be
234# agnostic between 1d and 2d data, with the pattern:
235# new_a = atleast_2d_column_default(a, preserve_pandas=True)
236# # do stuff to new_a, which can assume it's always 2 dimensional
237# return pandas_friendly_reshape(new_a, a.shape)
238def pandas_friendly_reshape(a, new_shape):
239 if not have_pandas:
240 return a.reshape(new_shape)
241 if not isinstance(a, pandas.DataFrame):
242 return a.reshape(new_shape)
243 # we have a DataFrame. Only supported reshapes are no-op, and
244 # single-column DataFrame -> Series.
245 if new_shape == a.shape:
246 return a
247 if len(new_shape) == 1 and a.shape[1] == 1:
248 if new_shape[0] != a.shape[0]:
249 raise ValueError("arrays have incompatible sizes")
250 return a[a.columns[0]]
251 raise ValueError("cannot reshape a DataFrame with shape %s to shape %s"
252 % (a.shape, new_shape))
254def test_pandas_friendly_reshape():
255 from nose.tools import assert_raises
256 global have_pandas
257 assert np.allclose(pandas_friendly_reshape(np.arange(10).reshape(5, 2),
258 (2, 5)),
259 np.arange(10).reshape(2, 5))
260 if have_pandas:
261 df = pandas.DataFrame({"x": [1, 2, 3]}, index=["a", "b", "c"])
262 noop = pandas_friendly_reshape(df, (3, 1))
263 assert isinstance(noop, pandas.DataFrame)
264 assert np.array_equal(noop.index, ["a", "b", "c"])
265 assert np.array_equal(noop.columns, ["x"])
266 squozen = pandas_friendly_reshape(df, (3,))
267 assert isinstance(squozen, pandas.Series)
268 assert np.array_equal(squozen.index, ["a", "b", "c"])
269 assert squozen.name == "x"
271 assert_raises(ValueError, pandas_friendly_reshape, df, (4,))
272 assert_raises(ValueError, pandas_friendly_reshape, df, (1, 3))
273 assert_raises(ValueError, pandas_friendly_reshape, df, (3, 3))
275 had_pandas = have_pandas
276 try:
277 have_pandas = False
278 # this will try to do a reshape directly, and DataFrames *have* no
279 # reshape method
280 assert_raises(AttributeError, pandas_friendly_reshape, df, (3,))
281 finally:
282 have_pandas = had_pandas
284def uniqueify_list(seq):
285 seq_new = []
286 seen = set()
287 for obj in seq:
288 if obj not in seen:
289 seq_new.append(obj)
290 seen.add(obj)
291 return seq_new
293def test_to_uniqueify_list():
294 assert uniqueify_list([1, 2, 3]) == [1, 2, 3]
295 assert uniqueify_list([1, 3, 3, 2, 3, 1]) == [1, 3, 2]
296 assert uniqueify_list([3, 2, 1, 4, 1, 2, 3]) == [3, 2, 1, 4]
298for float_type in ("float128", "float96", "float64"):
299 if hasattr(np, float_type):
300 widest_float = getattr(np, float_type)
301 break
302else: # pragma: no cover
303 assert False
304for complex_type in ("complex256", "complex196", "complex128"):
305 if hasattr(np, complex_type):
306 widest_complex = getattr(np, complex_type)
307 break
308else: # pragma: no cover
309 assert False
311def wide_dtype_for(arr):
312 arr = np.asarray(arr)
313 if (safe_issubdtype(arr.dtype, np.integer)
314 or safe_issubdtype(arr.dtype, np.floating)):
315 return widest_float
316 elif safe_issubdtype(arr.dtype, np.complexfloating):
317 return widest_complex
318 raise ValueError("cannot widen a non-numeric type %r" % (arr.dtype,))
320def widen(arr):
321 return np.asarray(arr, dtype=wide_dtype_for(arr))
323def test_wide_dtype_for_and_widen():
324 assert np.allclose(widen([1, 2, 3]), [1, 2, 3])
325 assert widen([1, 2, 3]).dtype == widest_float
326 assert np.allclose(widen([1.0, 2.0, 3.0]), [1, 2, 3])
327 assert widen([1.0, 2.0, 3.0]).dtype == widest_float
328 assert np.allclose(widen([1+0j, 2, 3]), [1, 2, 3])
329 assert widen([1+0j, 2, 3]).dtype == widest_complex
330 from nose.tools import assert_raises
331 assert_raises(ValueError, widen, ["hi"])
333class PushbackAdapter(object):
334 def __init__(self, it):
335 self._it = it
336 self._pushed = []
338 def __iter__(self):
339 return self
341 def push_back(self, obj):
342 self._pushed.append(obj)
344 def next(self):
345 if self._pushed:
346 return self._pushed.pop()
347 else:
348 # May raise StopIteration
349 return six.advance_iterator(self._it)
350 __next__ = next
352 def peek(self):
353 try:
354 obj = six.advance_iterator(self)
355 except StopIteration:
356 raise ValueError("no more data")
357 self.push_back(obj)
358 return obj
360 def has_more(self):
361 try:
362 self.peek()
363 except ValueError:
364 return False
365 else:
366 return True
368def test_PushbackAdapter():
369 it = PushbackAdapter(iter([1, 2, 3, 4]))
370 assert it.has_more()
371 assert six.advance_iterator(it) == 1
372 it.push_back(0)
373 assert six.advance_iterator(it) == 0
374 assert six.advance_iterator(it) == 2
375 assert it.peek() == 3
376 it.push_back(10)
377 assert it.peek() == 10
378 it.push_back(20)
379 assert it.peek() == 20
380 assert it.has_more()
381 assert list(it) == [20, 10, 3, 4]
382 assert not it.has_more()
384# The IPython pretty-printer gives very nice output that is difficult to get
385# otherwise, e.g., look how much more readable this is than if it were all
386# smooshed onto one line:
387#
388# ModelDesc(input_code='y ~ x*asdf',
389# lhs_terms=[Term([EvalFactor('y')])],
390# rhs_terms=[Term([]),
391# Term([EvalFactor('x')]),
392# Term([EvalFactor('asdf')]),
393# Term([EvalFactor('x'), EvalFactor('asdf')])],
394# )
395#
396# But, we don't want to assume it always exists; nor do we want to be
397# re-writing every repr function twice, once for regular repr and once for
398# the pretty printer. So, here's an ugly fallback implementation that can be
399# used unconditionally to implement __repr__ in terms of _pretty_repr_.
400#
401# Pretty printer docs:
402# http://ipython.org/ipython-doc/dev/api/generated/IPython.lib.pretty.html
404class _MiniPPrinter(object):
405 def __init__(self):
406 self._out = StringIO()
407 self.indentation = 0
409 def text(self, text):
410 self._out.write(text)
412 def breakable(self, sep=" "):
413 self._out.write(sep)
415 def begin_group(self, _, text):
416 self.text(text)
418 def end_group(self, _, text):
419 self.text(text)
421 def pretty(self, obj):
422 if hasattr(obj, "_repr_pretty_"):
423 obj._repr_pretty_(self, False)
424 else:
425 self.text(repr(obj))
427 def getvalue(self):
428 return self._out.getvalue()
430def _mini_pretty(obj):
431 printer = _MiniPPrinter()
432 printer.pretty(obj)
433 return printer.getvalue()
435def repr_pretty_delegate(obj):
436 # If IPython is already loaded, then might as well use it. (Most commonly
437 # this will occur if we are in an IPython session, but somehow someone has
438 # called repr() directly. This can happen for example if printing an
439 # container like a namedtuple that IPython lacks special code for
440 # pretty-printing.) But, if IPython is not already imported, we do not
441 # attempt to import it. This makes patsy itself faster to import (as of
442 # Nov. 2012 I measured the extra overhead from loading IPython as ~4
443 # seconds on a cold cache), it prevents IPython from automatically
444 # spawning a bunch of child processes (!) which may not be what you want
445 # if you are not otherwise using IPython, and it avoids annoying the
446 # pandas people who have some hack to tell whether you are using IPython
447 # in their test suite (see patsy bug #12).
448 if optional_dep_ok and "IPython" in sys.modules:
449 from IPython.lib.pretty import pretty
450 return pretty(obj)
451 else:
452 return _mini_pretty(obj)
454def repr_pretty_impl(p, obj, args, kwargs=[]):
455 name = obj.__class__.__name__
456 p.begin_group(len(name) + 1, "%s(" % (name,))
457 started = [False]
458 def new_item():
459 if started[0]:
460 p.text(",")
461 p.breakable()
462 started[0] = True
463 for arg in args:
464 new_item()
465 p.pretty(arg)
466 for label, value in kwargs:
467 new_item()
468 p.begin_group(len(label) + 1, "%s=" % (label,))
469 p.pretty(value)
470 p.end_group(len(label) + 1, "")
471 p.end_group(len(name) + 1, ")")
473def test_repr_pretty():
474 assert repr_pretty_delegate("asdf") == "'asdf'"
475 printer = _MiniPPrinter()
476 class MyClass(object):
477 pass
478 repr_pretty_impl(printer, MyClass(),
479 ["a", 1], [("foo", "bar"), ("asdf", "asdf")])
480 assert printer.getvalue() == "MyClass('a', 1, foo='bar', asdf='asdf')"
482# In Python 3, objects of different types are not generally comparable, so a
483# list of heterogenous types cannot be sorted. This implements a Python 2
484# style comparison for arbitrary types. (It works on Python 2 too, but just
485# gives you the built-in ordering.) To understand why this is tricky, consider
486# this example:
487# a = 1 # type 'int'
488# b = 1.5 # type 'float'
489# class gggg:
490# pass
491# c = gggg()
492# sorted([a, b, c])
493# The fallback ordering sorts by class name, so according to the fallback
494# ordering, we have b < c < a. But, of course, a and b are comparable (even
495# though they're of different types), so we also have a < b. This is
496# inconsistent. There is no general solution to this problem (which I guess is
497# why Python 3 stopped trying), but the worst offender is all the different
498# "numeric" classes (int, float, complex, decimal, rational...), so as a
499# special-case, we sort all numeric objects to the start of the list.
500# (In Python 2, there is also a similar special case for str and unicode, but
501# we don't have to worry about that for Python 3.)
502class SortAnythingKey(object):
503 def __init__(self, obj):
504 self.obj = obj
506 def _python_lt(self, other_obj):
507 # On Py2, < never raises an error, so this is just <. (Actually it
508 # does raise a TypeError for comparing complex to numeric, but not for
509 # comparisons of complex to other types. Sigh. Whatever.)
510 # On Py3, this returns a bool if available, and otherwise returns
511 # NotImplemented
512 try:
513 return self.obj < other_obj
514 except TypeError:
515 return NotImplemented
517 def __lt__(self, other):
518 assert isinstance(other, SortAnythingKey)
519 result = self._python_lt(other.obj)
520 if result is not NotImplemented:
521 return result
522 # Okay, that didn't work, time to fall back.
523 # If one of these is a number, then it is smaller.
524 if self._python_lt(0) is not NotImplemented:
525 return True
526 if other._python_lt(0) is not NotImplemented:
527 return False
528 # Also check ==, since it may well be defined for otherwise
529 # unorderable objects, and if so then we should be consistent with
530 # it:
531 if self.obj == other.obj:
532 return False
533 # Otherwise, we break ties based on class name and memory position
534 return ((self.obj.__class__.__name__, id(self.obj))
535 < (other.obj.__class__.__name__, id(other.obj)))
537def test_SortAnythingKey():
538 assert sorted([20, 10, 0, 15], key=SortAnythingKey) == [0, 10, 15, 20]
539 assert sorted([10, -1.5], key=SortAnythingKey) == [-1.5, 10]
540 assert sorted([10, "a", 20.5, "b"], key=SortAnythingKey) == [10, 20.5, "a", "b"]
541 class a(object):
542 pass
543 class b(object):
544 pass
545 class z(object):
546 pass
547 a_obj = a()
548 b_obj = b()
549 z_obj = z()
550 o_obj = object()
551 assert (sorted([z_obj, a_obj, 1, b_obj, o_obj], key=SortAnythingKey)
552 == [1, a_obj, b_obj, o_obj, z_obj])
554# NaN checking functions that work on arbitrary objects, on old Python
555# versions (math.isnan is only in 2.6+), etc.
556def safe_scalar_isnan(x):
557 try:
558 return np.isnan(float(x))
559 except (TypeError, ValueError, NotImplementedError):
560 return False
561safe_isnan = np.vectorize(safe_scalar_isnan, otypes=[bool])
563def test_safe_scalar_isnan():
564 assert not safe_scalar_isnan(True)
565 assert not safe_scalar_isnan(None)
566 assert not safe_scalar_isnan("sadf")
567 assert not safe_scalar_isnan((1, 2, 3))
568 assert not safe_scalar_isnan(np.asarray([1, 2, 3]))
569 assert not safe_scalar_isnan([np.nan])
570 assert safe_scalar_isnan(np.nan)
571 assert safe_scalar_isnan(np.float32(np.nan))
572 assert safe_scalar_isnan(float(np.nan))
574def test_safe_isnan():
575 assert np.array_equal(safe_isnan([1, True, None, np.nan, "asdf"]),
576 [False, False, False, True, False])
577 assert safe_isnan(np.nan).ndim == 0
578 assert safe_isnan(np.nan)
579 assert not safe_isnan(None)
580 # raw isnan raises a *different* error for strings than for objects:
581 assert not safe_isnan("asdf")
583def iterable(obj):
584 try:
585 iter(obj)
586 except Exception:
587 return False
588 return True
590def test_iterable():
591 assert iterable("asdf")
592 assert iterable([])
593 assert iterable({"a": 1})
594 assert not iterable(1)
595 assert not iterable(iterable)
597##### Handling Pandas's categorical stuff is horrible and hateful
599# Basically they decided that they didn't like how numpy does things, so their
600# categorical stuff is *kinda* like how numpy would do it (e.g. they have a
601# special ".dtype" attribute to mark categorical data), so by default you'll
602# find yourself using the same code paths to handle pandas categorical data
603# and other non-categorical data. BUT, all the idioms for detecting
604# categorical data blow up with errors if you try them with real numpy dtypes,
605# and all numpy's idioms for detecting non-categorical types blow up with
606# errors if you try them with pandas categorical stuff. So basically they have
607# just poisoned all code that touches dtypes; the old numpy stuff is unsafe,
608# and you must use special code like below.
609#
610# Also there are hoops to jump through to handle both the old style
611# (Categorical objects) and new-style (Series with dtype="category").
613# Needed to support pandas < 0.15
614def pandas_Categorical_from_codes(codes, categories):
615 assert have_pandas_categorical
617 # Old versions of pandas sometimes fail to coerce this to an array and
618 # just return it directly from .labels (?!).
619 codes = np.asarray(codes)
620 if hasattr(pandas.Categorical, "from_codes"):
621 return pandas.Categorical.from_codes(codes, categories)
622 else:
623 return pandas.Categorical(codes, categories)
625def test_pandas_Categorical_from_codes():
626 if not have_pandas_categorical:
627 return
628 c = pandas_Categorical_from_codes([1, 1, 0, -1], ["a", "b"])
629 assert np.all(np.asarray(c)[:-1] == ["b", "b", "a"])
630 assert np.isnan(np.asarray(c)[-1])
632# Needed to support pandas < 0.15
633def pandas_Categorical_categories(cat):
634 # In 0.15+, a categorical Series has a .cat attribute which is similar to
635 # a Categorical object, and Categorical objects are what have .categories
636 # and .codes attributes.
637 if hasattr(cat, "cat"):
638 cat = cat.cat
639 if hasattr(cat, "categories"):
640 return cat.categories
641 else:
642 return cat.levels
644# Needed to support pandas < 0.15
645def pandas_Categorical_codes(cat):
646 # In 0.15+, a categorical Series has a .cat attribute which is a
647 # Categorical object, and Categorical objects are what have .categories /
648 # .codes attributes.
649 if hasattr(cat, "cat"):
650 cat = cat.cat
651 if hasattr(cat, "codes"):
652 return cat.codes
653 else:
654 return cat.labels
656def test_pandas_Categorical_accessors():
657 if not have_pandas_categorical:
658 return
659 c = pandas_Categorical_from_codes([1, 1, 0, -1], ["a", "b"])
660 assert np.all(pandas_Categorical_categories(c) == ["a", "b"])
661 assert np.all(pandas_Categorical_codes(c) == [1, 1, 0, -1])
663 if have_pandas_categorical_dtype:
664 s = pandas.Series(c)
665 assert np.all(pandas_Categorical_categories(s) == ["a", "b"])
666 assert np.all(pandas_Categorical_codes(s) == [1, 1, 0, -1])
668# Needed to support pandas >= 0.15 (!)
669def safe_is_pandas_categorical_dtype(dt):
670 if not have_pandas_categorical_dtype:
671 return False
672 return _pandas_is_categorical_dtype(dt)
674# Needed to support pandas >= 0.15 (!)
675def safe_is_pandas_categorical(data):
676 if not have_pandas_categorical:
677 return False
678 if isinstance(data, pandas.Categorical):
679 return True
680 if hasattr(data, "dtype"):
681 return safe_is_pandas_categorical_dtype(data.dtype)
682 return False
684def test_safe_is_pandas_categorical():
685 assert not safe_is_pandas_categorical(np.arange(10))
687 if have_pandas_categorical:
688 c_obj = pandas.Categorical(["a", "b"])
689 assert safe_is_pandas_categorical(c_obj)
691 if have_pandas_categorical_dtype:
692 s_obj = pandas.Series(["a", "b"], dtype="category")
693 assert safe_is_pandas_categorical(s_obj)
695# Needed to support pandas >= 0.15 (!)
696# Calling np.issubdtype on a pandas categorical will blow up -- the officially
697# recommended solution is to replace every piece of code like
698# np.issubdtype(foo.dtype, bool)
699# with code like
700# isinstance(foo.dtype, np.dtype) and np.issubdtype(foo.dtype, bool)
701# or
702# not pandas.is_categorical_dtype(foo.dtype) and issubdtype(foo.dtype, bool)
703# We do the latter (with extra hoops) because the isinstance check is not
704# safe. See
705# https://github.com/pydata/pandas/issues/9581
706# https://github.com/pydata/pandas/issues/9581#issuecomment-77099564
707def safe_issubdtype(dt1, dt2):
708 if safe_is_pandas_categorical_dtype(dt1):
709 return False
710 return np.issubdtype(dt1, dt2)
712def test_safe_issubdtype():
713 assert safe_issubdtype(int, np.integer)
714 assert safe_issubdtype(np.dtype(float), np.floating)
715 assert not safe_issubdtype(int, np.floating)
716 assert not safe_issubdtype(np.dtype(float), np.integer)
718 if have_pandas_categorical_dtype:
719 bad_dtype = pandas.Series(["a", "b"], dtype="category")
720 assert not safe_issubdtype(bad_dtype, np.integer)
722def no_pickling(*args, **kwargs):
723 raise NotImplementedError(
724 "Sorry, pickling not yet supported. "
725 "See https://github.com/pydata/patsy/issues/26 if you want to "
726 "help.")
728def assert_no_pickling(obj):
729 import pickle
730 from nose.tools import assert_raises
731 assert_raises(NotImplementedError, pickle.dumps, obj)
733# Use like:
734# if safe_string_eq(constraints, "center"):
735# ...
736# where 'constraints' might be a string or an array. (If it's an array, then
737# we can't use == becaues it might broadcast and ugh.)
738def safe_string_eq(obj, value):
739 if isinstance(obj, six.string_types):
740 return obj == value
741 else:
742 return False
744def test_safe_string_eq():
745 assert safe_string_eq("foo", "foo")
746 assert not safe_string_eq("foo", "bar")
748 if not six.PY3:
749 assert safe_string_eq(unicode("foo"), "foo")
751 assert not safe_string_eq(np.empty((2, 2)), "foo")