提交 36c317bd authored 作者: James Bergstra's avatar James Bergstra

converting tensor ops to _scal_elemwise

上级 7ab5fb48
......@@ -21,9 +21,16 @@ def _numpy_checker(x, y):
Used in DualLinker to compare C version with Python version.
"""
x, y = x[0], y[0]
if x.dtype != y.dtype or x.shape != y.shape or numpy.any(abs(x - y) > 1e-10):
if x.dtype != y.dtype or x.shape != y.shape or numpy.any(numpy.abs(x - y) > 1e-10):
raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y})
def safe_make_node(op, *inputs):
"""Emulate the behaviour of make_node when op is a function instead of an Op instance."""
node = op(*inputs)
if isinstance(node, list):
return node[0].owner
else:
return node.owner
def make_tester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_runtime = {}, grad = {}):
if grad is True:
......@@ -46,7 +53,8 @@ def make_tester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_
inputs = [copy(input) for input in inputs]
inputrs = [value(input) for input in inputs]
try:
node = self.op.make_node(*inputrs)
#node = self.op.make_node(*inputrs)
node = safe_make_node(self.op, *inputrs)
except:
type, exc_value, traceback = sys.exc_info()
err_msg = "Test %s::%s: Error occurred while making a node with inputs %s" \
......@@ -80,7 +88,8 @@ def make_tester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_
if not isinstance(expecteds, (list, tuple)):
expecteds = (expecteds, )
for i, (result, expected) in enumerate(zip(results, expecteds)):
if result.dtype != expected.dtype or result.shape != expected.shape or numpy.any(abs(result - expected) > 1e-10):
if result.dtype != expected.dtype or result.shape != expected.shape or \
numpy.any(numpy.abs(result - expected) > 1e-10):
self.fail("Test %s::%s: Output %s gave the wrong value. With inputs %s, expected %s, got %s."
% (self.op, testname, i, inputs, expected, result))
......@@ -94,7 +103,7 @@ def make_tester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_
inputs = [copy(input) for input in inputs]
inputrs = [value(input) for input in inputs]
try:
node = self.op.make_node(*inputrs)
node = safe_make_node(self.op,*inputrs)
except:
return
self.fail("Test %s::%s: %s was successfully instantiated on the following bad inputs: %s"
......@@ -105,7 +114,7 @@ def make_tester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_
inputs = [copy(input) for input in inputs]
inputrs = [value(input) for input in inputs]
try:
node = self.op.make_node(*inputrs)
node = safe_make_node(self.op,*inputrs)
except:
type, exc_value, traceback = sys.exc_info()
err_msg = "Test %s::%s: Error occurred while trying to make a node with inputs %s" \
......@@ -340,8 +349,8 @@ AbsTester = make_broadcast_tester(op = tensor._abs,
expected = lambda x: abs(x),
good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal)
AbsInplaceTester = make_broadcast_tester(op = tensor._abs_inplace,
expected = lambda x: abs(x),
AbsInplaceTester = make_broadcast_tester(op = tensor.__abs_inplace,
expected = lambda x: numpy.abs(x),
good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal,
inplace = True)
......@@ -519,7 +528,9 @@ def verify_grad(testcase, op, pt, n_tests=1, rng=numpy.random, eps=0.0000001, to
for test_num in xrange(n_tests):
# tensor_pt = [as_tensor(p,name='input %i'%i) for i,p in enumerate(pt)]
tensor_pt = [constant(p).type('input %i'%i) for i,p in enumerate(pt)]
o = op.make_node(*[tpt.copy() for tpt in tensor_pt])
#o = op.make_node(*[tpt.copy() for tpt in tensor_pt])
o = safe_make_node(op, *[tpt.copy() for tpt in tensor_pt])
if hasattr(o, 'outputs'):
o_outputs = o.outputs
else:
......
......@@ -586,7 +586,7 @@ class Abs(UnaryScalarOp):
return "%(z)s = fabs(%(x)s);" % locals()
#complex, other?
raise NotImplementedError('type not supported', type)
abs = Abs(same_out)
_abs = Abs(same_out)
class Sgn(UnaryScalarOp):
def impl(self, x):
......
......@@ -508,6 +508,29 @@ def _elemwise(scalar_op, name, doc_prefix=''):
return straight, inplace
def _epydoc_cheat(real_symbol_value):
"""Replace the value associated with a function symbol"""
def decorator(f):
return real_symbol_value
return decorator
def _scal_elemwise(symbol):
"""Replace a symbol definition with an elementwise version of the corresponding scalar Op"""
symbolname = symbol.__name__
inplace = symbolname.endswith('_inplace')
if inplace:
scalar_op = getattr(scal, symbolname[1:-len('_inplace')])
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
rval = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name=symbolname)
else:
scalar_op = getattr(scal, symbolname)
rval = elemwise.Elemwise(scalar_op, name=symbolname)
if getattr(symbol, '__doc__', ''):
rval.__doc__ = symbol.__doc__ + '\n' + rval.__doc__
return rval
#########################
......@@ -674,93 +697,109 @@ def _elemwise_macro_inplace(scalar_op, *args):
inplace = elemwise.Elemwise(inplace_scalar_op, {0: 0})
return inplace(*args)
@_scal_elemwise
def lt(a, b):
"""a < b"""
return _elemwise_macro(scal.lt, a, b)
@_scal_elemwise
def _lt_inplace(a,b):
"""a < b (inplace on a)"""
return _elemwise_macro_inplace(scal.lt, a, b)
@_scal_elemwise
def gt(a, b):
"""a > b"""
return _elemwise_macro(scal.gt, a, b)
@_scal_elemwise
def _gt_inplace(a,b):
"""a > b (inplace on a)"""
return _elemwise_macro_inplace(scal.gt, a, b)
@_scal_elemwise
def le(a, b):
"""a <= b"""
return _elemwise_macro(scal.le, a, b)
@_scal_elemwise
def _le_inplace(a,b):
"""a <= b (inplace on a)"""
return _elemwise_macro_inplace(scal.le, a, b)
@_scal_elemwise
def ge(a, b):
"""a >= b"""
return _elemwise_macro(scal.ge, a, b)
@_scal_elemwise
def _ge_inplace(a,b):
"""a >= b (inplace on a)"""
return _elemwise_macro_inplace(scal.ge, a, b)
@_scal_elemwise
def eq(a, b):
"""a == b"""
return _elemwise_macro(scal.eq, a, b)
@_scal_elemwise
def _eq_inplace(a,b):
"""a == b (inplace on a)"""
return _elemwise_macro_inplace(scal.eq, a, b)
@_scal_elemwise
def neq(a, b):
"""a != b"""
return _elemwise_macro(scal.neq, a, b)
@_scal_elemwise
def _neq_inplace(a,b):
"""a != b (inplace on a)"""
return _elemwise_macro_inplace(scal.neq, a, b)
##########################
# Bit-wise
##########################
@_scal_elemwise
def and_(a,b):
"""bitwise a & b"""
return _elemwise_macro(scal.and_, a, b)
@_scal_elemwise
def _and__inplace(a,b):
"""bitwise a & b (inplace on a)"""
return _elemwise_macro_inplace(scal.and_, a, b)
@_scal_elemwise
def or_(a,b):
"""bitwise a | b"""
return _elemwise_macro(scal.or_, a, b)
@_scal_elemwise
def _or__inplace(a,b):
"""bitwise a | b (inplace on a)"""
return _elemwise_macro_inplace(scal.or_, a, b)
def xor_(a,b):
@_scal_elemwise
def xor(a,b):
"""bitwise a ^ b"""
return _elemwise_macro(scal.xor_, a, b)
def _xor__inplace(a,b):
@_scal_elemwise
def _xor_inplace(a,b):
"""bitwise a ^ b (inplace on a)"""
return _elemwise_macro_inplace(scal.xor_, a, b)
def invert_(a,b):
"""bitwise a ~ b"""
return _elemwise_macro(scal.invert, a, b)
def _invert__inplace(a,b):
"""bitwise a ~ b (inplace on a)"""
return _elemwise_macro_inplace(scal.invert, a, b)
@_scal_elemwise
def invert(a):
"""bitwise ~a"""
@_scal_elemwise
def _invert_inplace(a):
"""bitwise ~a (inplace on a)"""
##########################
# Math
##########################
_abs, _abs_inplace = _elemwise(scal.abs, 'abs',
"""absolute value (elemwise)""")
@_scal_elemwise
def _abs(*a):
"""|a|
_abs has a leading underscore because abs() is a builtin. TensorResult overloads the
__abs__ operator so that this function is called when you type abs(a).
"""
@_scal_elemwise
def __abs_inplace(a):
"""|a| (inplace on a)"""
exp, _exp_inplace = _elemwise(scal.exp, 'exp',
"""exponential (elemwise)""")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论