提交 ab5a218b authored 作者: James Bergstra's avatar James Bergstra

tensor opt - added local_subtensor_unary canonicalization

上级 e91f68d0
......@@ -293,6 +293,18 @@ register_canonicalize(local_fill_lift, 'fill_lift')
# Subtensor opts #
##################
@register_canonicalize
@gof.local_optimizer([])
def local_subtensor_unary(node):
"""
unary(x)[idx] -> unary(x[idx])
"""
if isinstance(node.op, T.Subtensor):
u = node.inputs[0]
if u.owner and isinstance(u.owner.op, T.Elemwise) and len(u.owner.inputs)==1:
idx = node.inputs[1:]
x_idx = node.op(u.owner.inputs[0], *idx)
return [u.owner.op(x_idx)]
@gof.local_optimizer([None, None])
def local_subtensor_make_vector(node):
......
......@@ -6,7 +6,8 @@ import numpy
import theano
from theano import gof
from theano.tensor.opt import *
from theano import tensor
from theano import tensor #do not use, there is an import * below that hides it
from theano import tensor as TT #ugly but works for now...
from theano.tensor import TensorType, inplace
from theano.gof import Env
from theano.tensor.elemwise import DimShuffle
......@@ -79,6 +80,7 @@ def test_add_canonizer_problem0():
f = function([label], r)
from theano.tensor import *
# Why is there TWO 'import *' in this file???
class test_greedy_distribute(unittest.TestCase):
def test_main(self):
......@@ -940,6 +942,42 @@ def test_log1p():
f = function([z], T.log(1+(z)), mode=m)
assert [node.op for node in f.maker.env.toposort()] == [T.log1p]
class test_local_subtensor_unary():
def test0(self):
# basic test that the Op works
mode = theano.config.mode
if mode == 'FAST_COMPILE':
mode = 'FAST_RUN'
x = TT.matrix()
f = function([x], TT.exp(x)[0], mode=mode)
prog=f.maker.env.toposort()
assert isinstance(prog[0].op, TT.Subtensor) #first subtensor
assert prog[1].op == TT.exp
f([[0,1],[2,3]]) # let debugmode test something
def test1(self):
# basic test that the optimization doesn't work with broadcasting
# ... It *could* be extended to,
# ... but right now it doesn't, so it shouldn't try.
mode = theano.config.mode
if mode == 'FAST_COMPILE':
mode = 'FAST_RUN'
x = TT.matrix()
y = TT.vector()
f = function([x,y], TT.exp(x+y)[0], mode=mode)
prog=f.maker.env.toposort()
# the optimization works through exp() but not add()
print prog
assert isinstance(prog[0].op, TT.DimShuffle)
assert prog[1].op == TT.add
assert isinstance(prog[2].op, TT.Subtensor) #first subtensor
assert prog[3].op == inplace.exp_inplace
f([[0,1],[2,3]], [4,5]) # let debugmode test something
if __name__ == '__main__':
# unittest.main()
test_fusion().tes_memory_leak()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论