提交 bbd9e02f authored 作者: nouiz's avatar nouiz

Merge pull request #121 from gdesjardins/reshape_subtensor_opt

Added Subtensor(Rebroadcast(x)) => Rebroadcast(Subtensor(x)) optimization
......@@ -1311,13 +1311,16 @@ def local_subtensor_lift(node):
"""
unary(x)[idx] -> unary(x[idx])#any broadcast pattern.
Handles the following unary ops:
elemwise(x,...)[idx] -> elemwise(x[idx],...)
when x,... are broadcasted scalar or not broadcasted at all
when x,... are broadcasted scalar or not broadcasted at all
rebroadcast(x)[idx] => rebroadcast(x[idx])
"""
if isinstance(node.op, T.Subtensor):
u = node.inputs[0]
if not u.owner or len(u.clients) > 1:
return False
if isinstance(u.owner.op, T.Elemwise) and len(u.owner.inputs)==1:
idx = node.inputs[1:]
x_idx = node.op(u.owner.inputs[0], *idx)
......@@ -1346,6 +1349,35 @@ def local_subtensor_lift(node):
new_inputs.append(i.dimshuffle(['x']*node.outputs[0].ndim))
return [u.owner.op(*new_inputs)]
if isinstance(u.owner.op, T.Rebroadcast):
# make sure that Subtensor and Rebroadcast only have 1 input/output
assert len(node.inputs) == 1
assert len(u.owner.inputs) == 1
# Subtensor might reduce dim., adapt broadcast pattern accordingly
new_axis = []
# loop through indices being subtensor-ed
# i indexes broadcastable pattern before subtensor
# j indexes broadcastable pattern after subtensor
j = 0
for (i,x) in enumerate(node.op.idx_list):
# if its not a slice, it will reduce the dimension, should
# not appear in the broascastable dimensions
if isinstance(x, slice):
new_axis += [(j, u.broadcastable[i])]
j += 1
# now keep the broadcastable pattern of all
# items not appearing in subtensor list
for i in xrange(len(node.op.idx_list), len(u.broadcastable)):
new_axis += [(j,u.broadcastable[i])]
j += 1
subt_x = T.Subtensor(node.op.idx_list)(u.owner.inputs[0])
rbcast_subt_x = T.Rebroadcast(*new_axis)(subt_x)
return [rbcast_subt_x]
def merge_two_slices(slice1, len1, slice2, len2):
'''
......
......@@ -1370,6 +1370,57 @@ class test_local_subtensor_lift(unittest.TestCase):
assert len(prog)==2
f([1,2,3], 4) # let debugmode test something
def test7(self):
# test that Subtensor(Rebroadcast(x)) gets optimized into
# Rebroadcast(Subtensor(x)).
# test basic case
x = tensor.matrix('x')
xval = numpy.random.rand(1,10).astype(config.floatX)
assert x.broadcastable == (False,False)
newx = tensor.Rebroadcast((0,True),(1,False))(x)
assert newx.broadcastable == (True,False)
f1 = function([x], newx[:2,:5], mode=mode_opt)
prog=f1.maker.env.toposort()
assert isinstance(prog[0].op, tensor.Subtensor)
assert isinstance(prog[1].op, tensor.Rebroadcast)
assert (f1(xval) == xval[:2,:5]).all()
# corner case 1: rebroadcast changes dims which are dropped through subtensor
y = tensor.tensor4('x')
yval = numpy.random.rand(1,10,1,3).astype(config.floatX)
assert y.broadcastable == (False,False,False,False)
newy = tensor.Rebroadcast((0,True),(2,True))(y)
assert newy.broadcastable == (True,False,True,False)
f2 = function([y], newy[:,3,0,:], mode=mode_opt)
prog=f2.maker.env.toposort()
assert isinstance(prog[0].op, tensor.Subtensor)
assert isinstance(prog[1].op, tensor.Rebroadcast)
assert (f2(yval) == yval[:,3,0,:]).all()
# corner case 2: subtensor idx_list is shorter than resulting broadcast pattern
f3 = function([y], newy[:,3,0], mode=mode_opt)
prog=f3.maker.env.toposort()
assert isinstance(prog[0].op, tensor.Subtensor)
assert isinstance(prog[1].op, tensor.Rebroadcast)
assert (f3(yval) == yval[:,3,0]).all()
# corner case 3: subtensor idx_list is shorter than rebroadcast.axis
z = tensor.tensor4('x')
zval = numpy.random.rand(4,10,3,1).astype(config.floatX)
assert z.broadcastable == (False,False,False,False)
newz = tensor.Rebroadcast((3,True))(z)
assert newz.broadcastable == (False,False,False,True)
out = newz[:,3,0]
f4= function([z], newz[:,3,0], mode=mode_opt)
prog=f4.maker.env.toposort()
assert isinstance(prog[0].op, tensor.Subtensor)
assert isinstance(prog[1].op, tensor.Rebroadcast)
assert (f4(zval) == zval[:,3,0]).all()
class test_local_subtensor_merge(unittest.TestCase):
def setUp(self):
utt.seed_rng()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论