提交 8c745911 authored 作者: projects@lgcm's avatar projects@lgcm

docstrings, hiding inplace ops in tensor

上级 d3613582
...@@ -77,7 +77,7 @@ class T_Function(unittest.TestCase): ...@@ -77,7 +77,7 @@ class T_Function(unittest.TestCase):
def test_closure(self): def test_closure(self):
x, y, z = tensor.scalars('xyz') x, y, z = tensor.scalars('xyz')
v = tensor.value(numpy.zeros(())) v = tensor.value(numpy.zeros(()))
e = x + tensor.add_inplace(v, 1) e = x + tensor._add_inplace(v, 1)
f = function([x], [e]) f = function([x], [e])
assert f(1.) == 2. assert f(1.) == 2.
assert f(1.) == 3. assert f(1.) == 3.
...@@ -109,7 +109,7 @@ class T_Function(unittest.TestCase): ...@@ -109,7 +109,7 @@ class T_Function(unittest.TestCase):
def test_borrow_false_through_inplace(self): def test_borrow_false_through_inplace(self):
x, y, z = tensor.scalars('xyz') x, y, z = tensor.scalars('xyz')
# if borrow_outputs is False, we must not reuse the temporary created for x+y # if borrow_outputs is False, we must not reuse the temporary created for x+y
e = tensor.add_inplace(x + y, z) e = tensor._add_inplace(x + y, z)
for linker in 'py c c|py c&py'.split(): for linker in 'py c c|py c&py'.split():
f = function([x, y, z], [e], borrow_outputs = False, linker = linker) f = function([x, y, z], [e], borrow_outputs = False, linker = linker)
res1 = f(1.0, 2.0, 3.0) res1 = f(1.0, 2.0, 3.0)
......
差异被折叠。
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
import unittest import unittest
from theano import gof import gof
from theano.tensor_opt import * from tensor_opt import *
from theano import tensor import tensor
from theano.tensor import Tensor from tensor import Tensor
from theano.gof import Env from gof import Env
from theano.elemwise import DimShuffle from elemwise import DimShuffle
import numpy import numpy
#import scalar_opt #import scalar_opt
...@@ -43,7 +43,7 @@ def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)): ...@@ -43,7 +43,7 @@ def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)):
# def test_user_inplace(self): # def test_user_inplace(self):
# x, y, z = inputs() # x, y, z = inputs()
# e0 = x + y # e0 = x + y
# e1 = tensor.mul_inplace(x, y) # e1 = tensor._mul_inplace(x, y)
# g = Env([x, y], [e0, e1]) # g = Env([x, y], [e0, e1])
# self.failUnless(str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, y)]") # self.failUnless(str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, y)]")
# inplace_optimizer.optimize(g) # inplace_optimizer.optimize(g)
...@@ -52,7 +52,7 @@ def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)): ...@@ -52,7 +52,7 @@ def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)):
# def test_inplace_on_second_argument(self): # def test_inplace_on_second_argument(self):
# x, y, z = inputs() # x, y, z = inputs()
# e0 = x + y # e0 = x + y
# e1 = tensor.mul_inplace(x, z) # e1 = tensor._mul_inplace(x, z)
# g = Env([x, y], [e0, e1]) # g = Env([x, y], [e0, e1])
# self.failUnless(str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, z)]") # self.failUnless(str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, z)]")
# inplace_optimizer.optimize(g) # inplace_optimizer.optimize(g)
...@@ -98,9 +98,9 @@ class _test_dimshuffle_lift(unittest.TestCase): ...@@ -98,9 +98,9 @@ class _test_dimshuffle_lift(unittest.TestCase):
from theano.tensor import * from tensor import *
from theano.sandbox import pprint from sandbox import pprint
class _test_greedy_distribute(unittest.TestCase): class _test_greedy_distribute(unittest.TestCase):
def test_main(self): def test_main(self):
...@@ -279,8 +279,8 @@ class _test_canonize(unittest.TestCase): ...@@ -279,8 +279,8 @@ class _test_canonize(unittest.TestCase):
# # def test_inplace(self): # # def test_inplace(self):
# # x, y, z = inputs() # # x, y, z = inputs()
# # #e = tensor.add_inplace(x, y + z) # # #e = tensor._add_inplace(x, y + z)
# # e = x + tensor.add_inplace(y, z) # # e = x + tensor._add_inplace(y, z)
# # g = Env([x, y, z], [e]) # # g = Env([x, y, z], [e])
# # opt = CliqueOptimizer(through_broadcast = False, # # opt = CliqueOptimizer(through_broadcast = False,
# # scalar_optimizer = None, # # scalar_optimizer = None,
......
...@@ -415,9 +415,9 @@ class _tensor_py_operators: ...@@ -415,9 +415,9 @@ class _tensor_py_operators:
def __rand__(self,other): return and_(other,self) def __rand__(self,other): return and_(other,self)
def __ror__(self,other): return or_(other, self) def __ror__(self,other): return or_(other, self)
def __rxor__(self,other): return xor(other, self) def __rxor__(self,other): return xor(other, self)
def __iand__(self, other): return and_inplace(self, other) def __iand__(self, other): return _and_inplace(self, other)
def __ior__(self, other): return or_inplace(self, other) def __ior__(self, other): return _or_inplace(self, other)
def __ixor__(self, other): return xor_inplace(self, other) def __ixor__(self, other): return _xor_inplace(self, other)
#ARITHMETIC - NORMAL #ARITHMETIC - NORMAL
def __add__(self,other): return add(self,other) def __add__(self,other): return add(self,other)
...@@ -428,11 +428,11 @@ class _tensor_py_operators: ...@@ -428,11 +428,11 @@ class _tensor_py_operators:
def __mod__(self,other): return mod(self,other) def __mod__(self,other): return mod(self,other)
#ARITHMETIC - INPLACE #ARITHMETIC - INPLACE
def __iadd__(self,other): return add_inplace(self,other) def __iadd__(self,other): return _add_inplace(self,other)
def __isub__(self,other): return sub_inplace(self,other) def __isub__(self,other): return _sub_inplace(self,other)
def __imul__(self,other): return mul_inplace(self,other) def __imul__(self,other): return _mul_inplace(self,other)
def __idiv__(self,other): return div_inplace(self,other) def __idiv__(self,other): return _div_inplace(self,other)
def __ipow__(self,other): return pow_inplace(self,other) def __ipow__(self,other): return _pow_inplace(self,other)
#ARITHMETIC - RIGHT-OPERAND #ARITHMETIC - RIGHT-OPERAND
def __radd__(self,other): return add(other,self) def __radd__(self,other): return add(other,self)
...@@ -493,7 +493,7 @@ elemwise.TensorValue = TensorValue ...@@ -493,7 +493,7 @@ elemwise.TensorValue = TensorValue
def _elemwise(scalar_op, name, doc_prefix=''): def _elemwise(scalar_op, name, doc_prefix=''):
straight = elemwise.Elemwise(scalar_op, name = name) straight = elemwise.Elemwise(scalar_op, name = name)
inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0)) inplace_scalar_op = scalar_op.__class__(scal.transfer_type(0))
inplace = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name = name+"_inplace") inplace = elemwise.Elemwise(inplace_scalar_op, {0: 0}, name = '_'+name+"_inplace")
# don't add the inplace versions, they aren't supposed to be part of the user interface # don't add the inplace versions, they aren't supposed to be part of the user interface
_constructor_list.append(straight) _constructor_list.append(straight)
...@@ -664,97 +664,97 @@ def argmax(x, axis=None): ...@@ -664,97 +664,97 @@ def argmax(x, axis=None):
# Comparison # Comparison
########################## ##########################
lt, lt_inplace = _elemwise(scal.lt, 'lt') lt, _lt_inplace = _elemwise(scal.lt, 'lt',
"""less than (elemwise)""" """less than (elemwise)""")
gt, gt_inplace = _elemwise(scal.gt, 'gt') gt, _gt_inplace = _elemwise(scal.gt, 'gt',
"""greater than (elemwise)""" """greater than (elemwise)""")
le, le_inplace = _elemwise(scal.le, 'le') le, _le_inplace = _elemwise(scal.le, 'le',
"""less than, or equal to (elemwise)""" """less than, or equal to (elemwise)""")
ge, ge_inplace = _elemwise(scal.ge, 'ge') ge, _ge_inplace = _elemwise(scal.ge, 'ge',
"""greater than, or equal to (elemwise)""" """greater than, or equal to (elemwise)""")
eq, eq_inplace = _elemwise(scal.eq, 'eq') eq, _eq_inplace = _elemwise(scal.eq, 'eq',
"""equal to (elemwise)""" """equal to (elemwise)""")
neq, neq_inplace = _elemwise(scal.neq, 'neq') neq, _neq_inplace = _elemwise(scal.neq, 'neq',
"""not equal to (elemwise)""" """not equal to (elemwise)""")
########################## ##########################
# Bit-wise # Bit-wise
########################## ##########################
and_, and_inplace = _elemwise(scal.and_, 'and_') and_, _and_inplace = _elemwise(scal.and_, 'and_',
"""bitwise AND (elemwise)""" """bitwise AND (elemwise)""")
or_, or_inplace = _elemwise(scal.or_, 'or_') or_, _or_inplace = _elemwise(scal.or_, 'or_',
"""bitwise OR (elemwise)""" """bitwise OR (elemwise)""")
xor, xor_inplace = _elemwise(scal.xor, 'xor') xor, _xor_inplace = _elemwise(scal.xor, 'xor',
"""bitwise XOR (elemwise)""" """bitwise XOR (elemwise)""")
invert, invert_inplace = _elemwise(scal.invert, 'invert') invert, _invert_inplace = _elemwise(scal.invert, 'invert',
"""bitwise NOT (elemwise)""" """bitwise NOT (elemwise)""")
########################## ##########################
# Math # Math
########################## ##########################
_abs, abs_inplace = _elemwise(scal.abs, 'abs', _abs, _abs_inplace = _elemwise(scal.abs, 'abs',
"""absolute value (elemwise)""") """absolute value (elemwise)""")
exp, exp_inplace = _elemwise(scal.exp, 'exp', exp, _exp_inplace = _elemwise(scal.exp, 'exp',
"""exponential (elemwise)""") """exponential (elemwise)""")
neg, neg_inplace = _elemwise(scal.neg, 'neg', neg, _neg_inplace = _elemwise(scal.neg, 'neg',
"""negative (elemwise)""") """negative (elemwise)""")
inv, inv_inplace = _elemwise(scal.inv, 'inv', inv, _inv_inplace = _elemwise(scal.inv, 'inv',
"""multiplicative inverse (elemwise)""") """multiplicative inverse (elemwise)""")
log, log_inplace = _elemwise(scal.log, 'log', log, _log_inplace = _elemwise(scal.log, 'log',
"""logarithm base-e (elemwise)""") """logarithm base-e (elemwise)""")
log2, log2_inplace = _elemwise(scal.log2, 'log2') log2, _log2_inplace = _elemwise(scal.log2, 'log2',
"""logarithm base-2 (elemwise)""" """logarithm base-2 (elemwise)""")
sgn, sgn_inplace = _elemwise(scal.sgn, 'sgn') sgn, _sgn_inplace = _elemwise(scal.sgn, 'sgn',
"""sign (elemwise)""" """sign (elemwise)""")
sqr, sqr_inplace = _elemwise(scal.sqr, 'sqr') sqr, _sqr_inplace = _elemwise(scal.sqr, 'sqr',
"""square (elemwise)""" """square (elemwise)""")
sqrt, sqrt_inplace = _elemwise(scal.sqrt, 'sqrt') sqrt, _sqrt_inplace = _elemwise(scal.sqrt, 'sqrt',
"""square root (elemwise)""" """square root (elemwise)""")
cos, cos_inplace = _elemwise(scal.cos, 'cos') cos, _cos_inplace = _elemwise(scal.cos, 'cos',
"""cosine (elemwise)""" """cosine (elemwise)""")
sin, sin_inplace = _elemwise(scal.sin, 'sin') sin, _sin_inplace = _elemwise(scal.sin, 'sin',
"""sine (elemwise)""" """sine (elemwise)""")
tan, tan_inplace = _elemwise(scal.tan, 'tan') tan, _tan_inplace = _elemwise(scal.tan, 'tan',
"""tan = sin/cos (elemwise)""" """tan = sin/cos (elemwise)""")
cosh, cosh_inplace = _elemwise(scal.cosh, 'cosh') cosh, _cosh_inplace = _elemwise(scal.cosh, 'cosh',
"""hyperbolic cosine (elemwise)""" """hyperbolic cosine (elemwise)""")
sinh, sinh_inplace = _elemwise(scal.sinh, 'sinh') sinh, _sinh_inplace = _elemwise(scal.sinh, 'sinh',
"""hyperbolic sine (elemwise)""" """hyperbolic sine (elemwise)""")
tanh, tanh_inplace = _elemwise(scal.tanh, 'tanh') tanh, _tanh_inplace = _elemwise(scal.tanh, 'tanh',
"""hyperbolic tangent (elemwise)""" """hyperbolic tangent (elemwise)""")
########################## ##########################
# Misc # Misc
########################## ##########################
fill, fill_inplace = _elemwise(scal.second, 'fill') fill, _fill_inplace = _elemwise(scal.second, 'fill',
"""fill WRITEME (elemwise)""" """fill WRITEME (elemwise)""")
@constructor @constructor
def ones_like(model): def ones_like(model):
...@@ -870,22 +870,23 @@ repeat = Repeat() ...@@ -870,22 +870,23 @@ repeat = Repeat()
# Arithmetics # Arithmetics
########################## ##########################
add, add_inplace = _elemwise(scal.add, 'add', 'addition (element-wise)') add, _add_inplace = _elemwise(scal.add, 'add',
"""addition (element-wise)""")
sub, sub_inplace = _elemwise(scal.sub, 'sub') sub, _sub_inplace = _elemwise(scal.sub, 'sub',
"""subtraction (elemwise)""" """subtraction (elemwise)""")
mul, mul_inplace = _elemwise(scal.mul, 'mul') mul, _mul_inplace = _elemwise(scal.mul, 'mul',
"""multiplication (elemwise)""" """multiplication (elemwise)""")
div, div_inplace = _elemwise(scal.div, 'div') div, _div_inplace = _elemwise(scal.div, 'div',
"""division (elemwise)""" """division (elemwise)""")
mod, mod_inplace = _elemwise(scal.mod, 'mod') mod, _mod_inplace = _elemwise(scal.mod, 'mod',
"""modulo (elemwise)""" """modulo (elemwise)""")
pow, pow_inplace = _elemwise(scal.pow, 'pow') pow, _pow_inplace = _elemwise(scal.pow, 'pow',
"""raise to given exponent (elemwise)""" """raise to given exponent (elemwise)""")
########################## ##########################
...@@ -917,12 +918,12 @@ class TransposeInplace(Op): ...@@ -917,12 +918,12 @@ class TransposeInplace(Op):
def __str__(self): def __str__(self):
return "TransposeView" return "TransposeView"
transpose_inplace = TransposeInplace() _transpose_inplace = TransposeInplace()
"""WRITEME""" """WRITEME"""
def transpose(x, **kwargs): def transpose(x, **kwargs):
"""WRITEME""" """WRITEME"""
return transpose_inplace(tensor_copy(x), **kwargs) return _transpose_inplace(tensor_copy(x), **kwargs)
......
...@@ -20,7 +20,7 @@ def in2out(*local_opts): ...@@ -20,7 +20,7 @@ def in2out(*local_opts):
# gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c) # gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c)
# Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0) # Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0)
gemm_pattern_1 = gof.PatternSub((T.sub_inplace, gemm_pattern_1 = gof.PatternSub((T._sub_inplace,
'd', 'd',
(T.mul, (T.mul,
dict(pattern = (T.DimShuffle((), ['x', 'x'], inplace = True), 'a'), dict(pattern = (T.DimShuffle((), ['x', 'x'], inplace = True), 'a'),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论