提交 11faef91 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

splitted inplace ops from tensor.basic

上级 501dc901
...@@ -4,7 +4,7 @@ theano_path = os.path.realpath("%s/.." % sys.path[0]) ...@@ -4,7 +4,7 @@ theano_path = os.path.realpath("%s/.." % sys.path[0])
sys.path[0:0] = [theano_path] sys.path[0:0] = [theano_path]
def test_module(module_path, debugmode = False): def test_module(module_path, debugmode = False):
files = commands.getoutput("find %s -name _test_*.py" % module_path) files = commands.getoutput("find %s -name test_*.py" % module_path)
suite = None suite = None
tocut = len("/".join(module_path.split("/")[:-1])) + 1 tocut = len("/".join(module_path.split("/")[:-1])) + 1
for file in files.split("\n"): for file in files.split("\n"):
......
from basic import * from basic import *
from basic import _abs
...@@ -183,7 +183,7 @@ complex_types = complex64, complex128 ...@@ -183,7 +183,7 @@ complex_types = complex64, complex128
class _scalar_py_operators: class _scalar_py_operators:
#UNARY #UNARY
def __abs__(self): return _abs(self) def __abs__(self): return abs_(self)
def __neg__(self): return neg(self) def __neg__(self): return neg(self)
#CASTS #CASTS
...@@ -587,7 +587,7 @@ class Abs(UnaryScalarOp): ...@@ -587,7 +587,7 @@ class Abs(UnaryScalarOp):
return "%(z)s = fabs(%(x)s);" % locals() return "%(z)s = fabs(%(x)s);" % locals()
#complex, other? #complex, other?
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
_abs = Abs(same_out) abs_ = Abs(same_out)
class Sgn(UnaryScalarOp): class Sgn(UnaryScalarOp):
def impl(self, x): def impl(self, x):
......
...@@ -11,7 +11,7 @@ def inputs(): ...@@ -11,7 +11,7 @@ def inputs():
return floats('xyz') return floats('xyz')
class _test_ScalarOps(unittest.TestCase): class test_ScalarOps(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
...@@ -21,7 +21,7 @@ class _test_ScalarOps(unittest.TestCase): ...@@ -21,7 +21,7 @@ class _test_ScalarOps(unittest.TestCase):
assert fn(1.0, 2.0) == 1.5 assert fn(1.0, 2.0) == 1.5
class _test_composite(unittest.TestCase): class test_composite(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
...@@ -57,7 +57,7 @@ class _test_composite(unittest.TestCase): ...@@ -57,7 +57,7 @@ class _test_composite(unittest.TestCase):
assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5] assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5]
class _test_logical(unittest.TestCase): class test_logical(unittest.TestCase):
def test_gt(self): def test_gt(self):
x, y, z = inputs() x, y, z = inputs()
fn = gof.DualLinker().accept(Env([x,y], [x > y])).make_function() fn = gof.DualLinker().accept(Env([x,y], [x > y])).make_function()
......
...@@ -147,7 +147,7 @@ class T_conversion(unittest.TestCase): ...@@ -147,7 +147,7 @@ class T_conversion(unittest.TestCase):
self.failUnless(numpy.all(val[0] == [1,0,0,0,0])) self.failUnless(numpy.all(val[0] == [1,0,0,0,0]))
class _testCase_dot(unittest.TestCase): class test_dot(unittest.TestCase):
def setUp(self): def setUp(self):
numpy.random.seed(44) numpy.random.seed(44)
......
...@@ -450,7 +450,7 @@ cols, fcols, dcols, icols, lcols = _multi(col, fcol, dcol, icol, lcol) ...@@ -450,7 +450,7 @@ cols, fcols, dcols, icols, lcols = _multi(col, fcol, dcol, icol, lcol)
class _tensor_py_operators: class _tensor_py_operators:
#UNARY #UNARY
def __abs__(self): return _abs(self) def __abs__(self): return abs_(self)
def __neg__(self): return neg(self) def __neg__(self): return neg(self)
#CASTS #CASTS
...@@ -472,9 +472,9 @@ class _tensor_py_operators: ...@@ -472,9 +472,9 @@ class _tensor_py_operators:
def __rand__(self,other): return and_(other,self) def __rand__(self,other): return and_(other,self)
def __ror__(self,other): return or_(other, self) def __ror__(self,other): return or_(other, self)
def __rxor__(self,other): return xor(other, self) def __rxor__(self,other): return xor(other, self)
def __iand__(self, other): return _and_inplace(self, other) # def __iand__(self, other): return _and_inplace(self, other)
def __ior__(self, other): return _or_inplace(self, other) # def __ior__(self, other): return _or_inplace(self, other)
def __ixor__(self, other): return _xor_inplace(self, other) # def __ixor__(self, other): return _xor_inplace(self, other)
#ARITHMETIC - NORMAL #ARITHMETIC - NORMAL
def __add__(self,other): return add(self,other) def __add__(self,other): return add(self,other)
...@@ -484,12 +484,12 @@ class _tensor_py_operators: ...@@ -484,12 +484,12 @@ class _tensor_py_operators:
def __pow__(self,other): return pow(self,other) def __pow__(self,other): return pow(self,other)
def __mod__(self,other): return mod(self,other) def __mod__(self,other): return mod(self,other)
#ARITHMETIC - INPLACE # #ARITHMETIC - INPLACE
def __iadd__(self,other): return _add_inplace(self,other) # def __iadd__(self,other): return _add_inplace(self,other)
def __isub__(self,other): return _sub_inplace(self,other) # def __isub__(self,other): return _sub_inplace(self,other)
def __imul__(self,other): return _mul_inplace(self,other) # def __imul__(self,other): return _mul_inplace(self,other)
def __idiv__(self,other): return _div_inplace(self,other) # def __idiv__(self,other): return _div_inplace(self,other)
def __ipow__(self,other): return _pow_inplace(self,other) # def __ipow__(self,other): return _pow_inplace(self,other)
#ARITHMETIC - RIGHT-OPERAND #ARITHMETIC - RIGHT-OPERAND
def __radd__(self,other): return add(other,self) def __radd__(self,other): return add(other,self)
...@@ -562,7 +562,7 @@ elemwise.TensorValue = TensorValue ...@@ -562,7 +562,7 @@ elemwise.TensorValue = TensorValue
def _elemwise(scalar_op, name, doc_prefix=''): def _elemwise(scalar_op, name, doc_prefix=''):
straight = elemwise.Elemwise(scalar_op, name = name) straight = elemwise.Elemwise(scalar_op, name = name)
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0)) inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
inplace = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name = '_'+name+"_inplace") inplace = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name = name+"_inplace")
# don't add the inplace versions, they aren't supposed to be part of the user interface # don't add the inplace versions, they aren't supposed to be part of the user interface
_constructor_list.append(straight) _constructor_list.append(straight)
...@@ -599,7 +599,7 @@ def _scal_elemwise(symbol): ...@@ -599,7 +599,7 @@ def _scal_elemwise(symbol):
inplace = symbolname.endswith('_inplace') inplace = symbolname.endswith('_inplace')
if inplace: if inplace:
scalar_op = getattr(scal, symbolname[1:-len('_inplace')]) scalar_op = getattr(scal, symbolname[:-len('_inplace')])
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0)) inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
rval = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name=symbolname) rval = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name=symbolname)
else: else:
...@@ -780,50 +780,26 @@ def argmax(x, axis=None): ...@@ -780,50 +780,26 @@ def argmax(x, axis=None):
def lt(a, b): def lt(a, b):
"""a < b""" """a < b"""
@_scal_elemwise
def _lt_inplace(a,b):
"""a < b (inplace on a)"""
@_scal_elemwise @_scal_elemwise
def gt(a, b): def gt(a, b):
"""a > b""" """a > b"""
@_scal_elemwise
def _gt_inplace(a,b):
"""a > b (inplace on a)"""
@_scal_elemwise @_scal_elemwise
def le(a, b): def le(a, b):
"""a <= b""" """a <= b"""
@_scal_elemwise
def _le_inplace(a,b):
"""a <= b (inplace on a)"""
@_scal_elemwise @_scal_elemwise
def ge(a, b): def ge(a, b):
"""a >= b""" """a >= b"""
@_scal_elemwise
def _ge_inplace(a,b):
"""a >= b (inplace on a)"""
@_scal_elemwise @_scal_elemwise
def eq(a, b): def eq(a, b):
"""a == b""" """a == b"""
@_scal_elemwise
def _eq_inplace(a,b):
"""a == b (inplace on a)"""
@_scal_elemwise @_scal_elemwise
def neq(a, b): def neq(a, b):
"""a != b""" """a != b"""
@_scal_elemwise
def _neq_inplace(a,b):
"""a != b (inplace on a)"""
########################## ##########################
# Bit-wise # Bit-wise
...@@ -833,148 +809,86 @@ def _neq_inplace(a,b): ...@@ -833,148 +809,86 @@ def _neq_inplace(a,b):
def and_(a,b): def and_(a,b):
"""bitwise a & b""" """bitwise a & b"""
@_scal_elemwise
def _and__inplace(a,b):
"""bitwise a & b (inplace on a)"""
@_scal_elemwise @_scal_elemwise
def or_(a,b): def or_(a,b):
"""bitwise a | b""" """bitwise a | b"""
@_scal_elemwise
def _or__inplace(a,b):
"""bitwise a | b (inplace on a)"""
@_scal_elemwise @_scal_elemwise
def xor(a,b): def xor(a,b):
"""bitwise a ^ b""" """bitwise a ^ b"""
@_scal_elemwise
def _xor_inplace(a,b):
"""bitwise a ^ b (inplace on a)"""
@_scal_elemwise @_scal_elemwise
def invert(a): def invert(a):
"""bitwise ~a""" """bitwise ~a"""
@_scal_elemwise
def _invert_inplace(a):
"""bitwise ~a (inplace on a)"""
########################## ##########################
# Math # Math
########################## ##########################
@_scal_elemwise @_scal_elemwise
def _abs(a): def abs_(a):
"""|`a`| """|`a`|
_abs has a leading underscore because abs() is a builtin. TensorResult overloads the TensorResult overloads the `TensorResult.__abs__` operator so that
`TensorResult.__abs__` operator so that this function is called when you type abs(a). this function is called when you type abs(a).
""" """
@_scal_elemwise
def __abs_inplace(a):
"""|`a`| (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def exp(a): def exp(a):
"""e^`a`""" """e^`a`"""
@_scal_elemwise
def _exp_inplace(a):
"""e^`a` (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def neg(a): def neg(a):
"""-a""" """-a"""
@_scal_elemwise
def _neg_inplace(a):
"""-a (inplace on a)"""
@_scal_elemwise @_scal_elemwise
def inv(a): def inv(a):
"""1.0/a (inplace on a)""" """1.0/a (inplace on a)"""
@_scal_elemwise
def _inv_inplace(a):
"""1.0/a (inplace on a)"""
@_scal_elemwise @_scal_elemwise
def log(a): def log(a):
"""base e logarithm of a""" """base e logarithm of a"""
@_scal_elemwise
def _log_inplace(a):
"""base e logarithm of a (inplace on a)"""
@_scal_elemwise @_scal_elemwise
def log2(a): def log2(a):
"""base 2 logarithm of a""" """base 2 logarithm of a"""
@_scal_elemwise
def _log2_inplace(a):
"""base 2 logarithm of a (inplace on a)"""
@_scal_elemwise @_scal_elemwise
def sgn(a): def sgn(a):
"""sign of a""" """sign of a"""
@_scal_elemwise
def _sgn_inplace(a):
"""sign of `a` (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def sqr(a): def sqr(a):
"""square of a""" """square of a"""
@_scal_elemwise
def _sqr_inplace(a):
"""square of `a` (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def sqrt(a): def sqrt(a):
"""square root of a""" """square root of a"""
@_scal_elemwise
def _sqrt_inplace(a):
"""square root of `a` (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def cos(a): def cos(a):
"""cosine of a""" """cosine of a"""
@_scal_elemwise
def _cos_inplace(a):
"""cosine of `a` (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def sin(a): def sin(a):
"""sine of a""" """sine of a"""
@_scal_elemwise
def _sin_inplace(a):
"""sine of `a` (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def tan(a): def tan(a):
"""tangent of a""" """tangent of a"""
@_scal_elemwise
def _tan_inplace(a):
"""tangent of `a` (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def cosh(a): def cosh(a):
"""hyperbolic cosine of a""" """hyperbolic cosine of a"""
@_scal_elemwise
def _cosh_inplace(a):
"""hyperbolic cosine of `a` (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def sinh(a): def sinh(a):
"""hyperbolic sine of a""" """hyperbolic sine of a"""
@_scal_elemwise
def _sinh_inplace(a):
"""hyperbolic sine of `a` (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def tanh(a): def tanh(a):
"""hyperbolic tangent of a""" """hyperbolic tangent of a"""
@_scal_elemwise
def _tanh_inplace(a):
"""hyperbolic tangent of `a` (inplace on `a`)"""
########################## ##########################
...@@ -986,12 +900,8 @@ def _tanh_inplace(a): ...@@ -986,12 +900,8 @@ def _tanh_inplace(a):
@_scal_elemwise @_scal_elemwise
def second(a, b): def second(a, b):
"""Create a matrix by filling the shape of a with b""" """Create a matrix by filling the shape of a with b"""
@_scal_elemwise
def _second_inplace(a):
"""Fill `a` with `b`"""
fill = second fill = second
_fill_inplace = _second_inplace
@constructor @constructor
def ones_like(model): def ones_like(model):
...@@ -1112,44 +1022,26 @@ repeat = Repeat() ...@@ -1112,44 +1022,26 @@ repeat = Repeat()
@_scal_elemwise @_scal_elemwise
def add(a, b): def add(a, b):
"""elementwise addition""" """elementwise addition"""
@_scal_elemwise
def _add_inplace(a, b):
"""elementwise addition (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def sub(a, b): def sub(a, b):
"""elementwise subtraction""" """elementwise subtraction"""
@_scal_elemwise
def _sub_inplace(a, b):
"""elementwise subtraction (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def mul(a, b): def mul(a, b):
"""elementwise multiplication""" """elementwise multiplication"""
@_scal_elemwise
def _mul_inplace(a, b):
"""elementwise multiplication (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def div(a, b): def div(a, b):
"""elementwise division""" """elementwise division"""
@_scal_elemwise
def _div_inplace(a, b):
"""elementwise division (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def mod(a, b): def mod(a, b):
"""elementwise modulo""" """elementwise modulo"""
@_scal_elemwise
def _mod_inplace(a, b):
"""elementwise modulo (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def pow(a, b): def pow(a, b):
"""elementwise power""" """elementwise power"""
@_scal_elemwise
def _pow_inplace(a, b):
"""elementwise power (inplace on `a`)"""
########################## ##########################
...@@ -1182,7 +1074,6 @@ class TransposeInplace(Op): ...@@ -1182,7 +1074,6 @@ class TransposeInplace(Op):
return "TransposeView" return "TransposeView"
_transpose_inplace = TransposeInplace() _transpose_inplace = TransposeInplace()
"""WRITEME"""
def transpose(x, **kwargs): def transpose(x, **kwargs):
"""WRITEME""" """WRITEME"""
......
from basic import _scal_elemwise, _transpose_inplace
@_scal_elemwise
def lt_inplace(a,b):
"""a < b (inplace on a)"""
@_scal_elemwise
def gt_inplace(a,b):
"""a > b (inplace on a)"""
@_scal_elemwise
def le_inplace(a,b):
"""a <= b (inplace on a)"""
@_scal_elemwise
def ge_inplace(a,b):
"""a >= b (inplace on a)"""
@_scal_elemwise
def eq_inplace(a,b):
"""a == b (inplace on a)"""
@_scal_elemwise
def neq_inplace(a,b):
"""a != b (inplace on a)"""
@_scal_elemwise
def and__inplace(a,b):
"""bitwise a & b (inplace on a)"""
@_scal_elemwise
def or__inplace(a,b):
"""bitwise a | b (inplace on a)"""
@_scal_elemwise
def xor_inplace(a,b):
"""bitwise a ^ b (inplace on a)"""
@_scal_elemwise
def invert_inplace(a):
"""bitwise ~a (inplace on a)"""
@_scal_elemwise
def abs__inplace(a):
"""|`a`| (inplace on `a`)"""
@_scal_elemwise
def exp_inplace(a):
"""e^`a` (inplace on `a`)"""
@_scal_elemwise
def neg_inplace(a):
"""-a (inplace on a)"""
@_scal_elemwise
def inv_inplace(a):
"""1.0/a (inplace on a)"""
@_scal_elemwise
def log_inplace(a):
"""base e logarithm of a (inplace on a)"""
@_scal_elemwise
def log2_inplace(a):
"""base 2 logarithm of a (inplace on a)"""
@_scal_elemwise
def sgn_inplace(a):
"""sign of `a` (inplace on `a`)"""
@_scal_elemwise
def sqr_inplace(a):
"""square of `a` (inplace on `a`)"""
@_scal_elemwise
def sqrt_inplace(a):
"""square root of `a` (inplace on `a`)"""
@_scal_elemwise
def cos_inplace(a):
"""cosine of `a` (inplace on `a`)"""
@_scal_elemwise
def sin_inplace(a):
"""sine of `a` (inplace on `a`)"""
@_scal_elemwise
def tan_inplace(a):
"""tangent of `a` (inplace on `a`)"""
@_scal_elemwise
def cosh_inplace(a):
"""hyperbolic cosine of `a` (inplace on `a`)"""
@_scal_elemwise
def sinh_inplace(a):
"""hyperbolic sine of `a` (inplace on `a`)"""
@_scal_elemwise
def tanh_inplace(a):
"""hyperbolic tangent of `a` (inplace on `a`)"""
@_scal_elemwise
def second_inplace(a):
"""Fill `a` with `b`"""
fill_inplace = second_inplace
@_scal_elemwise
def add_inplace(a, b):
"""elementwise addition (inplace on `a`)"""
@_scal_elemwise
def sub_inplace(a, b):
"""elementwise subtraction (inplace on `a`)"""
@_scal_elemwise
def mul_inplace(a, b):
"""elementwise multiplication (inplace on `a`)"""
@_scal_elemwise
def div_inplace(a, b):
"""elementwise division (inplace on `a`)"""
@_scal_elemwise
def mod_inplace(a, b):
"""elementwise modulo (inplace on `a`)"""
@_scal_elemwise
def pow_inplace(a, b):
"""elementwise power (inplace on `a`)"""
transpose_inplace = _transpose_inplace
"""WRITEME"""
...@@ -8,6 +8,7 @@ from ..gof import opt ...@@ -8,6 +8,7 @@ from ..gof import opt
from elemwise import Elemwise, DimShuffle from elemwise import Elemwise, DimShuffle
from .. import scalar from .. import scalar
import basic as T import basic as T
import inplace as I
import numpy as N import numpy as N
import operator import operator
import itertools import itertools
...@@ -30,7 +31,7 @@ def in2out(*local_opts): ...@@ -30,7 +31,7 @@ def in2out(*local_opts):
# gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c) # gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c)
# Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0) # Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0)
gemm_pattern_1 = gof.PatternSub((T._sub_inplace, gemm_pattern_1 = gof.PatternSub((I.sub_inplace,
'd', 'd',
(T.mul, (T.mul,
dict(pattern = (T.DimShuffle((), ['x', 'x'], inplace = True), 'a'), dict(pattern = (T.DimShuffle((), ['x', 'x'], inplace = True), 'a'),
......
...@@ -3,6 +3,7 @@ import operator ...@@ -3,6 +3,7 @@ import operator
from theano.tensor import * from theano.tensor import *
from theano.tensor import basic as tensor # for hidden symbols from theano.tensor import basic as tensor # for hidden symbols
from theano.tensor import inplace
import unittest import unittest
from copy import copy from copy import copy
...@@ -232,7 +233,9 @@ AddTester = make_broadcast_restet(op = add, ...@@ -232,7 +233,9 @@ AddTester = make_broadcast_restet(op = add,
**_good_broadcast_binary_normal), **_good_broadcast_binary_normal),
bad_build = _bad_build_broadcast_binary_normal, bad_build = _bad_build_broadcast_binary_normal,
bad_runtime = _bad_runtime_broadcast_binary_normal) bad_runtime = _bad_runtime_broadcast_binary_normal)
AddInplaceTester = make_broadcast_restet(op = tensor._add_inplace,
AddInplaceTester = make_broadcast_restet(op = inplace.add_inplace,
expected = lambda x, y: x + y, expected = lambda x, y: x + y,
good = _good_broadcast_binary_normal, good = _good_broadcast_binary_normal,
bad_build = _bad_build_broadcast_binary_normal, bad_build = _bad_build_broadcast_binary_normal,
...@@ -246,7 +249,7 @@ SubTester = make_broadcast_restet(op = sub, ...@@ -246,7 +249,7 @@ SubTester = make_broadcast_restet(op = sub,
bad_runtime = _bad_runtime_broadcast_binary_normal, bad_runtime = _bad_runtime_broadcast_binary_normal,
grad = _grad_broadcast_binary_normal) grad = _grad_broadcast_binary_normal)
SubInplaceTester = make_broadcast_restet(op = tensor._sub_inplace, SubInplaceTester = make_broadcast_restet(op = inplace.sub_inplace,
expected = lambda x, y: x - y, expected = lambda x, y: x - y,
good = _good_broadcast_binary_normal, good = _good_broadcast_binary_normal,
bad_build = _bad_build_broadcast_binary_normal, bad_build = _bad_build_broadcast_binary_normal,
...@@ -264,7 +267,7 @@ MulTester = make_broadcast_restet(op = mul, ...@@ -264,7 +267,7 @@ MulTester = make_broadcast_restet(op = mul,
grad = dict(three_inputs_same_shapes = (rand(2, 3), rand(2, 3), rand(2, 3)), grad = dict(three_inputs_same_shapes = (rand(2, 3), rand(2, 3), rand(2, 3)),
four_inputs_broadcast = (rand(2, 3), rand(1, 3), rand(2, 1), rand(1, 1)), four_inputs_broadcast = (rand(2, 3), rand(1, 3), rand(2, 1), rand(1, 1)),
**_grad_broadcast_binary_normal)) **_grad_broadcast_binary_normal))
MulInplaceTester = make_broadcast_restet(op = tensor._mul_inplace, MulInplaceTester = make_broadcast_restet(op = inplace.mul_inplace,
expected = lambda x, y: x * y, expected = lambda x, y: x * y,
good = _good_broadcast_binary_normal, good = _good_broadcast_binary_normal,
bad_build = _bad_build_broadcast_binary_normal, bad_build = _bad_build_broadcast_binary_normal,
...@@ -290,7 +293,7 @@ DivTester = make_broadcast_restet(op = div, ...@@ -290,7 +293,7 @@ DivTester = make_broadcast_restet(op = div,
scalar = (rand(2, 3), rand(1, 1)), scalar = (rand(2, 3), rand(1, 1)),
row = (rand(2, 3), rand(1, 3)), row = (rand(2, 3), rand(1, 3)),
column = (rand(2, 3), rand(2, 1)))) column = (rand(2, 3), rand(2, 1))))
DivInplaceTester = make_broadcast_restet(op = tensor._div_inplace, DivInplaceTester = make_broadcast_restet(op = inplace.div_inplace,
expected = lambda x, y: x / y, expected = lambda x, y: x / y,
good = dict(same_shapes = (rand(2, 3), rand(2, 3)), good = dict(same_shapes = (rand(2, 3), rand(2, 3)),
scalar = (rand(2, 3), rand(1, 1)), scalar = (rand(2, 3), rand(1, 1)),
...@@ -320,7 +323,7 @@ ModTester = make_broadcast_restet(op = mod, ...@@ -320,7 +323,7 @@ ModTester = make_broadcast_restet(op = mod,
# dtype_mixup_1 = (rand(2, 3), randint_nonzero(2, 3)), # dtype_mixup_1 = (rand(2, 3), randint_nonzero(2, 3)),
# dtype_mixup_2 = (randint_nonzero(2, 3), rand(2, 3))), # dtype_mixup_2 = (randint_nonzero(2, 3), rand(2, 3))),
) )
ModInplaceTester = make_broadcast_restet(op = tensor._mod_inplace, ModInplaceTester = make_broadcast_restet(op = inplace.mod_inplace,
expected = lambda x, y: x % y, expected = lambda x, y: x % y,
good = dict(same_shapes = (rand(2, 3), rand(2, 3)), good = dict(same_shapes = (rand(2, 3), rand(2, 3)),
scalar = (rand(2, 3), rand(1, 1)), scalar = (rand(2, 3), rand(1, 1)),
...@@ -343,7 +346,7 @@ PowTester = make_broadcast_restet(op = pow, ...@@ -343,7 +346,7 @@ PowTester = make_broadcast_restet(op = pow,
row = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (1, 3))), row = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (1, 3))),
column = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (2, 1)))) column = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (2, 1))))
) )
PowInplaceTester = make_broadcast_restet(op = tensor._pow_inplace, PowInplaceTester = make_broadcast_restet(op = inplace.pow_inplace,
expected = lambda x, y: x ** y, expected = lambda x, y: x ** y,
good = dict(same_shapes = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (2, 3))), good = dict(same_shapes = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (2, 3))),
scalar = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (1, 1))), scalar = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (1, 1))),
...@@ -364,11 +367,11 @@ _good_broadcast_unary_normal = dict(normal = (rand_ranged(-5, 5, (2, 3)),), ...@@ -364,11 +367,11 @@ _good_broadcast_unary_normal = dict(normal = (rand_ranged(-5, 5, (2, 3)),),
_grad_broadcast_unary_normal = dict(normal = (rand_ranged(-5, 5, (2, 3)),)) _grad_broadcast_unary_normal = dict(normal = (rand_ranged(-5, 5, (2, 3)),))
AbsTester = make_broadcast_restet(op = tensor._abs, AbsTester = make_broadcast_restet(op = tensor.abs_,
expected = lambda x: abs(x), expected = lambda x: abs(x),
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal) grad = _grad_broadcast_unary_normal)
AbsInplaceTester = make_broadcast_restet(op = tensor.__abs_inplace, AbsInplaceTester = make_broadcast_restet(op = inplace.abs__inplace,
expected = lambda x: numpy.abs(x), expected = lambda x: numpy.abs(x),
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal, grad = _grad_broadcast_unary_normal,
...@@ -378,7 +381,7 @@ NegTester = make_broadcast_restet(op = neg, ...@@ -378,7 +381,7 @@ NegTester = make_broadcast_restet(op = neg,
expected = lambda x: -x, expected = lambda x: -x,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal) grad = _grad_broadcast_unary_normal)
NegInplaceTester = make_broadcast_restet(op = tensor._neg_inplace, NegInplaceTester = make_broadcast_restet(op = inplace.neg_inplace,
expected = lambda x: -x, expected = lambda x: -x,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal, grad = _grad_broadcast_unary_normal,
...@@ -387,7 +390,7 @@ NegInplaceTester = make_broadcast_restet(op = tensor._neg_inplace, ...@@ -387,7 +390,7 @@ NegInplaceTester = make_broadcast_restet(op = tensor._neg_inplace,
SgnTester = make_broadcast_restet(op = sgn, SgnTester = make_broadcast_restet(op = sgn,
expected = numpy.sign, expected = numpy.sign,
good = _good_broadcast_unary_normal) good = _good_broadcast_unary_normal)
SgnInplaceTester = make_broadcast_restet(op = tensor._sgn_inplace, SgnInplaceTester = make_broadcast_restet(op = inplace.sgn_inplace,
expected = numpy.sign, expected = numpy.sign,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
inplace = True) inplace = True)
...@@ -396,7 +399,7 @@ SqrTester = make_broadcast_restet(op = sqr, ...@@ -396,7 +399,7 @@ SqrTester = make_broadcast_restet(op = sqr,
expected = numpy.square, expected = numpy.square,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal) grad = _grad_broadcast_unary_normal)
SqrInplaceTester = make_broadcast_restet(op = tensor._sqr_inplace, SqrInplaceTester = make_broadcast_restet(op = inplace.sqr_inplace,
expected = numpy.square, expected = numpy.square,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal, grad = _grad_broadcast_unary_normal,
...@@ -406,7 +409,7 @@ ExpTester = make_broadcast_restet(op = exp, ...@@ -406,7 +409,7 @@ ExpTester = make_broadcast_restet(op = exp,
expected = numpy.exp, expected = numpy.exp,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal) grad = _grad_broadcast_unary_normal)
ExpInplaceTester = make_broadcast_restet(op = tensor._exp_inplace, ExpInplaceTester = make_broadcast_restet(op = inplace.exp_inplace,
expected = numpy.exp, expected = numpy.exp,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal, grad = _grad_broadcast_unary_normal,
...@@ -422,7 +425,7 @@ LogTester = make_broadcast_restet(op = log, ...@@ -422,7 +425,7 @@ LogTester = make_broadcast_restet(op = log,
expected = numpy.log, expected = numpy.log,
good = _good_broadcast_unary_positive, good = _good_broadcast_unary_positive,
grad = _grad_broadcast_unary_positive) grad = _grad_broadcast_unary_positive)
LogInplaceTester = make_broadcast_restet(op = tensor._log_inplace, LogInplaceTester = make_broadcast_restet(op = inplace.log_inplace,
expected = numpy.log, expected = numpy.log,
good = _good_broadcast_unary_positive, good = _good_broadcast_unary_positive,
grad = _grad_broadcast_unary_positive, grad = _grad_broadcast_unary_positive,
...@@ -432,7 +435,7 @@ Log2Tester = make_broadcast_restet(op = log2, ...@@ -432,7 +435,7 @@ Log2Tester = make_broadcast_restet(op = log2,
expected = numpy.log2, expected = numpy.log2,
good = _good_broadcast_unary_positive, good = _good_broadcast_unary_positive,
grad = _grad_broadcast_unary_positive) grad = _grad_broadcast_unary_positive)
Log2InplaceTester = make_broadcast_restet(op = tensor._log2_inplace, Log2InplaceTester = make_broadcast_restet(op = inplace.log2_inplace,
expected = numpy.log2, expected = numpy.log2,
good = _good_broadcast_unary_positive, good = _good_broadcast_unary_positive,
grad = _grad_broadcast_unary_positive, grad = _grad_broadcast_unary_positive,
...@@ -442,7 +445,7 @@ SqrtTester = make_broadcast_restet(op = sqrt, ...@@ -442,7 +445,7 @@ SqrtTester = make_broadcast_restet(op = sqrt,
expected = numpy.sqrt, expected = numpy.sqrt,
good = _good_broadcast_unary_positive, good = _good_broadcast_unary_positive,
grad = _grad_broadcast_unary_positive) grad = _grad_broadcast_unary_positive)
SqrtInplaceTester = make_broadcast_restet(op = tensor._sqrt_inplace, SqrtInplaceTester = make_broadcast_restet(op = inplace.sqrt_inplace,
expected = numpy.sqrt, expected = numpy.sqrt,
good = _good_broadcast_unary_positive, good = _good_broadcast_unary_positive,
grad = _grad_broadcast_unary_positive, grad = _grad_broadcast_unary_positive,
...@@ -460,7 +463,7 @@ SinTester = make_broadcast_restet(op = sin, ...@@ -460,7 +463,7 @@ SinTester = make_broadcast_restet(op = sin,
expected = numpy.sin, expected = numpy.sin,
good = _good_broadcast_unary_wide, good = _good_broadcast_unary_wide,
grad = _grad_broadcast_unary_wide) grad = _grad_broadcast_unary_wide)
SinInplaceTester = make_broadcast_restet(op = tensor._sin_inplace, SinInplaceTester = make_broadcast_restet(op = inplace.sin_inplace,
expected = numpy.sin, expected = numpy.sin,
good = _good_broadcast_unary_wide, good = _good_broadcast_unary_wide,
grad = _grad_broadcast_unary_wide, grad = _grad_broadcast_unary_wide,
...@@ -470,7 +473,7 @@ CosTester = make_broadcast_restet(op = cos, ...@@ -470,7 +473,7 @@ CosTester = make_broadcast_restet(op = cos,
expected = numpy.cos, expected = numpy.cos,
good = _good_broadcast_unary_wide, good = _good_broadcast_unary_wide,
grad = _grad_broadcast_unary_wide) grad = _grad_broadcast_unary_wide)
CosInplaceTester = make_broadcast_restet(op = tensor._cos_inplace, CosInplaceTester = make_broadcast_restet(op = inplace.cos_inplace,
expected = numpy.cos, expected = numpy.cos,
good = _good_broadcast_unary_wide, good = _good_broadcast_unary_wide,
grad = _grad_broadcast_unary_wide, grad = _grad_broadcast_unary_wide,
...@@ -482,7 +485,7 @@ TanTester = make_broadcast_restet(op = tan, ...@@ -482,7 +485,7 @@ TanTester = make_broadcast_restet(op = tan,
shifted = (rand_ranged(3.15, 6.28, (2, 3)),)), shifted = (rand_ranged(3.15, 6.28, (2, 3)),)),
grad = dict(normal = (rand_ranged(-3.14, 3.14, (2, 3)),), grad = dict(normal = (rand_ranged(-3.14, 3.14, (2, 3)),),
shifted = (rand_ranged(3.15, 6.28, (2, 3)),))) shifted = (rand_ranged(3.15, 6.28, (2, 3)),)))
TanInplaceTester = make_broadcast_restet(op = tensor._tan_inplace, TanInplaceTester = make_broadcast_restet(op = inplace.tan_inplace,
expected = numpy.tan, expected = numpy.tan,
good = dict(normal = (rand_ranged(-3.14, 3.14, (2, 3)),), good = dict(normal = (rand_ranged(-3.14, 3.14, (2, 3)),),
shifted = (rand_ranged(3.15, 6.28, (2, 3)),)), shifted = (rand_ranged(3.15, 6.28, (2, 3)),)),
...@@ -495,7 +498,7 @@ CoshTester = make_broadcast_restet(op = cosh, ...@@ -495,7 +498,7 @@ CoshTester = make_broadcast_restet(op = cosh,
expected = numpy.cosh, expected = numpy.cosh,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal) grad = _grad_broadcast_unary_normal)
CoshInplaceTester = make_broadcast_restet(op = tensor._cosh_inplace, CoshInplaceTester = make_broadcast_restet(op = inplace.cosh_inplace,
expected = numpy.cosh, expected = numpy.cosh,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal, grad = _grad_broadcast_unary_normal,
...@@ -505,7 +508,7 @@ SinhTester = make_broadcast_restet(op = sinh, ...@@ -505,7 +508,7 @@ SinhTester = make_broadcast_restet(op = sinh,
expected = numpy.sinh, expected = numpy.sinh,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal) grad = _grad_broadcast_unary_normal)
SinhInplaceTester = make_broadcast_restet(op = tensor._sinh_inplace, SinhInplaceTester = make_broadcast_restet(op = inplace.sinh_inplace,
expected = numpy.sinh, expected = numpy.sinh,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal, grad = _grad_broadcast_unary_normal,
...@@ -515,7 +518,7 @@ TanhTester = make_broadcast_restet(op = tanh, ...@@ -515,7 +518,7 @@ TanhTester = make_broadcast_restet(op = tanh,
expected = numpy.tanh, expected = numpy.tanh,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal) grad = _grad_broadcast_unary_normal)
TanhInplaceTester = make_broadcast_restet(op = tensor._tanh_inplace, TanhInplaceTester = make_broadcast_restet(op = inplace.tanh_inplace,
expected = numpy.tanh, expected = numpy.tanh,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal, grad = _grad_broadcast_unary_normal,
...@@ -663,7 +666,7 @@ class T_transpose(unittest.TestCase): ...@@ -663,7 +666,7 @@ class T_transpose(unittest.TestCase):
def test0(self): def test0(self):
n = as_tensor(numpy.ones(())) n = as_tensor(numpy.ones(()))
t = transpose(n) t = transpose(n)
self.failUnless(t.owner.op == tensor._transpose_inplace) self.failUnless(t.owner.op == inplace.transpose_inplace)
f = function([n], t) f = function([n], t)
tval = f(n.data) tval = f(n.data)
self.failUnless(tval.shape == n.data.shape) self.failUnless(tval.shape == n.data.shape)
...@@ -675,7 +678,7 @@ class T_transpose(unittest.TestCase): ...@@ -675,7 +678,7 @@ class T_transpose(unittest.TestCase):
def test1(self): def test1(self):
n = as_tensor(numpy.ones(5)) n = as_tensor(numpy.ones(5))
t = transpose(n) t = transpose(n)
self.failUnless(t.owner.op == tensor._transpose_inplace) self.failUnless(t.owner.op == inplace.transpose_inplace)
f = function([n], t) f = function([n], t)
tval = f(n.data) tval = f(n.data)
self.failUnless(tval.shape == n.data.shape) self.failUnless(tval.shape == n.data.shape)
...@@ -686,7 +689,7 @@ class T_transpose(unittest.TestCase): ...@@ -686,7 +689,7 @@ class T_transpose(unittest.TestCase):
def test2(self): def test2(self):
n = as_tensor(numpy.ones((5,3))) n = as_tensor(numpy.ones((5,3)))
t = transpose(n) t = transpose(n)
self.failUnless(t.owner.op == tensor._transpose_inplace) self.failUnless(t.owner.op == inplace.transpose_inplace)
f = function([n], t) f = function([n], t)
tval = f(n.data) tval = f(n.data)
self.failUnless(tval.shape == (3,5)) self.failUnless(tval.shape == (3,5))
...@@ -697,8 +700,8 @@ class T_transpose(unittest.TestCase): ...@@ -697,8 +700,8 @@ class T_transpose(unittest.TestCase):
def test3(self): def test3(self):
"""Test transpose of tensor, inplace version""" """Test transpose of tensor, inplace version"""
n = as_tensor(numpy.ones((5,3,2))) n = as_tensor(numpy.ones((5,3,2)))
t = tensor._transpose_inplace(n) t = inplace.transpose_inplace(n)
self.failUnless(t.owner.op == tensor._transpose_inplace) self.failUnless(t.owner.op == inplace.transpose_inplace)
f = function([n], t) f = function([n], t)
tval = f(n.data) tval = f(n.data)
self.failUnless(tval.shape == (2,3,5)) self.failUnless(tval.shape == (2,3,5))
...@@ -706,8 +709,8 @@ class T_transpose(unittest.TestCase): ...@@ -706,8 +709,8 @@ class T_transpose(unittest.TestCase):
tval += 55.0 tval += 55.0
self.failUnless(n.data[0,0,0] == 56.0) self.failUnless(n.data[0,0,0] == 56.0)
def test_grad(self): def test_grad(self):
verify_grad(self, tensor._transpose_inplace, [numpy.random.rand(2, 3)]) verify_grad(self, inplace.transpose_inplace, [numpy.random.rand(2, 3)])
verify_grad(self, tensor._transpose_inplace, [numpy.ones(3)]) verify_grad(self, inplace.transpose_inplace, [numpy.ones(3)])
class T_subtensor(unittest.TestCase): class T_subtensor(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -1011,7 +1014,7 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -1011,7 +1014,7 @@ class T_Join_and_Split(unittest.TestCase):
verify_grad(self, lambda a, b: join(1,a,b), [v, 2*v]) verify_grad(self, lambda a, b: join(1,a,b), [v, 2*v])
class _test_comparison(unittest.TestCase): class test_comparison(unittest.TestCase):
def test_gt(self): def test_gt(self):
x, y = fvector(), fvector() x, y = fvector(), fvector()
fn = function([x,y], x > y) fn = function([x,y], x > y)
...@@ -1060,7 +1063,7 @@ class _test_comparison(unittest.TestCase): ...@@ -1060,7 +1063,7 @@ class _test_comparison(unittest.TestCase):
v = fn(l, r) v = fn(l, r)
self.failUnless(numpy.all(v == (l != r)), (v, (l!=r))) self.failUnless(numpy.all(v == (l != r)), (v, (l!=r)))
class _test_bitwise(unittest.TestCase): class test_bitwise(unittest.TestCase):
def test_or(self): def test_or(self):
x, y = bvector(), bvector() x, y = bvector(), bvector()
fn = function([x,y], x|y) fn = function([x,y], x|y)
...@@ -1073,7 +1076,7 @@ class _test_bitwise(unittest.TestCase): ...@@ -1073,7 +1076,7 @@ class _test_bitwise(unittest.TestCase):
x, y = bvector(), bvector() x, y = bvector(), bvector()
fn = function([x,y], x^y) fn = function([x,y], x^y)
ix = x ix = x
ix ^= y ix = inplace.xor_inplace(ix, y)
gn = function([x,y], ix) gn = function([x,y], ix)
l = numpy.asarray([0,0,1,1], dtype = 'int8') l = numpy.asarray([0,0,1,1], dtype = 'int8')
r = numpy.asarray([0,1,0,1], dtype = 'int8') r = numpy.asarray([0,1,0,1], dtype = 'int8')
...@@ -1131,7 +1134,7 @@ class T_exp(unittest.TestCase): ...@@ -1131,7 +1134,7 @@ class T_exp(unittest.TestCase):
numpy.asarray([[ 1.5089518 , 1.48439076, -4.7820262 ], numpy.asarray([[ 1.5089518 , 1.48439076, -4.7820262 ],
[ 2.04832468, 0.50791564, -1.58892269]])]) [ 2.04832468, 0.50791564, -1.58892269]])])
def test_grad_1(self): def test_grad_1(self):
verify_grad(self, tensor._exp_inplace, [ verify_grad(self, inplace.exp_inplace, [
numpy.asarray([[ 1.5089518 , 1.48439076, -4.7820262 ], numpy.asarray([[ 1.5089518 , 1.48439076, -4.7820262 ],
[ 2.04832468, 0.50791564, -1.58892269]])]) [ 2.04832468, 0.50791564, -1.58892269]])])
...@@ -1299,7 +1302,7 @@ class T_exp(unittest.TestCase): ...@@ -1299,7 +1302,7 @@ class T_exp(unittest.TestCase):
# def test_col(self): # def test_col(self):
# verify_grad(self, Pow, [numpy.random.rand(3, 5), numpy.random.rand(3, 1)]) # verify_grad(self, Pow, [numpy.random.rand(3, 5), numpy.random.rand(3, 1)])
class _testCase_matinv(unittest.TestCase): class test_matinv(unittest.TestCase):
def setUp(self): def setUp(self):
numpy.random.seed(1) numpy.random.seed(1)
...@@ -1499,7 +1502,7 @@ class t_gemm(unittest.TestCase): ...@@ -1499,7 +1502,7 @@ class t_gemm(unittest.TestCase):
Z = as_tensor(self.rand(2,2)) Z = as_tensor(self.rand(2,2))
A = as_tensor(self.rand(2,2)) A = as_tensor(self.rand(2,2))
try: try:
gemm(Z, 1.0, A, tensor._transpose_inplace(Z), 1.0) gemm(Z, 1.0, A, inplace.transpose_inplace(Z), 1.0)
except ValueError, e: except ValueError, e:
if e[0] == Gemm.E_z_uniq: if e[0] == Gemm.E_z_uniq:
return return
...@@ -1509,7 +1512,7 @@ class t_gemm(unittest.TestCase): ...@@ -1509,7 +1512,7 @@ class t_gemm(unittest.TestCase):
Z = as_tensor(self.rand(2,2)) Z = as_tensor(self.rand(2,2))
A = as_tensor(self.rand(2,2)) A = as_tensor(self.rand(2,2))
try: try:
gemm(Z, 1.0, tensor._transpose_inplace(Z), A, 1.0) gemm(Z, 1.0, inplace.transpose_inplace(Z), A, 1.0)
except ValueError, e: except ValueError, e:
if e[0] == Gemm.E_z_uniq: if e[0] == Gemm.E_z_uniq:
return return
...@@ -1739,7 +1742,7 @@ class T_tensorfromscalar(unittest.TestCase): ...@@ -1739,7 +1742,7 @@ class T_tensorfromscalar(unittest.TestCase):
# self.failUnless(t.data is not tt.data) # self.failUnless(t.data is not tt.data)
class _test_grad(unittest.TestCase): class test_grad(unittest.TestCase):
class O(gof.op.Op): class O(gof.op.Op):
def __init__(self): def __init__(self):
self.gval0 = scalar('e') self.gval0 = scalar('e')
...@@ -1753,13 +1756,13 @@ class _test_grad(unittest.TestCase): ...@@ -1753,13 +1756,13 @@ class _test_grad(unittest.TestCase):
def test_1param(self): def test_1param(self):
"""grad: Test passing a single result param""" """grad: Test passing a single result param"""
o = _test_grad.O() o = test_grad.O()
a1 = o.make_node() a1 = o.make_node()
self.failUnless(o.gval0 is grad(a1.outputs[0], a1.inputs[0])) self.failUnless(o.gval0 is grad(a1.outputs[0], a1.inputs[0]))
def test_Nparam(self): def test_Nparam(self):
"""grad: Test passing multiple result params""" """grad: Test passing multiple result params"""
o = _test_grad.O() o = test_grad.O()
a1 = o.make_node() a1 = o.make_node()
g0,g1 = grad(a1.outputs[0], a1.inputs) g0,g1 = grad(a1.outputs[0], a1.inputs)
self.failUnless(o.gval0 is g0) self.failUnless(o.gval0 is g0)
...@@ -1767,7 +1770,7 @@ class _test_grad(unittest.TestCase): ...@@ -1767,7 +1770,7 @@ class _test_grad(unittest.TestCase):
def test_1None_rval(self): def test_1None_rval(self):
"""grad: Test returning a single None from grad""" """grad: Test returning a single None from grad"""
o = _test_grad.O() o = test_grad.O()
a1 = o.make_node() a1 = o.make_node()
g = grad(a1.outputs[0], a1.outputs[1]) g = grad(a1.outputs[0], a1.outputs[1])
self.failUnless(isinstance(g, TensorConstant)) self.failUnless(isinstance(g, TensorConstant))
...@@ -1780,7 +1783,7 @@ class _test_grad(unittest.TestCase): ...@@ -1780,7 +1783,7 @@ class _test_grad(unittest.TestCase):
def test_NNone_rval(self): def test_NNone_rval(self):
"""grad: Test returning some Nones from grad""" """grad: Test returning some Nones from grad"""
o = _test_grad.O() o = test_grad.O()
a1 = o.make_node() a1 = o.make_node()
g0,g1,g2 = grad(a1.outputs[0], a1.inputs + [scalar('z')]) g0,g1,g2 = grad(a1.outputs[0], a1.inputs + [scalar('z')])
self.failUnless(o.gval0 is g0) self.failUnless(o.gval0 is g0)
......
...@@ -9,7 +9,7 @@ from theano import gof ...@@ -9,7 +9,7 @@ from theano import gof
from theano.gradient import * from theano.gradient import *
from theano import gradient from theano import gradient
class _test_grad_sources_inputs(unittest.TestCase): class test_grad_sources_inputs(unittest.TestCase):
def test_retNone1(self): def test_retNone1(self):
"""Test that it is not ok to return None from op.grad()""" """Test that it is not ok to return None from op.grad()"""
class retNone(gof.op.Op): class retNone(gof.op.Op):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论