提交 2bd4be2d authored 作者: projects@lgcm's avatar projects@lgcm

merged

...@@ -21,9 +21,16 @@ def _numpy_checker(x, y): ...@@ -21,9 +21,16 @@ def _numpy_checker(x, y):
Used in DualLinker to compare C version with Python version. Used in DualLinker to compare C version with Python version.
""" """
x, y = x[0], y[0] x, y = x[0], y[0]
if x.dtype != y.dtype or x.shape != y.shape or numpy.any(abs(x - y) > 1e-10): if x.dtype != y.dtype or x.shape != y.shape or numpy.any(numpy.abs(x - y) > 1e-10):
raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y}) raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y})
def safe_make_node(op, *inputs):
"""Emulate the behaviour of make_node when op is a function instead of an Op instance."""
node = op(*inputs)
if isinstance(node, list):
return node[0].owner
else:
return node.owner
def make_tester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_runtime = {}, grad = {}): def make_tester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_runtime = {}, grad = {}):
if grad is True: if grad is True:
...@@ -46,7 +53,8 @@ def make_tester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_ ...@@ -46,7 +53,8 @@ def make_tester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_
inputs = [copy(input) for input in inputs] inputs = [copy(input) for input in inputs]
inputrs = [value(input) for input in inputs] inputrs = [value(input) for input in inputs]
try: try:
node = self.op.make_node(*inputrs) #node = self.op.make_node(*inputrs)
node = safe_make_node(self.op, *inputrs)
except: except:
type, exc_value, traceback = sys.exc_info() type, exc_value, traceback = sys.exc_info()
err_msg = "Test %s::%s: Error occurred while making a node with inputs %s" \ err_msg = "Test %s::%s: Error occurred while making a node with inputs %s" \
...@@ -80,7 +88,8 @@ def make_tester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_ ...@@ -80,7 +88,8 @@ def make_tester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_
if not isinstance(expecteds, (list, tuple)): if not isinstance(expecteds, (list, tuple)):
expecteds = (expecteds, ) expecteds = (expecteds, )
for i, (result, expected) in enumerate(zip(results, expecteds)): for i, (result, expected) in enumerate(zip(results, expecteds)):
if result.dtype != expected.dtype or result.shape != expected.shape or numpy.any(abs(result - expected) > 1e-10): if result.dtype != expected.dtype or result.shape != expected.shape or \
numpy.any(numpy.abs(result - expected) > 1e-10):
self.fail("Test %s::%s: Output %s gave the wrong value. With inputs %s, expected %s, got %s." self.fail("Test %s::%s: Output %s gave the wrong value. With inputs %s, expected %s, got %s."
% (self.op, testname, i, inputs, expected, result)) % (self.op, testname, i, inputs, expected, result))
...@@ -94,7 +103,7 @@ def make_tester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_ ...@@ -94,7 +103,7 @@ def make_tester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_
inputs = [copy(input) for input in inputs] inputs = [copy(input) for input in inputs]
inputrs = [value(input) for input in inputs] inputrs = [value(input) for input in inputs]
try: try:
node = self.op.make_node(*inputrs) node = safe_make_node(self.op,*inputrs)
except: except:
return return
self.fail("Test %s::%s: %s was successfully instantiated on the following bad inputs: %s" self.fail("Test %s::%s: %s was successfully instantiated on the following bad inputs: %s"
...@@ -105,7 +114,7 @@ def make_tester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_ ...@@ -105,7 +114,7 @@ def make_tester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_
inputs = [copy(input) for input in inputs] inputs = [copy(input) for input in inputs]
inputrs = [value(input) for input in inputs] inputrs = [value(input) for input in inputs]
try: try:
node = self.op.make_node(*inputrs) node = safe_make_node(self.op,*inputrs)
except: except:
type, exc_value, traceback = sys.exc_info() type, exc_value, traceback = sys.exc_info()
err_msg = "Test %s::%s: Error occurred while trying to make a node with inputs %s" \ err_msg = "Test %s::%s: Error occurred while trying to make a node with inputs %s" \
...@@ -340,8 +349,8 @@ AbsTester = make_broadcast_tester(op = tensor._abs, ...@@ -340,8 +349,8 @@ AbsTester = make_broadcast_tester(op = tensor._abs,
expected = lambda x: abs(x), expected = lambda x: abs(x),
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal) grad = _grad_broadcast_unary_normal)
AbsInplaceTester = make_broadcast_tester(op = tensor._abs_inplace, AbsInplaceTester = make_broadcast_tester(op = tensor.__abs_inplace,
expected = lambda x: abs(x), expected = lambda x: numpy.abs(x),
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal, grad = _grad_broadcast_unary_normal,
inplace = True) inplace = True)
...@@ -519,7 +528,9 @@ def verify_grad(testcase, op, pt, n_tests=1, rng=numpy.random, eps=0.0000001, to ...@@ -519,7 +528,9 @@ def verify_grad(testcase, op, pt, n_tests=1, rng=numpy.random, eps=0.0000001, to
for test_num in xrange(n_tests): for test_num in xrange(n_tests):
# tensor_pt = [as_tensor(p,name='input %i'%i) for i,p in enumerate(pt)] # tensor_pt = [as_tensor(p,name='input %i'%i) for i,p in enumerate(pt)]
tensor_pt = [constant(p).type('input %i'%i) for i,p in enumerate(pt)] tensor_pt = [constant(p).type('input %i'%i) for i,p in enumerate(pt)]
o = op.make_node(*[tpt.copy() for tpt in tensor_pt]) #o = op.make_node(*[tpt.copy() for tpt in tensor_pt])
o = safe_make_node(op, *[tpt.copy() for tpt in tensor_pt])
if hasattr(o, 'outputs'): if hasattr(o, 'outputs'):
o_outputs = o.outputs o_outputs = o.outputs
else: else:
......
...@@ -586,7 +586,7 @@ class Abs(UnaryScalarOp): ...@@ -586,7 +586,7 @@ class Abs(UnaryScalarOp):
return "%(z)s = fabs(%(x)s);" % locals() return "%(z)s = fabs(%(x)s);" % locals()
#complex, other? #complex, other?
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
abs = Abs(same_out) _abs = Abs(same_out)
class Sgn(UnaryScalarOp): class Sgn(UnaryScalarOp):
def impl(self, x): def impl(self, x):
......
...@@ -508,6 +508,29 @@ def _elemwise(scalar_op, name, doc_prefix=''): ...@@ -508,6 +508,29 @@ def _elemwise(scalar_op, name, doc_prefix=''):
return straight, inplace return straight, inplace
def _epydoc_cheat(real_symbol_value):
"""Replace the value associated with a function symbol"""
def decorator(f):
return real_symbol_value
return decorator
def _scal_elemwise(symbol):
"""Replace a symbol definition with an elementwise version of the corresponding scalar Op"""
symbolname = symbol.__name__
inplace = symbolname.endswith('_inplace')
if inplace:
scalar_op = getattr(scal, symbolname[1:-len('_inplace')])
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
rval = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name=symbolname)
else:
scalar_op = getattr(scal, symbolname)
rval = elemwise.Elemwise(scalar_op, name=symbolname)
if getattr(symbol, '__doc__', ''):
rval.__doc__ = symbol.__doc__ + '\n' + rval.__doc__
return rval
######################### #########################
...@@ -664,47 +687,119 @@ def argmax(x, axis=None): ...@@ -664,47 +687,119 @@ def argmax(x, axis=None):
# Comparison # Comparison
########################## ##########################
lt, _lt_inplace = _elemwise(scal.lt, 'lt', def _elemwise_macro(scalar_op, *args):
"""less than (elemwise)""") straight = elemwise.Elemwise(scalar_op)
return straight(*args)
def _elemwise_macro_inplace(scalar_op, *args):
#construct an inplace version of the scalar op
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
inplace = elemwise.Elemwise(inplace_scalar_op, {0: 0})
return inplace(*args)
@_scal_elemwise
def lt(a, b):
"""a < b"""
return _elemwise_macro(scal.lt, a, b)
@_scal_elemwise
def _lt_inplace(a,b):
"""a < b (inplace on a)"""
return _elemwise_macro_inplace(scal.lt, a, b)
@_scal_elemwise
def gt(a, b):
"""a > b"""
@_scal_elemwise
def _gt_inplace(a,b):
"""a > b (inplace on a)"""
@_scal_elemwise
def le(a, b):
"""a <= b"""
gt, _gt_inplace = _elemwise(scal.gt, 'gt', @_scal_elemwise
"""greater than (elemwise)""") def _le_inplace(a,b):
"""a <= b (inplace on a)"""
le, _le_inplace = _elemwise(scal.le, 'le', @_scal_elemwise
"""less than, or equal to (elemwise)""") def ge(a, b):
"""a >= b"""
ge, _ge_inplace = _elemwise(scal.ge, 'ge', @_scal_elemwise
"""greater than, or equal to (elemwise)""") def _ge_inplace(a,b):
"""a >= b (inplace on a)"""
eq, _eq_inplace = _elemwise(scal.eq, 'eq', @_scal_elemwise
"""equal to (elemwise)""") def eq(a, b):
"""a == b"""
neq, _neq_inplace = _elemwise(scal.neq, 'neq', @_scal_elemwise
"""not equal to (elemwise)""") def _eq_inplace(a,b):
"""a == b (inplace on a)"""
@_scal_elemwise
def neq(a, b):
"""a != b"""
@_scal_elemwise
def _neq_inplace(a,b):
"""a != b (inplace on a)"""
########################## ##########################
# Bit-wise # Bit-wise
########################## ##########################
and_, _and_inplace = _elemwise(scal.and_, 'and_', @_scal_elemwise
"""bitwise AND (elemwise)""") def and_(a,b):
"""bitwise a & b"""
@_scal_elemwise
def _and__inplace(a,b):
"""bitwise a & b (inplace on a)"""
@_scal_elemwise
def or_(a,b):
"""bitwise a | b"""
@_scal_elemwise
def _or__inplace(a,b):
"""bitwise a | b (inplace on a)"""
or_, _or_inplace = _elemwise(scal.or_, 'or_', @_scal_elemwise
"""bitwise OR (elemwise)""") def xor(a,b):
"""bitwise a ^ b"""
xor, _xor_inplace = _elemwise(scal.xor, 'xor', @_scal_elemwise
"""bitwise XOR (elemwise)""") def _xor_inplace(a,b):
"""bitwise a ^ b (inplace on a)"""
invert, _invert_inplace = _elemwise(scal.invert, 'invert', @_scal_elemwise
"""bitwise NOT (elemwise)""") def invert(a):
"""bitwise ~a"""
@_scal_elemwise
def _invert_inplace(a):
"""bitwise ~a (inplace on a)"""
########################## ##########################
# Math # Math
########################## ##########################
_abs, _abs_inplace = _elemwise(scal.abs, 'abs', @_scal_elemwise
"""absolute value (elemwise)""") def _abs(*a):
"""|a|
_abs has a leading underscore because abs() is a builtin. TensorResult overloads the
__abs__ operator so that this function is called when you type abs(a).
"""
@_scal_elemwise
def __abs_inplace(a):
"""|a| (inplace on a)"""
exp, _exp_inplace = _elemwise(scal.exp, 'exp', exp, _exp_inplace = _elemwise(scal.exp, 'exp',
"""exponential (elemwise)""") """exponential (elemwise)""")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论