提交 8c745911 authored 作者: projects@lgcm's avatar projects@lgcm

docstrings, hiding inplace ops in tensor

上级 d3613582
......@@ -77,7 +77,7 @@ class T_Function(unittest.TestCase):
def test_closure(self):
x, y, z = tensor.scalars('xyz')
v = tensor.value(numpy.zeros(()))
e = x + tensor.add_inplace(v, 1)
e = x + tensor._add_inplace(v, 1)
f = function([x], [e])
assert f(1.) == 2.
assert f(1.) == 3.
......@@ -109,7 +109,7 @@ class T_Function(unittest.TestCase):
def test_borrow_false_through_inplace(self):
x, y, z = tensor.scalars('xyz')
# if borrow_outputs is False, we must not reuse the temporary created for x+y
e = tensor.add_inplace(x + y, z)
e = tensor._add_inplace(x + y, z)
for linker in 'py c c|py c&py'.split():
f = function([x, y, z], [e], borrow_outputs = False, linker = linker)
res1 = f(1.0, 2.0, 3.0)
......
差异被折叠。
......@@ -3,12 +3,12 @@
import unittest
from theano import gof
from theano.tensor_opt import *
from theano import tensor
from theano.tensor import Tensor
from theano.gof import Env
from theano.elemwise import DimShuffle
import gof
from tensor_opt import *
import tensor
from tensor import Tensor
from gof import Env
from elemwise import DimShuffle
import numpy
#import scalar_opt
......@@ -43,7 +43,7 @@ def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)):
# def test_user_inplace(self):
# x, y, z = inputs()
# e0 = x + y
# e1 = tensor.mul_inplace(x, y)
# e1 = tensor._mul_inplace(x, y)
# g = Env([x, y], [e0, e1])
# self.failUnless(str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, y)]")
# inplace_optimizer.optimize(g)
......@@ -52,7 +52,7 @@ def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)):
# def test_inplace_on_second_argument(self):
# x, y, z = inputs()
# e0 = x + y
# e1 = tensor.mul_inplace(x, z)
# e1 = tensor._mul_inplace(x, z)
# g = Env([x, y], [e0, e1])
# self.failUnless(str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, z)]")
# inplace_optimizer.optimize(g)
......@@ -98,9 +98,9 @@ class _test_dimshuffle_lift(unittest.TestCase):
from theano.tensor import *
from tensor import *
from theano.sandbox import pprint
from sandbox import pprint
class _test_greedy_distribute(unittest.TestCase):
def test_main(self):
......@@ -279,8 +279,8 @@ class _test_canonize(unittest.TestCase):
# # def test_inplace(self):
# # x, y, z = inputs()
# # #e = tensor.add_inplace(x, y + z)
# # e = x + tensor.add_inplace(y, z)
# # #e = tensor._add_inplace(x, y + z)
# # e = x + tensor._add_inplace(y, z)
# # g = Env([x, y, z], [e])
# # opt = CliqueOptimizer(through_broadcast = False,
# # scalar_optimizer = None,
......
......@@ -415,9 +415,9 @@ class _tensor_py_operators:
def __rand__(self,other): return and_(other,self)
def __ror__(self,other): return or_(other, self)
def __rxor__(self,other): return xor(other, self)
def __iand__(self, other): return and_inplace(self, other)
def __ior__(self, other): return or_inplace(self, other)
def __ixor__(self, other): return xor_inplace(self, other)
def __iand__(self, other): return _and_inplace(self, other)
def __ior__(self, other): return _or_inplace(self, other)
def __ixor__(self, other): return _xor_inplace(self, other)
#ARITHMETIC - NORMAL
def __add__(self,other): return add(self,other)
......@@ -428,11 +428,11 @@ class _tensor_py_operators:
def __mod__(self,other): return mod(self,other)
#ARITHMETIC - INPLACE
def __iadd__(self,other): return add_inplace(self,other)
def __isub__(self,other): return sub_inplace(self,other)
def __imul__(self,other): return mul_inplace(self,other)
def __idiv__(self,other): return div_inplace(self,other)
def __ipow__(self,other): return pow_inplace(self,other)
def __iadd__(self,other): return _add_inplace(self,other)
def __isub__(self,other): return _sub_inplace(self,other)
def __imul__(self,other): return _mul_inplace(self,other)
def __idiv__(self,other): return _div_inplace(self,other)
def __ipow__(self,other): return _pow_inplace(self,other)
#ARITHMETIC - RIGHT-OPERAND
def __radd__(self,other): return add(other,self)
......@@ -493,7 +493,7 @@ elemwise.TensorValue = TensorValue
def _elemwise(scalar_op, name, doc_prefix=''):
straight = elemwise.Elemwise(scalar_op, name = name)
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
inplace = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name = name+"_inplace")
inplace = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name = '_'+name+"_inplace")
# don't add the inplace versions, they aren't supposed to be part of the user interface
_constructor_list.append(straight)
......@@ -664,97 +664,97 @@ def argmax(x, axis=None):
# Comparison
##########################
lt, lt_inplace = _elemwise(scal.lt, 'lt')
"""less than (elemwise)"""
lt, _lt_inplace = _elemwise(scal.lt, 'lt',
"""less than (elemwise)""")
gt, gt_inplace = _elemwise(scal.gt, 'gt')
"""greater than (elemwise)"""
gt, _gt_inplace = _elemwise(scal.gt, 'gt',
"""greater than (elemwise)""")
le, le_inplace = _elemwise(scal.le, 'le')
"""less than, or equal to (elemwise)"""
le, _le_inplace = _elemwise(scal.le, 'le',
"""less than, or equal to (elemwise)""")
ge, ge_inplace = _elemwise(scal.ge, 'ge')
"""greater than, or equal to (elemwise)"""
ge, _ge_inplace = _elemwise(scal.ge, 'ge',
"""greater than, or equal to (elemwise)""")
eq, eq_inplace = _elemwise(scal.eq, 'eq')
"""equal to (elemwise)"""
eq, _eq_inplace = _elemwise(scal.eq, 'eq',
"""equal to (elemwise)""")
neq, neq_inplace = _elemwise(scal.neq, 'neq')
"""not equal to (elemwise)"""
neq, _neq_inplace = _elemwise(scal.neq, 'neq',
"""not equal to (elemwise)""")
##########################
# Bit-wise
##########################
and_, and_inplace = _elemwise(scal.and_, 'and_')
"""bitwise AND (elemwise)"""
and_, _and_inplace = _elemwise(scal.and_, 'and_',
"""bitwise AND (elemwise)""")
or_, or_inplace = _elemwise(scal.or_, 'or_')
"""bitwise OR (elemwise)"""
or_, _or_inplace = _elemwise(scal.or_, 'or_',
"""bitwise OR (elemwise)""")
xor, xor_inplace = _elemwise(scal.xor, 'xor')
"""bitwise XOR (elemwise)"""
xor, _xor_inplace = _elemwise(scal.xor, 'xor',
"""bitwise XOR (elemwise)""")
invert, invert_inplace = _elemwise(scal.invert, 'invert')
"""bitwise NOT (elemwise)"""
invert, _invert_inplace = _elemwise(scal.invert, 'invert',
"""bitwise NOT (elemwise)""")
##########################
# Math
##########################
_abs, abs_inplace = _elemwise(scal.abs, 'abs',
_abs, _abs_inplace = _elemwise(scal.abs, 'abs',
"""absolute value (elemwise)""")
exp, exp_inplace = _elemwise(scal.exp, 'exp',
exp, _exp_inplace = _elemwise(scal.exp, 'exp',
"""exponential (elemwise)""")
neg, neg_inplace = _elemwise(scal.neg, 'neg',
neg, _neg_inplace = _elemwise(scal.neg, 'neg',
"""negative (elemwise)""")
inv, inv_inplace = _elemwise(scal.inv, 'inv',
inv, _inv_inplace = _elemwise(scal.inv, 'inv',
"""multiplicative inverse (elemwise)""")
log, log_inplace = _elemwise(scal.log, 'log',
log, _log_inplace = _elemwise(scal.log, 'log',
"""logarithm base-e (elemwise)""")
log2, log2_inplace = _elemwise(scal.log2, 'log2')
"""logarithm base-2 (elemwise)"""
log2, _log2_inplace = _elemwise(scal.log2, 'log2',
"""logarithm base-2 (elemwise)""")
sgn, sgn_inplace = _elemwise(scal.sgn, 'sgn')
"""sign (elemwise)"""
sgn, _sgn_inplace = _elemwise(scal.sgn, 'sgn',
"""sign (elemwise)""")
sqr, sqr_inplace = _elemwise(scal.sqr, 'sqr')
"""square (elemwise)"""
sqr, _sqr_inplace = _elemwise(scal.sqr, 'sqr',
"""square (elemwise)""")
sqrt, sqrt_inplace = _elemwise(scal.sqrt, 'sqrt')
"""square root (elemwise)"""
sqrt, _sqrt_inplace = _elemwise(scal.sqrt, 'sqrt',
"""square root (elemwise)""")
cos, cos_inplace = _elemwise(scal.cos, 'cos')
"""cosine (elemwise)"""
cos, _cos_inplace = _elemwise(scal.cos, 'cos',
"""cosine (elemwise)""")
sin, sin_inplace = _elemwise(scal.sin, 'sin')
"""sine (elemwise)"""
sin, _sin_inplace = _elemwise(scal.sin, 'sin',
"""sine (elemwise)""")
tan, tan_inplace = _elemwise(scal.tan, 'tan')
"""tan = sin/cos (elemwise)"""
tan, _tan_inplace = _elemwise(scal.tan, 'tan',
"""tan = sin/cos (elemwise)""")
cosh, cosh_inplace = _elemwise(scal.cosh, 'cosh')
"""hyperbolic cosine (elemwise)"""
cosh, _cosh_inplace = _elemwise(scal.cosh, 'cosh',
"""hyperbolic cosine (elemwise)""")
sinh, sinh_inplace = _elemwise(scal.sinh, 'sinh')
"""hyperbolic sine (elemwise)"""
sinh, _sinh_inplace = _elemwise(scal.sinh, 'sinh',
"""hyperbolic sine (elemwise)""")
tanh, tanh_inplace = _elemwise(scal.tanh, 'tanh')
"""hyperbolic tangent (elemwise)"""
tanh, _tanh_inplace = _elemwise(scal.tanh, 'tanh',
"""hyperbolic tangent (elemwise)""")
##########################
# Misc
##########################
fill, fill_inplace = _elemwise(scal.second, 'fill')
"""fill WRITEME (elemwise)"""
fill, _fill_inplace = _elemwise(scal.second, 'fill',
"""fill WRITEME (elemwise)""")
@constructor
def ones_like(model):
......@@ -870,22 +870,23 @@ repeat = Repeat()
# Arithmetics
##########################
add, add_inplace = _elemwise(scal.add, 'add', 'addition (element-wise)')
add, _add_inplace = _elemwise(scal.add, 'add',
"""addition (element-wise)""")
sub, sub_inplace = _elemwise(scal.sub, 'sub')
"""subtraction (elemwise)"""
sub, _sub_inplace = _elemwise(scal.sub, 'sub',
"""subtraction (elemwise)""")
mul, mul_inplace = _elemwise(scal.mul, 'mul')
"""multiplication (elemwise)"""
mul, _mul_inplace = _elemwise(scal.mul, 'mul',
"""multiplication (elemwise)""")
div, div_inplace = _elemwise(scal.div, 'div')
"""division (elemwise)"""
div, _div_inplace = _elemwise(scal.div, 'div',
"""division (elemwise)""")
mod, mod_inplace = _elemwise(scal.mod, 'mod')
"""modulo (elemwise)"""
mod, _mod_inplace = _elemwise(scal.mod, 'mod',
"""modulo (elemwise)""")
pow, pow_inplace = _elemwise(scal.pow, 'pow')
"""raise to given exponent (elemwise)"""
pow, _pow_inplace = _elemwise(scal.pow, 'pow',
"""raise to given exponent (elemwise)""")
##########################
......@@ -917,12 +918,12 @@ class TransposeInplace(Op):
def __str__(self):
return "TransposeView"
transpose_inplace = TransposeInplace()
_transpose_inplace = TransposeInplace()
"""WRITEME"""
def transpose(x, **kwargs):
"""WRITEME"""
return transpose_inplace(tensor_copy(x), **kwargs)
return _transpose_inplace(tensor_copy(x), **kwargs)
......
......@@ -20,7 +20,7 @@ def in2out(*local_opts):
# gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c)
# Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0)
gemm_pattern_1 = gof.PatternSub((T.sub_inplace,
gemm_pattern_1 = gof.PatternSub((T._sub_inplace,
'd',
(T.mul,
dict(pattern = (T.DimShuffle((), ['x', 'x'], inplace = True), 'a'),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论