#!/usr/bin/env python # -*- coding: iso-8859-1 -*- ############################################################################### # # Yet another invariant/pre-/postcondition design-by-contract support module. # # Written by Dmitry Dvoinikov # Distributed under MIT license. # # The latest version, complete with self-tests can be downloaded from: # http://www.targeted.org/python/recipes/ipdbc.py # # Sample usage: # # import ipdbc.py # # class Balloon(ContractBase): # demonstrates class invariant # def invariant(self): # return 0 <= self.weight < 1000 # returns True/False # def __init__(self): # self.weight = 0 # def fails(self): # upon return this throws PostInvariantViolationError # self.weight = 1000 # # class GuidedBalloon(Balloon): # demonstrates pre/post condition # def pre_drop(self, _weight): # pre_ receives exact copy of arguments # return self.weight >= _weight # returns True/False # def drop(self, _weight): # self.weight -= _weight; # return self.weight # the result of the call is passed # def post_drop(self, result, _weight): # as a second parameter to post_ # return result >= 0 # followed again by copy of arguments # # Note: GuidedBalloon().fails() still fails, since Balloon's invariant is # inherited. # Note: All the dbc infused methods are inherited in the mro-correct way. # Note: Neither classmethods nor staticmethods are decorated, only "regular" # instance-bound methods. # # (c) 2005-2006 Dmitry Dvoinikov # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights to # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies # of the Software, and to permit persons to whom the Software is furnished to do # so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # ############################################################################### __all__ = ["ContractBase", "ContractViolationError", "InvariantViolationError", "PreInvariantViolationError", "PostInvariantViolationError", "PreConditionViolationError", "PostConditionViolationError", "PreconditionViolationError", "PostconditionViolationError" ] CONTRACT_CHECKS_ENABLED = True # allows to turn contract checks off when needed ############################################################################### class ContractViolationError(AssertionError): pass class InvariantViolationError(ContractViolationError): pass class PreInvariantViolationError(InvariantViolationError): pass class PostInvariantViolationError(InvariantViolationError): pass class PreConditionViolationError(ContractViolationError): pass PreconditionViolationError = PreConditionViolationError # pep 316 calls it such class PostConditionViolationError(ContractViolationError): pass PostconditionViolationError = PostConditionViolationError # pep 316 calls it such ############################################################################### from types import FunctionType from sys import hexversion have_python_24 = hexversion >= 0x2040000 ################################################################################ def any(s, f = lambda e: bool(e)): for e in s: if f(e): return True else: return False ################################################################################ def none(s, f = lambda e: bool(e)): return not any(s, f) ################################################################################ def empty(s): return len(s) == 0 ################################################################################ def pick_first(s, f = lambda e: bool(e)): for e in s: if f(e): return e else: return None ################################################################################ if not have_python_24: def reversed(s): r = list(s) r.reverse() return r ################################################################################ def merged_mro(*classes): """ Returns list of all classes' bases merged and mro-correctly ordered, implemented as per http://www.python.org/2.3/mro.html """ if any(classes, lambda c: not isinstance(c, type)): raise TypeError("merged_mro expects all it's parameters to be classes, got %s" % pick_first(classes, lambda c: not isinstance(c, type))) def merge(lists): result = [] lists = [ (list_[0], list_[1:]) for list_ in lists ] while not empty(lists): good_head, tail = pick_first(lists, lambda ht1: none(lists, lambda ht2: ht1[0] in ht2[1])) or (None, None) if good_head is None: raise TypeError("Cannot create a consistent method resolution " "order (MRO) for bases %s" % ", ".join([ cls.__name__ for cls in classes ])) result += [ good_head ] i = 0 while i < len(lists): head, tail = lists[i] if head == good_head: if empty(tail): del(lists[i]) else: lists[i] = ( tail[0], tail[1:] ) i += 1 else: i += 1 return result merged = [ cls.mro() for cls in classes ] + [ list(classes) ] return merge(merged) ############################################################################### class ContractFactory(type): def _wrap(_method, preinvariant, precondition, postcondition, postinvariant, _classname, _methodname): def preinvariant_check(result): if not result: raise PreInvariantViolationError( "Class invariant does not hold before a call to %s.%s" % (_classname, _methodname)) def precondition_check(result): if not result: raise PreConditionViolationError( "Precondition failed before a call to %s.%s" % (_classname, _methodname)) def postcondition_check(result): if not result: raise PostConditionViolationError( "Postcondition failed after a call to %s.%s" % (_classname, _methodname)) def postinvariant_check(result): if not result: raise PostInvariantViolationError( "Class invariant does not hold after a call to %s.%s" % (_classname, _methodname)) if preinvariant is not None and precondition is not None \ and postcondition is not None and postinvariant is not None: def dbc_wrapper(self, *args, **kwargs): preinvariant_check(preinvariant(self)) precondition_check(precondition(self, *args, **kwargs)) result = _method(self, *args, **kwargs) postcondition_check(postcondition(self, result, *args, **kwargs)) postinvariant_check(postinvariant(self)) return result elif preinvariant is not None and precondition is not None \ and postcondition is not None and postinvariant is None: def dbc_wrapper(self, *args, **kwargs): preinvariant_check(preinvariant(self)) precondition_check(precondition(self, *args, **kwargs)) result = _method(self, *args, **kwargs) postcondition_check(postcondition(self, result, *args, **kwargs)) return result elif preinvariant is not None and precondition is not None \ and postcondition is None and postinvariant is not None: def dbc_wrapper(self, *args, **kwargs): preinvariant_check(preinvariant(self)) precondition_check(precondition(self, *args, **kwargs)) result = _method(self, *args, **kwargs) postinvariant_check(postinvariant(self)) return result elif preinvariant is not None and precondition is not None \ and postcondition is None and postinvariant is None: def dbc_wrapper(self, *args, **kwargs): preinvariant_check(preinvariant(self)) precondition_check(precondition(self, *args, **kwargs)) result = _method(self, *args, **kwargs) return result elif preinvariant is not None and precondition is None \ and postcondition is not None and postinvariant is not None: def dbc_wrapper(self, *args, **kwargs): preinvariant_check(preinvariant(self)) result = _method(self, *args, **kwargs) postcondition_check(postcondition(self, result, *args, **kwargs)) postinvariant_check(postinvariant(self)) return result elif preinvariant is not None and precondition is None \ and postcondition is not None and postinvariant is None: def dbc_wrapper(self, *args, **kwargs): preinvariant_check(preinvariant(self)) result = _method(self, *args, **kwargs) postcondition_check(postcondition(self, result, *args, **kwargs)) return result elif preinvariant is not None and precondition is None \ and postcondition is None and postinvariant is not None: def dbc_wrapper(self, *args, **kwargs): preinvariant_check(preinvariant(self)) result = _method(self, *args, **kwargs) postinvariant_check(postinvariant(self)) return result elif preinvariant is not None and precondition is None \ and postcondition is None and postinvariant is None: def dbc_wrapper(self, *args, **kwargs): preinvariant_check(preinvariant(self)) result = _method(self, *args, **kwargs) return result elif preinvariant is None and precondition is not None \ and postcondition is not None and postinvariant is not None: def dbc_wrapper(self, *args, **kwargs): precondition_check(precondition(self, *args, **kwargs)) result = _method(self, *args, **kwargs) postcondition_check(postcondition(self, result, *args, **kwargs)) postinvariant_check(postinvariant(self)) return result elif preinvariant is None and precondition is not None \ and postcondition is not None and postinvariant is None: def dbc_wrapper(self, *args, **kwargs): precondition_check(precondition(self, *args, **kwargs)) result = _method(self, *args, **kwargs) postcondition_check(postcondition(self, result, *args, **kwargs)) return result elif preinvariant is None and precondition is not None \ and postcondition is None and postinvariant is not None: def dbc_wrapper(self, *args, **kwargs): precondition_check(precondition(self, *args, **kwargs)) result = _method(self, *args, **kwargs) postinvariant_check(postinvariant(self)) return result elif preinvariant is None and precondition is not None \ and postcondition is None and postinvariant is None: def dbc_wrapper(self, *args, **kwargs): precondition_check(precondition(self, *args, **kwargs)) result = _method(self, *args, **kwargs) return result elif preinvariant is None and precondition is None \ and postcondition is not None and postinvariant is not None: def dbc_wrapper(self, *args, **kwargs): result = _method(self, *args, **kwargs) postcondition_check(postcondition(self, result, *args, **kwargs)) postinvariant_check(postinvariant(self)) return result elif preinvariant is None and precondition is None \ and postcondition is not None and postinvariant is None: def dbc_wrapper(self, *args, **kwargs): result = _method(self, *args, **kwargs) postcondition_check(postcondition(self, result, *args, **kwargs)) return result elif preinvariant is None and precondition is None \ and postcondition is None and postinvariant is not None: def dbc_wrapper(self, *args, **kwargs): result = _method(self, *args, **kwargs) postinvariant_check(postinvariant(self)) return result elif preinvariant is None and precondition is None \ and postcondition is None and postinvariant is None: def dbc_wrapper(self, *args, **kwargs): result = _method(self, *args, **kwargs) return result if have_python_24: dbc_wrapper.__name__ = _methodname return dbc_wrapper _wrap = staticmethod(_wrap) def __new__(_class, _name, _bases, _dict): # because the mro for the class being created is not yet available # we'll have to build it by hand using our own mro implementation mro = merged_mro(*_bases) # the lack of _class itself in mro is compensated ... dict_with_bases = {} for base in reversed(mro): if hasattr(base, "__dict__"): dict_with_bases.update(base.__dict__) dict_with_bases.update(_dict) # ... here by explicitly adding it's method last try: invariant = dict_with_bases["invariant"] except KeyError: invariant = None for name, target in dict_with_bases.iteritems(): if isinstance(target, FunctionType) and name != "__del__" and name != "invariant" \ and not name.startswith("pre_") and not name.startswith("post_"): try: pre = dict_with_bases["pre_%s" % name] except KeyError: pre = None try: post = dict_with_bases["post_%s" % name] except KeyError: post = None # note that __del__ is not checked at all _dict[name] = ContractFactory._wrap(target, name != "__init__" and invariant or None, pre or None, post or None, invariant or None, _name, name) return super(ContractFactory, _class).__new__(_class, _name, _bases, _dict) class ContractBase(object): if CONTRACT_CHECKS_ENABLED: __metaclass__ = ContractFactory ############################################################################### if __name__ == "__main__": # run self-tests print "self-testing module ipdbc.py:" from time import time class C(ContractBase): def __init__(self): pass def f(self): pass def __del__(self): pass # self test #1, no checks at all ####### INVARIANT try: C().f() except DesignByContractError: raise Exception("Self test #1 in ipdbc.py failed") class C(ContractBase): def __init__(self): self.x = 0 def invariant(self): return self.x == 0 def __del__(self): self.x = -1 def f(self): self.x += 1 def g(self): pass def h(_class): pass h = classmethod(h) def i(): pass i = staticmethod(i) c = C() # self test #2, empty method does nothing try: c.g() except DesignByContractError: raise Exception("Self test #2 in ipdbc.py failed") # self test #3, break invariant by calling to a violent function try: c.f() except PostInvariantViolationError: pass else: raise Exception("Self test #3 in ipdbc.py failed") # self test #4, a call to any method of an instance with a broken invariant fails try: c.g() except PreInvariantViolationError: pass else: raise Exception("Self test #4 in ipdbc.py failed") # self test #5, calls to classmethods or staticmethods are not affected at all # even if the invariant is broken try: c.h() c.i() except ContractViolationError: raise Exception("Self test #5 in ipdbc.py failed") class C(ContractBase): def __init__(self): self.x = -1 def invariant(self): return self.x == 0 # self test #6, broken constructor try: C() except PostInvariantViolationError: pass else: raise Exception("Self test #6 in ipdbc.py failed") ####### PRECONDITION/POSTCONDITION class C(ContractBase): def __init__(self): super(C, self).__init__() def post___init__(self, result): return result is None def pre_SomeFunction(self, _func): return _func is not None def SomeFunction(self, _func): return _func() def post_SomeFunction(self, result, _func): return result is not None # self test #7, precondition breaks try: C().SomeFunction(None) except PreConditionViolationError: pass else: raise Exception("Self test #7 in ipdbc.py failed") # self test #8, postcondition breaks try: C().SomeFunction(lambda: None) except PostConditionViolationError: pass else: raise Exception("Self test #8 in ipdbc.py failed") # self test #9, both conditions satisfied try: C().SomeFunction(lambda: 0) except DesignByContractError: raise Exception("Self test #9 in ipdbc.py failed") # self test # 10, self type mangling test class C(object): def f(self): return str(type(self)) strClassName = C().f() class C(ContractBase): def f(self): global strClassName assert str(type(self)) == strClassName, "Class reference has been mangled" def pre_g(self, param): global strClassName assert str(type(self)) == strClassName, "Class reference has been mangled" return True def g(self, param): global strClassName assert str(type(self)) == strClassName, "Class reference has been mangled" def h(self, param): global strClassName assert str(type(self)) == strClassName, "Class reference has been mangled" def post_h(self, result, param): global strClassName assert str(type(self)) == strClassName, "Class reference has been mangled" return True def pre_i(self, param): global strClassName assert str(type(self)) == strClassName, "Class reference has been mangled" return True def i(self, param): global strClassName assert str(type(self)) == strClassName, "Class reference has been mangled" def post_i(self, result, param): global strClassName assert str(type(self)) == strClassName, "Class reference has been mangled" return True def j(cls): global strClassName assert str(cls) == strClassName, "Class reference has been mangled" j = classmethod(j) def k(): pass k = staticmethod(k) C().f() C().g(None) C().h(None) C().i(None) C().j() C().k() # self test # 11, multiple inheritance class test(ContractBase, list): def invariant(self): return len(self) == 0 def __init__(self): ContractBase.__init__(self) list.__init__(self) def break_self(self): self.append("foo") try: test().break_self() except PostInvariantViolationError: pass else: raise Exception("Self test #11 in ipdbc.py failed") # self test # 12, constructors and destructors class foo(ContractBase): x = 1 def invariant(self): return self.x == 0 def __init__(self): ContractBase.__init__(self) self.x = 0 def pre___del__(self): raise Exception("Self test #12 in ipdbc.py failed") def __del__(self): self.x = 1 def post___del__(self): raise Exception("Self test #12 in ipdbc.py failed") f = foo() del f # self test # 13, message readability class SuperficiallyEnhancedClass(ContractBase): def __init__(self): ContractBase.__init__(self) self.i = 0 def invariant(self): return self.i <= 2 def pre_ExternallyCheckedMethod(self, i): self.i = i return i != 2 def ExternallyCheckedMethod(self, i): return i def post_ExternallyCheckedMethod(self, result, i): return result != 1 s = SuperficiallyEnhancedClass() try: s.ExternallyCheckedMethod(1) except PostConditionViolationError, e: assert str(e).find("SuperficiallyEnhancedClass.ExternallyCheckedMethod") >= 0 else: raise Exception("Self test #13 in ipdbc.py failed") try: s.ExternallyCheckedMethod(2) except PreConditionViolationError, e: assert str(e).find("SuperficiallyEnhancedClass.ExternallyCheckedMethod") >= 0 else: raise Exception("Self test #13 in ipdbc.py failed") try: s.ExternallyCheckedMethod(3) except PostInvariantViolationError, e: assert str(e).find("SuperficiallyEnhancedClass.ExternallyCheckedMethod") >= 0 else: raise Exception("Self test #13 in ipdbc.py failed") try: s.ExternallyCheckedMethod(0) except PreInvariantViolationError, e: assert str(e).find("SuperficiallyEnhancedClass.ExternallyCheckedMethod") >= 0 else: raise Exception("Self test #13 in ipdbc.py failed") # self test # 14, nested classes class outer(ContractBase): class inner(ContractBase): def __init__(self): self.s = "" def invariant(self): self.s += "(inv)" return True def pre_foo(self): self.s += "(pre_foo)" return True def foo(self): self.s += "(foo)" def post_foo(self, result): self.s += "(post_foo)" return True assert issubclass(outer.inner, ContractBase) inr = outer.inner() inr.foo() assert inr.s == "(inv)(inv)(pre_foo)(foo)(post_foo)(inv)" # the very first (inv) is invoked after __init__ # self test # 15, inherited dbc class b(ContractBase): def pre___init__(self): self.s = "(b.pre___init__)" return True def __init__(self): ContractBase.__init__(self) self.s += "(b.__init__)" def post___init__(self, result): self.s += "(b.post___init__)" return True def invariant(self): self.s += "(b.inv)" return True def foo(self): self.s += "(b.foo)" bb = b() bb.s += "|" bb.foo() assert bb.s == "(b.pre___init__)(b.__init__)(b.post___init__)(b.inv)|(b.inv)(b.foo)(b.inv)" class c(b): def invariant(self): self.s += "(c.inv)" return b.invariant(self) and True def __init__(self): b.__init__(self) self.s += "(c.__init__)" def post___init__(self, result): self.s += "(c.post___init__)" return True def pre_foo(self): self.s += "(c.pre_foo)" return True def foo(self): self.s += "(c.foo)" cc = c() cc.s += "|" cc.foo() assert cc.s == "(b.pre___init__)(b.__init__)(b.post___init__)(b.inv)(c.__init__)" \ "(c.post___init__)(c.inv)(b.inv)|(c.inv)(b.inv)(c.pre_foo)(c.foo)(c.inv)(b.inv)" class d(b): def invariant(self): self.s += "(d.inv)" return b.invariant(self) and True def __init__(self): b.__init__(self) self.s += "(d.__init__)" def post___init__(self, result): self.s += "(d.post___init__)" return b.post___init__(self, result) def post_foo(self, result): self.s += "(d.post_foo)" return True def foo(self): self.s += "(d.foo)" dd = d() dd.s += "|" dd.foo() assert dd.s == "(b.pre___init__)(b.__init__)(b.post___init__)(b.inv)(d.__init__)" \ "(d.post___init__)(b.post___init__)(d.inv)(b.inv)|(d.inv)(b.inv)" \ "(d.foo)(d.post_foo)(d.inv)(b.inv)" class e(c, d): # note that swapping two classes in the inheritance list def __init__(self): c.__init__(self) self.s += "(e.__init__)" assert e().s == "(b.pre___init__)(b.__init__)(b.post___init__)(b.inv)(c.__init__)" \ "(c.post___init__)(c.inv)(b.inv)(e.__init__)(c.post___init__)(c.inv)(b.inv)" class e(d, c): # makes the appropriate class'es invariant to be used def __init__(self): c.__init__(self) self.s += "(e.__init__)" assert e().s == "(b.pre___init__)(b.__init__)(b.post___init__)(b.inv)(c.__init__)" \ "(c.post___init__)(c.inv)(b.inv)(e.__init__)(d.post___init__)(b.post___init__)(d.inv)(b.inv)" class e(c, d): def __init__(self): c.__init__(self) self.s += "(e.__init__)" def post___init__(self, result): self.s += "(e.post___init__)" return True def foo(self): self.s += "(e.foo)" ee = e() ee.s += "|" ee.foo() assert ee.s == "(b.pre___init__)(b.__init__)(b.post___init__)(b.inv)(c.__init__)" \ "(c.post___init__)(c.inv)(b.inv)(e.__init__)(e.post___init__)(c.inv)(b.inv)|" \ "(c.inv)(b.inv)(c.pre_foo)(e.foo)(d.post_foo)(c.inv)(b.inv)" # self test # 16, kwargs class c(ContractBase): def pre_foo(self, *args, **kwargs): return dict(zip([ "arg%d" % i for i in xrange(len(args)) ], args)) == kwargs def foo(self, *args, **kwargs): pass def post_foo(self, result, *args, **kwargs): return self.pre_foo(*args, **kwargs) c().foo() c().foo(0, arg0 = 0) c().foo("zoom", [], dict(), arg0 = "zoom", arg1 = list(), arg2 = {}) try: c().foo(0) except PreConditionViolationError: pass else: assert False, "Precondition should have thrown on invalid parameters" try: c().foo(arg0 = 0) except PreConditionViolationError: pass else: assert False, "Precondition should have thrown on invalid parameters" try: c().foo(0, arg0 = 1) except PreConditionViolationError: pass else: assert False, "Precondition should have thrown on invalid parameters" print "All tests passed ok" # timing tests print "Running benchmark, please wait..." def benchmark(_class, N = 200, M = 1000): start = time() for i in range(1, N + 1): c = _class() for j in range(1, M + 1): c.f(1) stop = time() return stop - start class C(ContractBase): def __init__(self): self.x = 0 def __del__(self): pass def invariant(self): return self.x == 0 or self.x == 1 def pre_f(self, _x): return _x == 1 def f(self, _x): self.x ^= _x def post_f(self, result, _x): return result is None withsec = benchmark(C) print "200000 iterations with dbc: %.3f sec" % withsec class C: def __init__(self): self.x = 0 def __del__(self): pass def f(self, _x): self.x ^= _x withoutsec = benchmark(C) print "200000 iterations without dbc: %.3f sec" % withoutsec print "Slowdown factor (should be about 8): %.3f" % (withsec / withoutsec) print "ok" ############################################################################### # EOF