提交 5c907045 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix optimizations to exit gracefully when env.feature_shape is not present.

Also test it. This should fix the problem experienced by Justin Bayer.
上级 40274190
......@@ -841,6 +841,9 @@ def local_useless_alloc(node):
@gof.local_optimizer([T._shape])
def local_shape_to_shape_i(node):
if node.op == T._shape:
# This optimization needs ShapeOpt and env.shape_feature
if not hasattr(node.env, 'shape_feature'):
return
shape_feature = node.env.shape_feature
return [shape_feature.make_vector_shape(node.inputs[0])]
......@@ -866,6 +869,9 @@ def local_subtensor_make_vector(node):
# [a,b,c][0:2] -> [a,b]
# we can do this for constant indexes
if isinstance(node.op, T.Subtensor):
# This optimization needs ShapeOpt and env.shape_feature
if not hasattr(node.env, 'shape_feature'):
return
shape_feature = node.env.shape_feature
x = node.inputs[0]
if x.owner and x.owner.op == make_vector:
......@@ -1173,6 +1179,9 @@ def local_useless_subtensor(node):
Remove Subtensor if it take the full input
"""
if isinstance(node.op, T.Subtensor):
# This optimization needs ShapeOpt and env.shape_feature
if not hasattr(node.env, 'shape_feature'):
return
shape_of = node.env.shape_feature.shape_of
node_input_idx = 1
for pos, idx in enumerate(node.op.idx_list):
......@@ -1778,6 +1787,9 @@ if 0:
@gof.local_optimizer([])
def local_sum_over_empty(node):
if isinstance(node.op, T.Sum):
# This optimization needs ShapeOpt and env.shape_feature
if not hasattr(node.env, 'shape_feature'):
return
y, = node.outputs
y_shape = node.env.shape_feature.shape_of[y]
......
......@@ -1697,6 +1697,16 @@ class test_shapeoptimizer(unittest.TestCase):
assert identity_noshape not in h_ops
assert identity_shape not in h_ops
def test_no_shapeopt(self):
# Test that a basic example works even when ShapeOpt is excluded
X = T.matrix()
expr = X.shape[0]
mode = theano.compile.get_default_mode().excluding('ShapeOpt')
f = theano.function([X], expr, mode=mode)
print f([[1, 2], [2, 3]])
class test_assert(unittest.TestCase):
def test0(self):
x=T.scalar()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论