提交 c95bea50 authored 作者: Sina Honari's avatar Sina Honari 提交者: Francesco

Replace some subtensors with DimShuffle first commit (#4356)

Replace subtensors on broadcastable dimensions with DimShuffle
上级 e6724e8d
...@@ -1764,6 +1764,54 @@ def local_track_shape_i(node): ...@@ -1764,6 +1764,54 @@ def local_track_shape_i(node):
return [shape_feature.shape_of[replacement][node.op.i]] return [shape_feature.shape_of[replacement][node.op.i]]
@register_specialize
@register_canonicalize
@gof.local_optimizer([Subtensor])
def local_subtensor_remove_broadcastable_index(node):
"""
Remove broadcastable dimension with index 0 or -1
a[:,:,:,0] -> a.dimshuffle(0,1,2), when
a.broadcastable = (False, False, False, True)
a[0,:,-1,:] -> a.dimshuffle(1,3), when
a.broadcastable = (True, False, True, False)
"""
if isinstance(node.op, Subtensor):
idx = node.op.idx_list
else:
return
remove_dim = []
node_inputs_idx = 1
for dim, elem in enumerate(idx):
if isinstance(elem, (scalar.Scalar)):
# The idx is a Scalar, ie a Type. This means the actual index
# is contained in node.inputs[1]
dim_index = node.inputs[node_inputs_idx]
if type(dim_index) == theano.scalar.basic.ScalarConstant:
dim_index = dim_index.value
if dim_index in [0, -1] and node.inputs[0].broadcastable[dim]:
remove_dim.append(dim)
node_inputs_idx += 1
else:
return
elif isinstance(elem, slice):
if elem != slice(None):
return
elif isinstance(elem, (integer_types, numpy.integer)):
if elem in [0, -1] and node.inputs[0].broadcastable[dim]:
remove_dim.append(dim)
else:
raise TypeError('case not expected')
if len(remove_dim) == 0:
return
else:
all_dim = range(node.inputs[0].ndim)
remain_dim = [x for x in all_dim if x not in remove_dim]
return [node.inputs[0].dimshuffle(tuple(remain_dim))]
@register_specialize @register_specialize
@register_canonicalize('fast_compile_gpu') @register_canonicalize('fast_compile_gpu')
@gof.local_optimizer([Subtensor, AdvancedSubtensor1]) @gof.local_optimizer([Subtensor, AdvancedSubtensor1])
......
...@@ -47,8 +47,12 @@ from theano.tensor import vector, ivector, lvector, fvector, dvector ...@@ -47,8 +47,12 @@ from theano.tensor import vector, ivector, lvector, fvector, dvector
from theano.tensor import matrix, imatrix, lmatrix, fmatrix, dmatrix from theano.tensor import matrix, imatrix, lmatrix, fmatrix, dmatrix
from theano.tensor import scalars, vectors, matrices, fmatrices, dmatrices from theano.tensor import scalars, vectors, matrices, fmatrices, dmatrices
from theano.tensor import ( from theano.tensor import (
AdvancedSubtensor,
AdvancedSubtensor1, AdvancedSubtensor1,
as_tensor_variable, as_tensor_variable,
IncSubtensor,
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
inplace, inplace,
Join, Join,
join, join,
...@@ -1834,6 +1838,67 @@ def test_local_useless_subtensor(): ...@@ -1834,6 +1838,67 @@ def test_local_useless_subtensor():
f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something
def test_local_subtensor_remove_broadcastable_index():
# testing local_subtensor_remove_broadcastable_index optimization
#
# tests removing broadcastable dimensions with index 0 or -1,
# otherwise the optimzation should not be applied
mode = theano.compile.mode.get_default_mode()
mode = mode.including("local_subtensor_remove_broadcastable_index")
x = T.dmatrix('x')
y1 = x.dimshuffle(0, 'x', 1)
y2 = x.dimshuffle('x', 1, 0, 'x')
y3 = x.dimshuffle('x', 1, 'x', 0, 'x')
# testing for cases that the optimzation should be applied
z1 = y1[:, 0, :]
z2 = y1[:, -1, :]
z3 = y2[0, :, :, -1]
z4 = y2[0, :, :, 0]
z5 = y2[-1, :, :, -1]
z6 = y3[-1, :, 0, :, -1]
z7 = y3[-1, :, -1, :, -1]
z8 = y3[0, :, 0, :, 0]
f = theano.function([x], [z1, z2, z3, z4, z5, z6, z7, z8], mode=mode)
for elem in f.maker.fgraph.toposort():
assert type(elem.op) not in [Subtensor, AdvancedSubtensor,
AdvancedSubtensor1, IncSubtensor,
AdvancedIncSubtensor,
AdvancedIncSubtensor1]
rng = numpy.random.RandomState(seed=utt.fetch_seed())
xn = rng.rand(5, 5)
f(xn)
# testing for cases that the optimzation should not be applied
# to verify that other subtensor usage are passed without errors
w1 = y1[3, 0, :]
w2 = y1[2:4, -1, :]
w3 = y2[0, :, 4:, -1]
w4 = y2[:, :, 0, -1]
w5 = y2[0, 2:4, :, 0]
w6 = y2[0, -1, :, -1]
w7 = y2[-1, 4:, :, -1]
w8 = y2[-1, :, :3, -1]
w9 = y2[-1, :, -1, -1]
w10 = y3[-1, 2, 0, :, -1]
w11 = y3[-1, 0, -1, :, -1]
w12 = y3[-1, :, -1, -1, -1]
w13 = y3[0, 0, 0, :, 0]
w14 = y3[-1, 2:4, 0, 1:5, -1]
w15 = y3[-1, 0, -1, 0, -1]
w16 = y3[0, 2, 0, 4, 0]
w17 = y3[:, 0, :, 1]
w18 = y3[0, :, :, 2]
w19 = y3[:, 2, 0]
w20 = y3[:, 3]
f2 = theano.function([x], [w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11,
w12, w13, w14, w15, w16, w17, w18, w19, w20],
mode=mode)
f2(xn)
class test_local_subtensor_make_vector(unittest.TestCase): class test_local_subtensor_make_vector(unittest.TestCase):
def test_scalar_idx(self): def test_scalar_idx(self):
x, y, z = tensor.lscalars('xyz') x, y, z = tensor.lscalars('xyz')
......
...@@ -5,6 +5,9 @@ from numpy.testing import assert_equal, assert_string_equal ...@@ -5,6 +5,9 @@ from numpy.testing import assert_equal, assert_string_equal
import theano import theano
import theano.tensor as tt import theano.tensor as tt
import theano.tests.unittest_tools as utt import theano.tests.unittest_tools as utt
from theano.tensor import (Subtensor, AdvancedSubtensor, AdvancedSubtensor1,
IncSubtensor, AdvancedIncSubtensor,
AdvancedIncSubtensor1)
def test_numpy_method(): def test_numpy_method():
...@@ -30,3 +33,31 @@ def test_copy(): ...@@ -30,3 +33,31 @@ def test_copy():
f = theano.function([x], y) f = theano.function([x], y)
assert_equal(f(data), data) assert_equal(f(data), data)
assert_string_equal(y.name, 'y') assert_string_equal(y.name, 'y')
def test_None_dimShuffle_replace():
# tests replacing None usage in subtensor with dimshuffle
#
# tests whenever None is used in subtensor to reshape a variable, it is
# replaced by dimshuffle. If the replacement is done properly, Subtensor op
# (or any of its variants) should not be used anymore.
x = tt.dmatrix('x')
y = x[:, None, :]
f = theano.function([x], y)
for elem in f.maker.fgraph.toposort():
assert type(elem.op) not in [Subtensor, AdvancedSubtensor,
AdvancedSubtensor1, IncSubtensor,
AdvancedIncSubtensor,
AdvancedIncSubtensor1]
x = tt.tensor3('x')
y1 = x[:, :, None, :]
y2 = x[None, :, :, None, :]
y3 = x[:, :, None, :, None, None]
f = theano.function([x], [y1, y2, y3])
for elem in f.maker.fgraph.toposort():
assert type(elem.op) not in [Subtensor, AdvancedSubtensor,
AdvancedSubtensor1, IncSubtensor,
AdvancedIncSubtensor,
AdvancedIncSubtensor1]
...@@ -524,8 +524,11 @@ class _tensor_py_operators(object): ...@@ -524,8 +524,11 @@ class _tensor_py_operators(object):
counter += 1 counter += 1
new_args.append(arg) new_args.append(arg)
view = self.dimshuffle(pattern) view = self.dimshuffle(pattern)
rval = view.__getitem__(tuple(new_args)) check_rval = [arg == slice(None, None, None) for arg in new_args]
return rval if all(check_rval) == True:
return view
else:
return view.__getitem__(tuple(new_args))
else: else:
return theano.tensor.subtensor.Subtensor(args)( return theano.tensor.subtensor.Subtensor(args)(
self, *theano.tensor.subtensor.Subtensor.collapse( self, *theano.tensor.subtensor.Subtensor.collapse(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论