提交 11faef91 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

splitted inplace ops from tensor.basic

上级 501dc901
...@@ -4,7 +4,7 @@ theano_path = os.path.realpath("%s/.." % sys.path[0]) ...@@ -4,7 +4,7 @@ theano_path = os.path.realpath("%s/.." % sys.path[0])
sys.path[0:0] = [theano_path] sys.path[0:0] = [theano_path]
def test_module(module_path, debugmode = False): def test_module(module_path, debugmode = False):
files = commands.getoutput("find %s -name _test_*.py" % module_path) files = commands.getoutput("find %s -name test_*.py" % module_path)
suite = None suite = None
tocut = len("/".join(module_path.split("/")[:-1])) + 1 tocut = len("/".join(module_path.split("/")[:-1])) + 1
for file in files.split("\n"): for file in files.split("\n"):
......
from basic import * from basic import *
from basic import _abs
...@@ -183,7 +183,7 @@ complex_types = complex64, complex128 ...@@ -183,7 +183,7 @@ complex_types = complex64, complex128
class _scalar_py_operators: class _scalar_py_operators:
#UNARY #UNARY
def __abs__(self): return _abs(self) def __abs__(self): return abs_(self)
def __neg__(self): return neg(self) def __neg__(self): return neg(self)
#CASTS #CASTS
...@@ -587,7 +587,7 @@ class Abs(UnaryScalarOp): ...@@ -587,7 +587,7 @@ class Abs(UnaryScalarOp):
return "%(z)s = fabs(%(x)s);" % locals() return "%(z)s = fabs(%(x)s);" % locals()
#complex, other? #complex, other?
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
_abs = Abs(same_out) abs_ = Abs(same_out)
class Sgn(UnaryScalarOp): class Sgn(UnaryScalarOp):
def impl(self, x): def impl(self, x):
......
...@@ -11,7 +11,7 @@ def inputs(): ...@@ -11,7 +11,7 @@ def inputs():
return floats('xyz') return floats('xyz')
class _test_ScalarOps(unittest.TestCase): class test_ScalarOps(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
...@@ -21,7 +21,7 @@ class _test_ScalarOps(unittest.TestCase): ...@@ -21,7 +21,7 @@ class _test_ScalarOps(unittest.TestCase):
assert fn(1.0, 2.0) == 1.5 assert fn(1.0, 2.0) == 1.5
class _test_composite(unittest.TestCase): class test_composite(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
...@@ -57,7 +57,7 @@ class _test_composite(unittest.TestCase): ...@@ -57,7 +57,7 @@ class _test_composite(unittest.TestCase):
assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5] assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5]
class _test_logical(unittest.TestCase): class test_logical(unittest.TestCase):
def test_gt(self): def test_gt(self):
x, y, z = inputs() x, y, z = inputs()
fn = gof.DualLinker().accept(Env([x,y], [x > y])).make_function() fn = gof.DualLinker().accept(Env([x,y], [x > y])).make_function()
......
...@@ -147,7 +147,7 @@ class T_conversion(unittest.TestCase): ...@@ -147,7 +147,7 @@ class T_conversion(unittest.TestCase):
self.failUnless(numpy.all(val[0] == [1,0,0,0,0])) self.failUnless(numpy.all(val[0] == [1,0,0,0,0]))
class _testCase_dot(unittest.TestCase): class test_dot(unittest.TestCase):
def setUp(self): def setUp(self):
numpy.random.seed(44) numpy.random.seed(44)
......
...@@ -450,7 +450,7 @@ cols, fcols, dcols, icols, lcols = _multi(col, fcol, dcol, icol, lcol) ...@@ -450,7 +450,7 @@ cols, fcols, dcols, icols, lcols = _multi(col, fcol, dcol, icol, lcol)
class _tensor_py_operators: class _tensor_py_operators:
#UNARY #UNARY
def __abs__(self): return _abs(self) def __abs__(self): return abs_(self)
def __neg__(self): return neg(self) def __neg__(self): return neg(self)
#CASTS #CASTS
...@@ -472,9 +472,9 @@ class _tensor_py_operators: ...@@ -472,9 +472,9 @@ class _tensor_py_operators:
def __rand__(self,other): return and_(other,self) def __rand__(self,other): return and_(other,self)
def __ror__(self,other): return or_(other, self) def __ror__(self,other): return or_(other, self)
def __rxor__(self,other): return xor(other, self) def __rxor__(self,other): return xor(other, self)
def __iand__(self, other): return _and_inplace(self, other) # def __iand__(self, other): return _and_inplace(self, other)
def __ior__(self, other): return _or_inplace(self, other) # def __ior__(self, other): return _or_inplace(self, other)
def __ixor__(self, other): return _xor_inplace(self, other) # def __ixor__(self, other): return _xor_inplace(self, other)
#ARITHMETIC - NORMAL #ARITHMETIC - NORMAL
def __add__(self,other): return add(self,other) def __add__(self,other): return add(self,other)
...@@ -484,12 +484,12 @@ class _tensor_py_operators: ...@@ -484,12 +484,12 @@ class _tensor_py_operators:
def __pow__(self,other): return pow(self,other) def __pow__(self,other): return pow(self,other)
def __mod__(self,other): return mod(self,other) def __mod__(self,other): return mod(self,other)
#ARITHMETIC - INPLACE # #ARITHMETIC - INPLACE
def __iadd__(self,other): return _add_inplace(self,other) # def __iadd__(self,other): return _add_inplace(self,other)
def __isub__(self,other): return _sub_inplace(self,other) # def __isub__(self,other): return _sub_inplace(self,other)
def __imul__(self,other): return _mul_inplace(self,other) # def __imul__(self,other): return _mul_inplace(self,other)
def __idiv__(self,other): return _div_inplace(self,other) # def __idiv__(self,other): return _div_inplace(self,other)
def __ipow__(self,other): return _pow_inplace(self,other) # def __ipow__(self,other): return _pow_inplace(self,other)
#ARITHMETIC - RIGHT-OPERAND #ARITHMETIC - RIGHT-OPERAND
def __radd__(self,other): return add(other,self) def __radd__(self,other): return add(other,self)
...@@ -562,7 +562,7 @@ elemwise.TensorValue = TensorValue ...@@ -562,7 +562,7 @@ elemwise.TensorValue = TensorValue
def _elemwise(scalar_op, name, doc_prefix=''): def _elemwise(scalar_op, name, doc_prefix=''):
straight = elemwise.Elemwise(scalar_op, name = name) straight = elemwise.Elemwise(scalar_op, name = name)
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0)) inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
inplace = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name = '_'+name+"_inplace") inplace = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name = name+"_inplace")
# don't add the inplace versions, they aren't supposed to be part of the user interface # don't add the inplace versions, they aren't supposed to be part of the user interface
_constructor_list.append(straight) _constructor_list.append(straight)
...@@ -599,7 +599,7 @@ def _scal_elemwise(symbol): ...@@ -599,7 +599,7 @@ def _scal_elemwise(symbol):
inplace = symbolname.endswith('_inplace') inplace = symbolname.endswith('_inplace')
if inplace: if inplace:
scalar_op = getattr(scal, symbolname[1:-len('_inplace')]) scalar_op = getattr(scal, symbolname[:-len('_inplace')])
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0)) inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
rval = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name=symbolname) rval = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name=symbolname)
else: else:
...@@ -780,50 +780,26 @@ def argmax(x, axis=None): ...@@ -780,50 +780,26 @@ def argmax(x, axis=None):
def lt(a, b): def lt(a, b):
"""a < b""" """a < b"""
@_scal_elemwise
def _lt_inplace(a,b):
"""a < b (inplace on a)"""
@_scal_elemwise @_scal_elemwise
def gt(a, b): def gt(a, b):
"""a > b""" """a > b"""
@_scal_elemwise
def _gt_inplace(a,b):
"""a > b (inplace on a)"""
@_scal_elemwise @_scal_elemwise
def le(a, b): def le(a, b):
"""a <= b""" """a <= b"""
@_scal_elemwise
def _le_inplace(a,b):
"""a <= b (inplace on a)"""
@_scal_elemwise @_scal_elemwise
def ge(a, b): def ge(a, b):
"""a >= b""" """a >= b"""
@_scal_elemwise
def _ge_inplace(a,b):
"""a >= b (inplace on a)"""
@_scal_elemwise @_scal_elemwise
def eq(a, b): def eq(a, b):
"""a == b""" """a == b"""
@_scal_elemwise
def _eq_inplace(a,b):
"""a == b (inplace on a)"""
@_scal_elemwise @_scal_elemwise
def neq(a, b): def neq(a, b):
"""a != b""" """a != b"""
@_scal_elemwise
def _neq_inplace(a,b):
"""a != b (inplace on a)"""
########################## ##########################
# Bit-wise # Bit-wise
...@@ -833,148 +809,86 @@ def _neq_inplace(a,b): ...@@ -833,148 +809,86 @@ def _neq_inplace(a,b):
def and_(a,b): def and_(a,b):
"""bitwise a & b""" """bitwise a & b"""
@_scal_elemwise
def _and__inplace(a,b):
"""bitwise a & b (inplace on a)"""
@_scal_elemwise @_scal_elemwise
def or_(a,b): def or_(a,b):
"""bitwise a | b""" """bitwise a | b"""
@_scal_elemwise
def _or__inplace(a,b):
"""bitwise a | b (inplace on a)"""
@_scal_elemwise @_scal_elemwise
def xor(a,b): def xor(a,b):
"""bitwise a ^ b""" """bitwise a ^ b"""
@_scal_elemwise
def _xor_inplace(a,b):
"""bitwise a ^ b (inplace on a)"""
@_scal_elemwise @_scal_elemwise
def invert(a): def invert(a):
"""bitwise ~a""" """bitwise ~a"""
@_scal_elemwise
def _invert_inplace(a):
"""bitwise ~a (inplace on a)"""
########################## ##########################
# Math # Math
########################## ##########################
@_scal_elemwise @_scal_elemwise
def _abs(a): def abs_(a):
"""|`a`| """|`a`|
_abs has a leading underscore because abs() is a builtin. TensorResult overloads the TensorResult overloads the `TensorResult.__abs__` operator so that
`TensorResult.__abs__` operator so that this function is called when you type abs(a). this function is called when you type abs(a).
""" """
@_scal_elemwise
def __abs_inplace(a):
"""|`a`| (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def exp(a): def exp(a):
"""e^`a`""" """e^`a`"""
@_scal_elemwise
def _exp_inplace(a):
"""e^`a` (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def neg(a): def neg(a):
"""-a""" """-a"""
@_scal_elemwise
def _neg_inplace(a):
"""-a (inplace on a)"""
@_scal_elemwise @_scal_elemwise
def inv(a): def inv(a):
"""1.0/a (inplace on a)""" """1.0/a (inplace on a)"""
@_scal_elemwise
def _inv_inplace(a):
"""1.0/a (inplace on a)"""
@_scal_elemwise @_scal_elemwise
def log(a): def log(a):
"""base e logarithm of a""" """base e logarithm of a"""
@_scal_elemwise
def _log_inplace(a):
"""base e logarithm of a (inplace on a)"""
@_scal_elemwise @_scal_elemwise
def log2(a): def log2(a):
"""base 2 logarithm of a""" """base 2 logarithm of a"""
@_scal_elemwise
def _log2_inplace(a):
"""base 2 logarithm of a (inplace on a)"""
@_scal_elemwise @_scal_elemwise
def sgn(a): def sgn(a):
"""sign of a""" """sign of a"""
@_scal_elemwise
def _sgn_inplace(a):
"""sign of `a` (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def sqr(a): def sqr(a):
"""square of a""" """square of a"""
@_scal_elemwise
def _sqr_inplace(a):
"""square of `a` (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def sqrt(a): def sqrt(a):
"""square root of a""" """square root of a"""
@_scal_elemwise
def _sqrt_inplace(a):
"""square root of `a` (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def cos(a): def cos(a):
"""cosine of a""" """cosine of a"""
@_scal_elemwise
def _cos_inplace(a):
"""cosine of `a` (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def sin(a): def sin(a):
"""sine of a""" """sine of a"""
@_scal_elemwise
def _sin_inplace(a):
"""sine of `a` (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def tan(a): def tan(a):
"""tangent of a""" """tangent of a"""
@_scal_elemwise
def _tan_inplace(a):
"""tangent of `a` (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def cosh(a): def cosh(a):
"""hyperbolic cosine of a""" """hyperbolic cosine of a"""
@_scal_elemwise
def _cosh_inplace(a):
"""hyperbolic cosine of `a` (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def sinh(a): def sinh(a):
"""hyperbolic sine of a""" """hyperbolic sine of a"""
@_scal_elemwise
def _sinh_inplace(a):
"""hyperbolic sine of `a` (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def tanh(a): def tanh(a):
"""hyperbolic tangent of a""" """hyperbolic tangent of a"""
@_scal_elemwise
def _tanh_inplace(a):
"""hyperbolic tangent of `a` (inplace on `a`)"""
########################## ##########################
...@@ -986,12 +900,8 @@ def _tanh_inplace(a): ...@@ -986,12 +900,8 @@ def _tanh_inplace(a):
@_scal_elemwise @_scal_elemwise
def second(a, b): def second(a, b):
"""Create a matrix by filling the shape of a with b""" """Create a matrix by filling the shape of a with b"""
@_scal_elemwise
def _second_inplace(a):
"""Fill `a` with `b`"""
fill = second fill = second
_fill_inplace = _second_inplace
@constructor @constructor
def ones_like(model): def ones_like(model):
...@@ -1112,44 +1022,26 @@ repeat = Repeat() ...@@ -1112,44 +1022,26 @@ repeat = Repeat()
@_scal_elemwise @_scal_elemwise
def add(a, b): def add(a, b):
"""elementwise addition""" """elementwise addition"""
@_scal_elemwise
def _add_inplace(a, b):
"""elementwise addition (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def sub(a, b): def sub(a, b):
"""elementwise subtraction""" """elementwise subtraction"""
@_scal_elemwise
def _sub_inplace(a, b):
"""elementwise subtraction (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def mul(a, b): def mul(a, b):
"""elementwise multiplication""" """elementwise multiplication"""
@_scal_elemwise
def _mul_inplace(a, b):
"""elementwise multiplication (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def div(a, b): def div(a, b):
"""elementwise division""" """elementwise division"""
@_scal_elemwise
def _div_inplace(a, b):
"""elementwise division (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def mod(a, b): def mod(a, b):
"""elementwise modulo""" """elementwise modulo"""
@_scal_elemwise
def _mod_inplace(a, b):
"""elementwise modulo (inplace on `a`)"""
@_scal_elemwise @_scal_elemwise
def pow(a, b): def pow(a, b):
"""elementwise power""" """elementwise power"""
@_scal_elemwise
def _pow_inplace(a, b):
"""elementwise power (inplace on `a`)"""
########################## ##########################
...@@ -1182,7 +1074,6 @@ class TransposeInplace(Op): ...@@ -1182,7 +1074,6 @@ class TransposeInplace(Op):
return "TransposeView" return "TransposeView"
_transpose_inplace = TransposeInplace() _transpose_inplace = TransposeInplace()
"""WRITEME"""
def transpose(x, **kwargs): def transpose(x, **kwargs):
"""WRITEME""" """WRITEME"""
......
from basic import _scal_elemwise, _transpose_inplace
@_scal_elemwise
def lt_inplace(a,b):
"""a < b (inplace on a)"""
@_scal_elemwise
def gt_inplace(a,b):
"""a > b (inplace on a)"""
@_scal_elemwise
def le_inplace(a,b):
"""a <= b (inplace on a)"""
@_scal_elemwise
def ge_inplace(a,b):
"""a >= b (inplace on a)"""
@_scal_elemwise
def eq_inplace(a,b):
"""a == b (inplace on a)"""
@_scal_elemwise
def neq_inplace(a,b):
"""a != b (inplace on a)"""
@_scal_elemwise
def and__inplace(a,b):
"""bitwise a & b (inplace on a)"""
@_scal_elemwise
def or__inplace(a,b):
"""bitwise a | b (inplace on a)"""
@_scal_elemwise
def xor_inplace(a,b):
"""bitwise a ^ b (inplace on a)"""
@_scal_elemwise
def invert_inplace(a):
"""bitwise ~a (inplace on a)"""
@_scal_elemwise
def abs__inplace(a):
"""|`a`| (inplace on `a`)"""
@_scal_elemwise
def exp_inplace(a):
"""e^`a` (inplace on `a`)"""
@_scal_elemwise
def neg_inplace(a):
"""-a (inplace on a)"""
@_scal_elemwise
def inv_inplace(a):
"""1.0/a (inplace on a)"""
@_scal_elemwise
def log_inplace(a):
"""base e logarithm of a (inplace on a)"""
@_scal_elemwise
def log2_inplace(a):
"""base 2 logarithm of a (inplace on a)"""
@_scal_elemwise
def sgn_inplace(a):
"""sign of `a` (inplace on `a`)"""
@_scal_elemwise
def sqr_inplace(a):
"""square of `a` (inplace on `a`)"""
@_scal_elemwise
def sqrt_inplace(a):
"""square root of `a` (inplace on `a`)"""
@_scal_elemwise
def cos_inplace(a):
"""cosine of `a` (inplace on `a`)"""
@_scal_elemwise
def sin_inplace(a):
"""sine of `a` (inplace on `a`)"""
@_scal_elemwise
def tan_inplace(a):
"""tangent of `a` (inplace on `a`)"""
@_scal_elemwise
def cosh_inplace(a):
"""hyperbolic cosine of `a` (inplace on `a`)"""
@_scal_elemwise
def sinh_inplace(a):
"""hyperbolic sine of `a` (inplace on `a`)"""
@_scal_elemwise
def tanh_inplace(a):
"""hyperbolic tangent of `a` (inplace on `a`)"""
@_scal_elemwise
def second_inplace(a):
"""Fill `a` with `b`"""
fill_inplace = second_inplace
@_scal_elemwise
def add_inplace(a, b):
"""elementwise addition (inplace on `a`)"""
@_scal_elemwise
def sub_inplace(a, b):
"""elementwise subtraction (inplace on `a`)"""
@_scal_elemwise
def mul_inplace(a, b):
"""elementwise multiplication (inplace on `a`)"""
@_scal_elemwise
def div_inplace(a, b):
"""elementwise division (inplace on `a`)"""
@_scal_elemwise
def mod_inplace(a, b):
"""elementwise modulo (inplace on `a`)"""
@_scal_elemwise
def pow_inplace(a, b):
"""elementwise power (inplace on `a`)"""
transpose_inplace = _transpose_inplace
"""WRITEME"""
...@@ -8,6 +8,7 @@ from ..gof import opt ...@@ -8,6 +8,7 @@ from ..gof import opt
from elemwise import Elemwise, DimShuffle from elemwise import Elemwise, DimShuffle
from .. import scalar from .. import scalar
import basic as T import basic as T
import inplace as I
import numpy as N import numpy as N
import operator import operator
import itertools import itertools
...@@ -30,7 +31,7 @@ def in2out(*local_opts): ...@@ -30,7 +31,7 @@ def in2out(*local_opts):
# gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c) # gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c)
# Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0) # Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0)
gemm_pattern_1 = gof.PatternSub((T._sub_inplace, gemm_pattern_1 = gof.PatternSub((I.sub_inplace,
'd', 'd',
(T.mul, (T.mul,
dict(pattern = (T.DimShuffle((), ['x', 'x'], inplace = True), 'a'), dict(pattern = (T.DimShuffle((), ['x', 'x'], inplace = True), 'a'),
......
...@@ -9,7 +9,7 @@ from theano import gof ...@@ -9,7 +9,7 @@ from theano import gof
from theano.gradient import * from theano.gradient import *
from theano import gradient from theano import gradient
class _test_grad_sources_inputs(unittest.TestCase): class test_grad_sources_inputs(unittest.TestCase):
def test_retNone1(self): def test_retNone1(self):
"""Test that it is not ok to return None from op.grad()""" """Test that it is not ok to return None from op.grad()"""
class retNone(gof.op.Op): class retNone(gof.op.Op):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论