import copy import functools import sys import unittest from test import test_support from weakref import proxy import pickle @staticmethod def PythonPartial(func, *args, **keywords): 'Pure Python approximation of partial()' def newfunc(*fargs, **fkeywords): newkeywords = keywords.copy() newkeywords.update(fkeywords) return func(*(args + fargs), **newkeywords) newfunc.func = func newfunc.args = args newfunc.keywords = keywords return newfunc def capture(*args, **kw): """capture all positional and keyword arguments""" return args, kw def signature(part): """ return the signature of a partial object """ return (part.func, part.args, part.keywords, part.__dict__) class MyTuple(tuple): pass class BadTuple(tuple): def __add__(self, other): return list(self) + list(other) class MyDict(dict): pass class TestPartial(unittest.TestCase): partial = functools.partial def test_basic_examples(self): p = self.partial(capture, 1, 2, a=10, b=20) self.assertEqual(p(3, 4, b=30, c=40), ((1, 2, 3, 4), dict(a=10, b=30, c=40))) p = self.partial(map, lambda x: x*10) self.assertEqual(p([1,2,3,4]), [10, 20, 30, 40]) def test_attributes(self): p = self.partial(capture, 1, 2, a=10, b=20) # attributes should be readable self.assertEqual(p.func, capture) self.assertEqual(p.args, (1, 2)) self.assertEqual(p.keywords, dict(a=10, b=20)) # attributes should not be writable self.assertRaises(TypeError, setattr, p, 'func', map) self.assertRaises(TypeError, setattr, p, 'args', (1, 2)) self.assertRaises(TypeError, setattr, p, 'keywords', dict(a=1, b=2)) p = self.partial(hex) try: del p.__dict__ except TypeError: pass else: self.fail('partial object allowed __dict__ to be deleted') def test_argument_checking(self): self.assertRaises(TypeError, self.partial) # need at least a func arg try: self.partial(2)() except TypeError: pass else: self.fail('First arg not checked for callability') def test_protection_of_callers_dict_argument(self): # a caller's dictionary should not be altered by partial def func(a=10, b=20): return a d = {'a':3} p = self.partial(func, a=5) self.assertEqual(p(**d), 3) self.assertEqual(d, {'a':3}) p(b=7) self.assertEqual(d, {'a':3}) def test_arg_combinations(self): # exercise special code paths for zero args in either partial # object or the caller p = self.partial(capture) self.assertEqual(p(), ((), {})) self.assertEqual(p(1,2), ((1,2), {})) p = self.partial(capture, 1, 2) self.assertEqual(p(), ((1,2), {})) self.assertEqual(p(3,4), ((1,2,3,4), {})) def test_kw_combinations(self): # exercise special code paths for no keyword args in # either the partial object or the caller p = self.partial(capture) self.assertEqual(p.keywords, {}) self.assertEqual(p(), ((), {})) self.assertEqual(p(a=1), ((), {'a':1})) p = self.partial(capture, a=1) self.assertEqual(p.keywords, {'a':1}) self.assertEqual(p(), ((), {'a':1})) self.assertEqual(p(b=2), ((), {'a':1, 'b':2})) # keyword args in the call override those in the partial object self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2})) def test_positional(self): # make sure positional arguments are captured correctly for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]: p = self.partial(capture, *args) expected = args + ('x',) got, empty = p('x') self.assertTrue(expected == got and empty == {}) def test_keyword(self): # make sure keyword arguments are captured correctly for a in ['a', 0, None, 3.5]: p = self.partial(capture, a=a) expected = {'a':a,'x':None} empty, got = p(x=None) self.assertTrue(expected == got and empty == ()) def test_no_side_effects(self): # make sure there are no side effects that affect subsequent calls p = self.partial(capture, 0, a=1) args1, kw1 = p(1, b=2) self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2}) args2, kw2 = p() self.assertTrue(args2 == (0,) and kw2 == {'a':1}) def test_error_propagation(self): def f(x, y): x // y self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0)) self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0) self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0) self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1) def test_weakref(self): f = self.partial(int, base=16) p = proxy(f) self.assertEqual(f.func, p.func) f = None self.assertRaises(ReferenceError, getattr, p, 'func') def test_with_bound_and_unbound_methods(self): data = map(str, range(10)) join = self.partial(str.join, '') self.assertEqual(join(data), '0123456789') join = self.partial(''.join) self.assertEqual(join(data), '0123456789') def test_pickle(self): f = self.partial(signature, ['asdf'], bar=[True]) f.attr = [] for proto in range(pickle.HIGHEST_PROTOCOL + 1): f_copy = pickle.loads(pickle.dumps(f, proto)) self.assertEqual(signature(f_copy), signature(f)) def test_copy(self): f = self.partial(signature, ['asdf'], bar=[True]) f.attr = [] f_copy = copy.copy(f) self.assertEqual(signature(f_copy), signature(f)) self.assertIs(f_copy.attr, f.attr) self.assertIs(f_copy.args, f.args) self.assertIs(f_copy.keywords, f.keywords) def test_deepcopy(self): f = self.partial(signature, ['asdf'], bar=[True]) f.attr = [] f_copy = copy.deepcopy(f) self.assertEqual(signature(f_copy), signature(f)) self.assertIsNot(f_copy.attr, f.attr) self.assertIsNot(f_copy.args, f.args) self.assertIsNot(f_copy.args[0], f.args[0]) self.assertIsNot(f_copy.keywords, f.keywords) self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar']) def test_setstate(self): f = self.partial(signature) f.__setstate__((capture, (1,), dict(a=10), dict(attr=[]))) self.assertEqual(signature(f), (capture, (1,), dict(a=10), dict(attr=[]))) self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) f.__setstate__((capture, (1,), dict(a=10), None)) self.assertEqual(signature(f), (capture, (1,), dict(a=10), {})) self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) f.__setstate__((capture, (1,), None, None)) #self.assertEqual(signature(f), (capture, (1,), {}, {})) self.assertEqual(f(2, b=20), ((1, 2), {'b': 20})) self.assertEqual(f(2), ((1, 2), {})) self.assertEqual(f(), ((1,), {})) f.__setstate__((capture, (), {}, None)) self.assertEqual(signature(f), (capture, (), {}, {})) self.assertEqual(f(2, b=20), ((2,), {'b': 20})) self.assertEqual(f(2), ((2,), {})) self.assertEqual(f(), ((), {})) def test_setstate_errors(self): f = self.partial(signature) self.assertRaises(TypeError, f.__setstate__, (capture, (), {})) self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None)) self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None]) self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None)) self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None)) self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None)) self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None)) def test_setstate_subclasses(self): f = self.partial(signature) f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None)) s = signature(f) self.assertEqual(s, (capture, (1,), dict(a=10), {})) self.assertIs(type(s[1]), tuple) self.assertIs(type(s[2]), dict) r = f() self.assertEqual(r, ((1,), {'a': 10})) self.assertIs(type(r[0]), tuple) self.assertIs(type(r[1]), dict) f.__setstate__((capture, BadTuple((1,)), {}, None)) s = signature(f) self.assertEqual(s, (capture, (1,), {}, {})) self.assertIs(type(s[1]), tuple) r = f(2) self.assertEqual(r, ((1, 2), {})) self.assertIs(type(r[0]), tuple) def test_recursive_pickle(self): f = self.partial(capture) f.__setstate__((f, (), {}, {})) try: for proto in range(pickle.HIGHEST_PROTOCOL + 1): with self.assertRaises(RuntimeError): pickle.dumps(f, proto) finally: f.__setstate__((capture, (), {}, {})) f = self.partial(capture) f.__setstate__((capture, (f,), {}, {})) try: for proto in range(pickle.HIGHEST_PROTOCOL + 1): f_copy = pickle.loads(pickle.dumps(f, proto)) try: self.assertIs(f_copy.args[0], f_copy) finally: f_copy.__setstate__((capture, (), {}, {})) finally: f.__setstate__((capture, (), {}, {})) f = self.partial(capture) f.__setstate__((capture, (), {'a': f}, {})) try: for proto in range(pickle.HIGHEST_PROTOCOL + 1): f_copy = pickle.loads(pickle.dumps(f, proto)) try: self.assertIs(f_copy.keywords['a'], f_copy) finally: f_copy.__setstate__((capture, (), {}, {})) finally: f.__setstate__((capture, (), {}, {})) # Issue 6083: Reference counting bug def test_setstate_refcount(self): class BadSequence: def __len__(self): return 4 def __getitem__(self, key): if key == 0: return max elif key == 1: return tuple(range(1000000)) elif key in (2, 3): return {} raise IndexError f = self.partial(object) self.assertRaises(TypeError, f.__setstate__, BadSequence()) class PartialSubclass(functools.partial): pass class TestPartialSubclass(TestPartial): partial = PartialSubclass class TestPythonPartial(TestPartial): partial = PythonPartial # the python version isn't picklable test_pickle = None test_setstate = None test_setstate_errors = None test_setstate_subclasses = None test_setstate_refcount = None test_recursive_pickle = None # the python version isn't deepcopyable test_deepcopy = None # the python version isn't a type test_attributes = None class TestUpdateWrapper(unittest.TestCase): def check_wrapper(self, wrapper, wrapped, assigned=functools.WRAPPER_ASSIGNMENTS, updated=functools.WRAPPER_UPDATES): # Check attributes were assigned for name in assigned: self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name)) # Check attributes were updated for name in updated: wrapper_attr = getattr(wrapper, name) wrapped_attr = getattr(wrapped, name) for key in wrapped_attr: self.assertTrue(wrapped_attr[key] is wrapper_attr[key]) def _default_update(self): def f(): """This is a test""" pass f.attr = 'This is also a test' def wrapper(): pass functools.update_wrapper(wrapper, f) return wrapper, f def test_default_update(self): wrapper, f = self._default_update() self.check_wrapper(wrapper, f) self.assertEqual(wrapper.__name__, 'f') self.assertEqual(wrapper.attr, 'This is also a test') @unittest.skipIf(sys.flags.optimize >= 2, "Docstrings are omitted with -O2 and above") def test_default_update_doc(self): wrapper, f = self._default_update() self.assertEqual(wrapper.__doc__, 'This is a test') def test_no_update(self): def f(): """This is a test""" pass f.attr = 'This is also a test' def wrapper(): pass functools.update_wrapper(wrapper, f, (), ()) self.check_wrapper(wrapper, f, (), ()) self.assertEqual(wrapper.__name__, 'wrapper') self.assertEqual(wrapper.__doc__, None) self.assertFalse(hasattr(wrapper, 'attr')) def test_selective_update(self): def f(): pass f.attr = 'This is a different test' f.dict_attr = dict(a=1, b=2, c=3) def wrapper(): pass wrapper.dict_attr = {} assign = ('attr',) update = ('dict_attr',) functools.update_wrapper(wrapper, f, assign, update) self.check_wrapper(wrapper, f, assign, update) self.assertEqual(wrapper.__name__, 'wrapper') self.assertEqual(wrapper.__doc__, None) self.assertEqual(wrapper.attr, 'This is a different test') self.assertEqual(wrapper.dict_attr, f.dict_attr) @test_support.requires_docstrings def test_builtin_update(self): # Test for bug #1576241 def wrapper(): pass functools.update_wrapper(wrapper, max) self.assertEqual(wrapper.__name__, 'max') self.assertTrue(wrapper.__doc__.startswith('max(')) class TestWraps(TestUpdateWrapper): def _default_update(self): def f(): """This is a test""" pass f.attr = 'This is also a test' @functools.wraps(f) def wrapper(): pass self.check_wrapper(wrapper, f) return wrapper def test_default_update(self): wrapper = self._default_update() self.assertEqual(wrapper.__name__, 'f') self.assertEqual(wrapper.attr, 'This is also a test') @unittest.skipIf(sys.flags.optimize >= 2, "Docstrings are omitted with -O2 and above") def test_default_update_doc(self): wrapper = self._default_update() self.assertEqual(wrapper.__doc__, 'This is a test') def test_no_update(self): def f(): """This is a test""" pass f.attr = 'This is also a test' @functools.wraps(f, (), ()) def wrapper(): pass self.check_wrapper(wrapper, f, (), ()) self.assertEqual(wrapper.__name__, 'wrapper') self.assertEqual(wrapper.__doc__, None) self.assertFalse(hasattr(wrapper, 'attr')) def test_selective_update(self): def f(): pass f.attr = 'This is a different test' f.dict_attr = dict(a=1, b=2, c=3) def add_dict_attr(f): f.dict_attr = {} return f assign = ('attr',) update = ('dict_attr',) @functools.wraps(f, assign, update) @add_dict_attr def wrapper(): pass self.check_wrapper(wrapper, f, assign, update) self.assertEqual(wrapper.__name__, 'wrapper') self.assertEqual(wrapper.__doc__, None) self.assertEqual(wrapper.attr, 'This is a different test') self.assertEqual(wrapper.dict_attr, f.dict_attr) class TestReduce(unittest.TestCase): def test_reduce(self): class Squares: def __init__(self, max): self.max = max self.sofar = [] def __len__(self): return len(self.sofar) def __getitem__(self, i): if not 0 <= i < self.max: raise IndexError n = len(self.sofar) while n <= i: self.sofar.append(n*n) n += 1 return self.sofar[i] reduce = functools.reduce self.assertEqual(reduce(lambda x, y: x+y, ['a', 'b', 'c'], ''), 'abc') self.assertEqual( reduce(lambda x, y: x+y, [['a', 'c'], [], ['d', 'w']], []), ['a','c','d','w'] ) self.assertEqual(reduce(lambda x, y: x*y, range(2,8), 1), 5040) self.assertEqual( reduce(lambda x, y: x*y, range(2,21), 1L), 2432902008176640000L ) self.assertEqual(reduce(lambda x, y: x+y, Squares(10)), 285) self.assertEqual(reduce(lambda x, y: x+y, Squares(10), 0), 285) self.assertEqual(reduce(lambda x, y: x+y, Squares(0), 0), 0) self.assertRaises(TypeError, reduce) self.assertRaises(TypeError, reduce, 42, 42) self.assertRaises(TypeError, reduce, 42, 42, 42) self.assertEqual(reduce(42, "1"), "1") # func is never called with one item self.assertEqual(reduce(42, "", "1"), "1") # func is never called with one item self.assertRaises(TypeError, reduce, 42, (42, 42)) class TestCmpToKey(unittest.TestCase): def test_cmp_to_key(self): def mycmp(x, y): return y - x self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)), [4, 3, 2, 1, 0]) def test_hash(self): def mycmp(x, y): return y - x key = functools.cmp_to_key(mycmp) k = key(10) self.assertRaises(TypeError, hash(k)) class TestTotalOrdering(unittest.TestCase): def test_total_ordering_lt(self): @functools.total_ordering class A: def __init__(self, value): self.value = value def __lt__(self, other): return self.value < other.value def __eq__(self, other): return self.value == other.value self.assertTrue(A(1) < A(2)) self.assertTrue(A(2) > A(1)) self.assertTrue(A(1) <= A(2)) self.assertTrue(A(2) >= A(1)) self.assertTrue(A(2) <= A(2)) self.assertTrue(A(2) >= A(2)) def test_total_ordering_le(self): @functools.total_ordering class A: def __init__(self, value): self.value = value def __le__(self, other): return self.value <= other.value def __eq__(self, other): return self.value == other.value self.assertTrue(A(1) < A(2)) self.assertTrue(A(2) > A(1)) self.assertTrue(A(1) <= A(2)) self.assertTrue(A(2) >= A(1)) self.assertTrue(A(2) <= A(2)) self.assertTrue(A(2) >= A(2)) def test_total_ordering_gt(self): @functools.total_ordering class A: def __init__(self, value): self.value = value def __gt__(self, other): return self.value > other.value def __eq__(self, other): return self.value == other.value self.assertTrue(A(1) < A(2)) self.assertTrue(A(2) > A(1)) self.assertTrue(A(1) <= A(2)) self.assertTrue(A(2) >= A(1)) self.assertTrue(A(2) <= A(2)) self.assertTrue(A(2) >= A(2)) def test_total_ordering_ge(self): @functools.total_ordering class A: def __init__(self, value): self.value = value def __ge__(self, other): return self.value >= other.value def __eq__(self, other): return self.value == other.value self.assertTrue(A(1) < A(2)) self.assertTrue(A(2) > A(1)) self.assertTrue(A(1) <= A(2)) self.assertTrue(A(2) >= A(1)) self.assertTrue(A(2) <= A(2)) self.assertTrue(A(2) >= A(2)) def test_total_ordering_no_overwrite(self): # new methods should not overwrite existing @functools.total_ordering class A(str): pass self.assertTrue(A("a") < A("b")) self.assertTrue(A("b") > A("a")) self.assertTrue(A("a") <= A("b")) self.assertTrue(A("b") >= A("a")) self.assertTrue(A("b") <= A("b")) self.assertTrue(A("b") >= A("b")) def test_no_operations_defined(self): with self.assertRaises(ValueError): @functools.total_ordering class A: pass def test_bug_10042(self): @functools.total_ordering class TestTO: def __init__(self, value): self.value = value def __eq__(self, other): if isinstance(other, TestTO): return self.value == other.value return False def __lt__(self, other): if isinstance(other, TestTO): return self.value < other.value raise TypeError with self.assertRaises(TypeError): TestTO(8) <= () def test_bug_25732(self): @functools.total_ordering class A: def __init__(self, value): self.value = value def __gt__(self, other): return self.value > other.value def __eq__(self, other): return self.value == other.value def __hash__(self): return hash(self.value) self.assertTrue(A(1) != A(2)) self.assertFalse(A(1) != A(1)) @functools.total_ordering class A(object): def __init__(self, value): self.value = value def __gt__(self, other): return self.value > other.value def __eq__(self, other): return self.value == other.value def __hash__(self): return hash(self.value) self.assertTrue(A(1) != A(2)) self.assertFalse(A(1) != A(1)) @functools.total_ordering class A: def __init__(self, value): self.value = value def __gt__(self, other): return self.value > other.value def __eq__(self, other): return self.value == other.value def __ne__(self, other): raise RuntimeError(self, other) def __hash__(self): return hash(self.value) with self.assertRaises(RuntimeError): A(1) != A(2) with self.assertRaises(RuntimeError): A(1) != A(1) @functools.total_ordering class A(object): def __init__(self, value): self.value = value def __gt__(self, other): return self.value > other.value def __eq__(self, other): return self.value == other.value def __ne__(self, other): raise RuntimeError(self, other) def __hash__(self): return hash(self.value) with self.assertRaises(RuntimeError): A(1) != A(2) with self.assertRaises(RuntimeError): A(1) != A(1) def test_main(verbose=None): test_classes = ( TestPartial, TestPartialSubclass, TestPythonPartial, TestUpdateWrapper, TestTotalOrdering, TestWraps, TestReduce, ) test_support.run_unittest(*test_classes) # verify reference counting if verbose and hasattr(sys, "gettotalrefcount"): import gc counts = [None] * 5 for i in xrange(len(counts)): test_support.run_unittest(*test_classes) gc.collect() counts[i] = sys.gettotalrefcount() print counts if __name__ == '__main__': test_main(verbose=True)