提交 11faef91 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

splitted inplace ops from tensor.basic

上级 501dc901
......@@ -4,7 +4,7 @@ theano_path = os.path.realpath("%s/.." % sys.path[0])
sys.path[0:0] = [theano_path]
def test_module(module_path, debugmode = False):
files = commands.getoutput("find %s -name _test_*.py" % module_path)
files = commands.getoutput("find %s -name test_*.py" % module_path)
suite = None
tocut = len("/".join(module_path.split("/")[:-1])) + 1
for file in files.split("\n"):
......
from basic import *
from basic import _abs
......@@ -183,7 +183,7 @@ complex_types = complex64, complex128
class _scalar_py_operators:
#UNARY
def __abs__(self): return _abs(self)
def __abs__(self): return abs_(self)
def __neg__(self): return neg(self)
#CASTS
......@@ -587,7 +587,7 @@ class Abs(UnaryScalarOp):
return "%(z)s = fabs(%(x)s);" % locals()
#complex, other?
raise NotImplementedError('type not supported', type)
_abs = Abs(same_out)
abs_ = Abs(same_out)
class Sgn(UnaryScalarOp):
def impl(self, x):
......
......@@ -11,7 +11,7 @@ def inputs():
return floats('xyz')
class _test_ScalarOps(unittest.TestCase):
class test_ScalarOps(unittest.TestCase):
def test_straightforward(self):
x, y, z = inputs()
......@@ -21,7 +21,7 @@ class _test_ScalarOps(unittest.TestCase):
assert fn(1.0, 2.0) == 1.5
class _test_composite(unittest.TestCase):
class test_composite(unittest.TestCase):
def test_straightforward(self):
x, y, z = inputs()
......@@ -57,7 +57,7 @@ class _test_composite(unittest.TestCase):
assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5]
class _test_logical(unittest.TestCase):
class test_logical(unittest.TestCase):
def test_gt(self):
x, y, z = inputs()
fn = gof.DualLinker().accept(Env([x,y], [x > y])).make_function()
......
......@@ -147,7 +147,7 @@ class T_conversion(unittest.TestCase):
self.failUnless(numpy.all(val[0] == [1,0,0,0,0]))
class _testCase_dot(unittest.TestCase):
class test_dot(unittest.TestCase):
def setUp(self):
numpy.random.seed(44)
......
......@@ -450,7 +450,7 @@ cols, fcols, dcols, icols, lcols = _multi(col, fcol, dcol, icol, lcol)
class _tensor_py_operators:
#UNARY
def __abs__(self): return _abs(self)
def __abs__(self): return abs_(self)
def __neg__(self): return neg(self)
#CASTS
......@@ -472,9 +472,9 @@ class _tensor_py_operators:
def __rand__(self,other): return and_(other,self)
def __ror__(self,other): return or_(other, self)
def __rxor__(self,other): return xor(other, self)
def __iand__(self, other): return _and_inplace(self, other)
def __ior__(self, other): return _or_inplace(self, other)
def __ixor__(self, other): return _xor_inplace(self, other)
# def __iand__(self, other): return _and_inplace(self, other)
# def __ior__(self, other): return _or_inplace(self, other)
# def __ixor__(self, other): return _xor_inplace(self, other)
#ARITHMETIC - NORMAL
def __add__(self,other): return add(self,other)
......@@ -484,12 +484,12 @@ class _tensor_py_operators:
def __pow__(self,other): return pow(self,other)
def __mod__(self,other): return mod(self,other)
#ARITHMETIC - INPLACE
def __iadd__(self,other): return _add_inplace(self,other)
def __isub__(self,other): return _sub_inplace(self,other)
def __imul__(self,other): return _mul_inplace(self,other)
def __idiv__(self,other): return _div_inplace(self,other)
def __ipow__(self,other): return _pow_inplace(self,other)
# #ARITHMETIC - INPLACE
# def __iadd__(self,other): return _add_inplace(self,other)
# def __isub__(self,other): return _sub_inplace(self,other)
# def __imul__(self,other): return _mul_inplace(self,other)
# def __idiv__(self,other): return _div_inplace(self,other)
# def __ipow__(self,other): return _pow_inplace(self,other)
#ARITHMETIC - RIGHT-OPERAND
def __radd__(self,other): return add(other,self)
......@@ -562,7 +562,7 @@ elemwise.TensorValue = TensorValue
def _elemwise(scalar_op, name, doc_prefix=''):
straight = elemwise.Elemwise(scalar_op, name = name)
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
inplace = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name = '_'+name+"_inplace")
inplace = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name = name+"_inplace")
# don't add the inplace versions, they aren't supposed to be part of the user interface
_constructor_list.append(straight)
......@@ -599,7 +599,7 @@ def _scal_elemwise(symbol):
inplace = symbolname.endswith('_inplace')
if inplace:
scalar_op = getattr(scal, symbolname[1:-len('_inplace')])
scalar_op = getattr(scal, symbolname[:-len('_inplace')])
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
rval = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name=symbolname)
else:
......@@ -780,50 +780,26 @@ def argmax(x, axis=None):
def lt(a, b):
"""a < b"""
@_scal_elemwise
def _lt_inplace(a,b):
"""a < b (inplace on a)"""
@_scal_elemwise
def gt(a, b):
"""a > b"""
@_scal_elemwise
def _gt_inplace(a,b):
"""a > b (inplace on a)"""
@_scal_elemwise
def le(a, b):
"""a <= b"""
@_scal_elemwise
def _le_inplace(a,b):
"""a <= b (inplace on a)"""
@_scal_elemwise
def ge(a, b):
"""a >= b"""
@_scal_elemwise
def _ge_inplace(a,b):
"""a >= b (inplace on a)"""
@_scal_elemwise
def eq(a, b):
"""a == b"""
@_scal_elemwise
def _eq_inplace(a,b):
"""a == b (inplace on a)"""
@_scal_elemwise
def neq(a, b):
"""a != b"""
@_scal_elemwise
def _neq_inplace(a,b):
"""a != b (inplace on a)"""
##########################
# Bit-wise
......@@ -833,148 +809,86 @@ def _neq_inplace(a,b):
def and_(a,b):
"""bitwise a & b"""
@_scal_elemwise
def _and__inplace(a,b):
"""bitwise a & b (inplace on a)"""
@_scal_elemwise
def or_(a,b):
"""bitwise a | b"""
@_scal_elemwise
def _or__inplace(a,b):
"""bitwise a | b (inplace on a)"""
@_scal_elemwise
def xor(a,b):
"""bitwise a ^ b"""
@_scal_elemwise
def _xor_inplace(a,b):
"""bitwise a ^ b (inplace on a)"""
@_scal_elemwise
def invert(a):
"""bitwise ~a"""
@_scal_elemwise
def _invert_inplace(a):
"""bitwise ~a (inplace on a)"""
##########################
# Math
##########################
@_scal_elemwise
def _abs(a):
def abs_(a):
"""|`a`|
_abs has a leading underscore because abs() is a builtin. TensorResult overloads the
`TensorResult.__abs__` operator so that this function is called when you type abs(a).
TensorResult overloads the `TensorResult.__abs__` operator so that
this function is called when you type abs(a).
"""
@_scal_elemwise
def __abs_inplace(a):
"""|`a`| (inplace on `a`)"""
@_scal_elemwise
def exp(a):
"""e^`a`"""
@_scal_elemwise
def _exp_inplace(a):
"""e^`a` (inplace on `a`)"""
@_scal_elemwise
def neg(a):
"""-a"""
@_scal_elemwise
def _neg_inplace(a):
"""-a (inplace on a)"""
@_scal_elemwise
def inv(a):
"""1.0/a (inplace on a)"""
@_scal_elemwise
def _inv_inplace(a):
"""1.0/a (inplace on a)"""
@_scal_elemwise
def log(a):
"""base e logarithm of a"""
@_scal_elemwise
def _log_inplace(a):
"""base e logarithm of a (inplace on a)"""
@_scal_elemwise
def log2(a):
"""base 2 logarithm of a"""
@_scal_elemwise
def _log2_inplace(a):
"""base 2 logarithm of a (inplace on a)"""
@_scal_elemwise
def sgn(a):
"""sign of a"""
@_scal_elemwise
def _sgn_inplace(a):
"""sign of `a` (inplace on `a`)"""
@_scal_elemwise
def sqr(a):
"""square of a"""
@_scal_elemwise
def _sqr_inplace(a):
"""square of `a` (inplace on `a`)"""
@_scal_elemwise
def sqrt(a):
"""square root of a"""
@_scal_elemwise
def _sqrt_inplace(a):
"""square root of `a` (inplace on `a`)"""
@_scal_elemwise
def cos(a):
"""cosine of a"""
@_scal_elemwise
def _cos_inplace(a):
"""cosine of `a` (inplace on `a`)"""
@_scal_elemwise
def sin(a):
"""sine of a"""
@_scal_elemwise
def _sin_inplace(a):
"""sine of `a` (inplace on `a`)"""
@_scal_elemwise
def tan(a):
"""tangent of a"""
@_scal_elemwise
def _tan_inplace(a):
"""tangent of `a` (inplace on `a`)"""
@_scal_elemwise
def cosh(a):
"""hyperbolic cosine of a"""
@_scal_elemwise
def _cosh_inplace(a):
"""hyperbolic cosine of `a` (inplace on `a`)"""
@_scal_elemwise
def sinh(a):
"""hyperbolic sine of a"""
@_scal_elemwise
def _sinh_inplace(a):
"""hyperbolic sine of `a` (inplace on `a`)"""
@_scal_elemwise
def tanh(a):
"""hyperbolic tangent of a"""
@_scal_elemwise
def _tanh_inplace(a):
"""hyperbolic tangent of `a` (inplace on `a`)"""
##########################
......@@ -986,12 +900,8 @@ def _tanh_inplace(a):
@_scal_elemwise
def second(a, b):
"""Create a matrix by filling the shape of a with b"""
@_scal_elemwise
def _second_inplace(a):
"""Fill `a` with `b`"""
fill = second
_fill_inplace = _second_inplace
@constructor
def ones_like(model):
......@@ -1112,44 +1022,26 @@ repeat = Repeat()
@_scal_elemwise
def add(a, b):
"""elementwise addition"""
@_scal_elemwise
def _add_inplace(a, b):
"""elementwise addition (inplace on `a`)"""
@_scal_elemwise
def sub(a, b):
"""elementwise subtraction"""
@_scal_elemwise
def _sub_inplace(a, b):
"""elementwise subtraction (inplace on `a`)"""
@_scal_elemwise
def mul(a, b):
"""elementwise multiplication"""
@_scal_elemwise
def _mul_inplace(a, b):
"""elementwise multiplication (inplace on `a`)"""
@_scal_elemwise
def div(a, b):
"""elementwise division"""
@_scal_elemwise
def _div_inplace(a, b):
"""elementwise division (inplace on `a`)"""
@_scal_elemwise
def mod(a, b):
"""elementwise modulo"""
@_scal_elemwise
def _mod_inplace(a, b):
"""elementwise modulo (inplace on `a`)"""
@_scal_elemwise
def pow(a, b):
"""elementwise power"""
@_scal_elemwise
def _pow_inplace(a, b):
"""elementwise power (inplace on `a`)"""
##########################
......@@ -1182,7 +1074,6 @@ class TransposeInplace(Op):
return "TransposeView"
_transpose_inplace = TransposeInplace()
"""WRITEME"""
def transpose(x, **kwargs):
"""WRITEME"""
......
from basic import _scal_elemwise, _transpose_inplace
@_scal_elemwise
def lt_inplace(a,b):
"""a < b (inplace on a)"""
@_scal_elemwise
def gt_inplace(a,b):
"""a > b (inplace on a)"""
@_scal_elemwise
def le_inplace(a,b):
"""a <= b (inplace on a)"""
@_scal_elemwise
def ge_inplace(a,b):
"""a >= b (inplace on a)"""
@_scal_elemwise
def eq_inplace(a,b):
"""a == b (inplace on a)"""
@_scal_elemwise
def neq_inplace(a,b):
"""a != b (inplace on a)"""
@_scal_elemwise
def and__inplace(a,b):
"""bitwise a & b (inplace on a)"""
@_scal_elemwise
def or__inplace(a,b):
"""bitwise a | b (inplace on a)"""
@_scal_elemwise
def xor_inplace(a,b):
"""bitwise a ^ b (inplace on a)"""
@_scal_elemwise
def invert_inplace(a):
"""bitwise ~a (inplace on a)"""
@_scal_elemwise
def abs__inplace(a):
"""|`a`| (inplace on `a`)"""
@_scal_elemwise
def exp_inplace(a):
"""e^`a` (inplace on `a`)"""
@_scal_elemwise
def neg_inplace(a):
"""-a (inplace on a)"""
@_scal_elemwise
def inv_inplace(a):
"""1.0/a (inplace on a)"""
@_scal_elemwise
def log_inplace(a):
"""base e logarithm of a (inplace on a)"""
@_scal_elemwise
def log2_inplace(a):
"""base 2 logarithm of a (inplace on a)"""
@_scal_elemwise
def sgn_inplace(a):
"""sign of `a` (inplace on `a`)"""
@_scal_elemwise
def sqr_inplace(a):
"""square of `a` (inplace on `a`)"""
@_scal_elemwise
def sqrt_inplace(a):
"""square root of `a` (inplace on `a`)"""
@_scal_elemwise
def cos_inplace(a):
"""cosine of `a` (inplace on `a`)"""
@_scal_elemwise
def sin_inplace(a):
"""sine of `a` (inplace on `a`)"""
@_scal_elemwise
def tan_inplace(a):
"""tangent of `a` (inplace on `a`)"""
@_scal_elemwise
def cosh_inplace(a):
"""hyperbolic cosine of `a` (inplace on `a`)"""
@_scal_elemwise
def sinh_inplace(a):
"""hyperbolic sine of `a` (inplace on `a`)"""
@_scal_elemwise
def tanh_inplace(a):
"""hyperbolic tangent of `a` (inplace on `a`)"""
@_scal_elemwise
def second_inplace(a):
"""Fill `a` with `b`"""
fill_inplace = second_inplace
@_scal_elemwise
def add_inplace(a, b):
"""elementwise addition (inplace on `a`)"""
@_scal_elemwise
def sub_inplace(a, b):
"""elementwise subtraction (inplace on `a`)"""
@_scal_elemwise
def mul_inplace(a, b):
"""elementwise multiplication (inplace on `a`)"""
@_scal_elemwise
def div_inplace(a, b):
"""elementwise division (inplace on `a`)"""
@_scal_elemwise
def mod_inplace(a, b):
"""elementwise modulo (inplace on `a`)"""
@_scal_elemwise
def pow_inplace(a, b):
"""elementwise power (inplace on `a`)"""
transpose_inplace = _transpose_inplace
"""WRITEME"""
......@@ -8,6 +8,7 @@ from ..gof import opt
from elemwise import Elemwise, DimShuffle
from .. import scalar
import basic as T
import inplace as I
import numpy as N
import operator
import itertools
......@@ -30,7 +31,7 @@ def in2out(*local_opts):
# gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c)
# Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0)
gemm_pattern_1 = gof.PatternSub((T._sub_inplace,
gemm_pattern_1 = gof.PatternSub((I.sub_inplace,
'd',
(T.mul,
dict(pattern = (T.DimShuffle((), ['x', 'x'], inplace = True), 'a'),
......
......@@ -9,7 +9,7 @@ from theano import gof
from theano.gradient import *
from theano import gradient
class _test_grad_sources_inputs(unittest.TestCase):
class test_grad_sources_inputs(unittest.TestCase):
def test_retNone1(self):
"""Test that it is not ok to return None from op.grad()"""
class retNone(gof.op.Op):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论