提交 c95bea50 authored 作者: Sina Honari's avatar Sina Honari 提交者: Francesco

Replace some subtensors with DimShuffle first commit (#4356)

Replace subtensors on broadcastable dimensions with DimShuffle
上级 e6724e8d
......@@ -1764,6 +1764,54 @@ def local_track_shape_i(node):
return [shape_feature.shape_of[replacement][node.op.i]]
@register_specialize
@register_canonicalize
@gof.local_optimizer([Subtensor])
def local_subtensor_remove_broadcastable_index(node):
"""
Remove broadcastable dimension with index 0 or -1
a[:,:,:,0] -> a.dimshuffle(0,1,2), when
a.broadcastable = (False, False, False, True)
a[0,:,-1,:] -> a.dimshuffle(1,3), when
a.broadcastable = (True, False, True, False)
"""
if isinstance(node.op, Subtensor):
idx = node.op.idx_list
else:
return
remove_dim = []
node_inputs_idx = 1
for dim, elem in enumerate(idx):
if isinstance(elem, (scalar.Scalar)):
# The idx is a Scalar, ie a Type. This means the actual index
# is contained in node.inputs[1]
dim_index = node.inputs[node_inputs_idx]
if type(dim_index) == theano.scalar.basic.ScalarConstant:
dim_index = dim_index.value
if dim_index in [0, -1] and node.inputs[0].broadcastable[dim]:
remove_dim.append(dim)
node_inputs_idx += 1
else:
return
elif isinstance(elem, slice):
if elem != slice(None):
return
elif isinstance(elem, (integer_types, numpy.integer)):
if elem in [0, -1] and node.inputs[0].broadcastable[dim]:
remove_dim.append(dim)
else:
raise TypeError('case not expected')
if len(remove_dim) == 0:
return
else:
all_dim = range(node.inputs[0].ndim)
remain_dim = [x for x in all_dim if x not in remove_dim]
return [node.inputs[0].dimshuffle(tuple(remain_dim))]
@register_specialize
@register_canonicalize('fast_compile_gpu')
@gof.local_optimizer([Subtensor, AdvancedSubtensor1])
......
......@@ -47,8 +47,12 @@ from theano.tensor import vector, ivector, lvector, fvector, dvector
from theano.tensor import matrix, imatrix, lmatrix, fmatrix, dmatrix
from theano.tensor import scalars, vectors, matrices, fmatrices, dmatrices
from theano.tensor import (
AdvancedSubtensor,
AdvancedSubtensor1,
as_tensor_variable,
IncSubtensor,
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
inplace,
Join,
join,
......@@ -1834,6 +1838,67 @@ def test_local_useless_subtensor():
f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something
def test_local_subtensor_remove_broadcastable_index():
# testing local_subtensor_remove_broadcastable_index optimization
#
# tests removing broadcastable dimensions with index 0 or -1,
# otherwise the optimzation should not be applied
mode = theano.compile.mode.get_default_mode()
mode = mode.including("local_subtensor_remove_broadcastable_index")
x = T.dmatrix('x')
y1 = x.dimshuffle(0, 'x', 1)
y2 = x.dimshuffle('x', 1, 0, 'x')
y3 = x.dimshuffle('x', 1, 'x', 0, 'x')
# testing for cases that the optimzation should be applied
z1 = y1[:, 0, :]
z2 = y1[:, -1, :]
z3 = y2[0, :, :, -1]
z4 = y2[0, :, :, 0]
z5 = y2[-1, :, :, -1]
z6 = y3[-1, :, 0, :, -1]
z7 = y3[-1, :, -1, :, -1]
z8 = y3[0, :, 0, :, 0]
f = theano.function([x], [z1, z2, z3, z4, z5, z6, z7, z8], mode=mode)
for elem in f.maker.fgraph.toposort():
assert type(elem.op) not in [Subtensor, AdvancedSubtensor,
AdvancedSubtensor1, IncSubtensor,
AdvancedIncSubtensor,
AdvancedIncSubtensor1]
rng = numpy.random.RandomState(seed=utt.fetch_seed())
xn = rng.rand(5, 5)
f(xn)
# testing for cases that the optimzation should not be applied
# to verify that other subtensor usage are passed without errors
w1 = y1[3, 0, :]
w2 = y1[2:4, -1, :]
w3 = y2[0, :, 4:, -1]
w4 = y2[:, :, 0, -1]
w5 = y2[0, 2:4, :, 0]
w6 = y2[0, -1, :, -1]
w7 = y2[-1, 4:, :, -1]
w8 = y2[-1, :, :3, -1]
w9 = y2[-1, :, -1, -1]
w10 = y3[-1, 2, 0, :, -1]
w11 = y3[-1, 0, -1, :, -1]
w12 = y3[-1, :, -1, -1, -1]
w13 = y3[0, 0, 0, :, 0]
w14 = y3[-1, 2:4, 0, 1:5, -1]
w15 = y3[-1, 0, -1, 0, -1]
w16 = y3[0, 2, 0, 4, 0]
w17 = y3[:, 0, :, 1]
w18 = y3[0, :, :, 2]
w19 = y3[:, 2, 0]
w20 = y3[:, 3]
f2 = theano.function([x], [w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11,
w12, w13, w14, w15, w16, w17, w18, w19, w20],
mode=mode)
f2(xn)
class test_local_subtensor_make_vector(unittest.TestCase):
def test_scalar_idx(self):
x, y, z = tensor.lscalars('xyz')
......
......@@ -5,6 +5,9 @@ from numpy.testing import assert_equal, assert_string_equal
import theano
import theano.tensor as tt
import theano.tests.unittest_tools as utt
from theano.tensor import (Subtensor, AdvancedSubtensor, AdvancedSubtensor1,
IncSubtensor, AdvancedIncSubtensor,
AdvancedIncSubtensor1)
def test_numpy_method():
......@@ -30,3 +33,31 @@ def test_copy():
f = theano.function([x], y)
assert_equal(f(data), data)
assert_string_equal(y.name, 'y')
def test_None_dimShuffle_replace():
# tests replacing None usage in subtensor with dimshuffle
#
# tests whenever None is used in subtensor to reshape a variable, it is
# replaced by dimshuffle. If the replacement is done properly, Subtensor op
# (or any of its variants) should not be used anymore.
x = tt.dmatrix('x')
y = x[:, None, :]
f = theano.function([x], y)
for elem in f.maker.fgraph.toposort():
assert type(elem.op) not in [Subtensor, AdvancedSubtensor,
AdvancedSubtensor1, IncSubtensor,
AdvancedIncSubtensor,
AdvancedIncSubtensor1]
x = tt.tensor3('x')
y1 = x[:, :, None, :]
y2 = x[None, :, :, None, :]
y3 = x[:, :, None, :, None, None]
f = theano.function([x], [y1, y2, y3])
for elem in f.maker.fgraph.toposort():
assert type(elem.op) not in [Subtensor, AdvancedSubtensor,
AdvancedSubtensor1, IncSubtensor,
AdvancedIncSubtensor,
AdvancedIncSubtensor1]
......@@ -524,8 +524,11 @@ class _tensor_py_operators(object):
counter += 1
new_args.append(arg)
view = self.dimshuffle(pattern)
rval = view.__getitem__(tuple(new_args))
return rval
check_rval = [arg == slice(None, None, None) for arg in new_args]
if all(check_rval) == True:
return view
else:
return view.__getitem__(tuple(new_args))
else:
return theano.tensor.subtensor.Subtensor(args)(
self, *theano.tensor.subtensor.Subtensor.collapse(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论