提交 724412fa authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix test now that we replace a reshape with a dimshuffle

上级 70e23f42
...@@ -5114,12 +5114,15 @@ class T_reshape(utt.InferShapeTester, utt.TestOptimizationMixin): ...@@ -5114,12 +5114,15 @@ class T_reshape(utt.InferShapeTester, utt.TestOptimizationMixin):
self.ignore_topo = ignore_topo self.ignore_topo = ignore_topo
super(T_reshape, self).__init__(name) super(T_reshape, self).__init__(name)
def function(self, inputs, outputs): def function(self, inputs, outputs, ignore_empty=False):
f = function(inputs, outputs, mode=self.mode) f = function(inputs, outputs, mode=self.mode)
if self.mode is not None or theano.config.mode != "FAST_COMPILE": if self.mode is not None or theano.config.mode != "FAST_COMPILE":
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
topo_ = [node for node in topo if not isinstance(node.op, topo_ = [node for node in topo if not isinstance(node.op,
self.ignore_topo)] self.ignore_topo)]
if ignore_empty:
assert len(topo_) <= 1, topo_
else:
assert len(topo_) == 1, topo_ assert len(topo_) == 1, topo_
assert type(topo_[0].op) is self.op assert type(topo_[0].op) is self.op
return f return f
...@@ -5194,12 +5197,23 @@ class T_reshape(utt.InferShapeTester, utt.TestOptimizationMixin): ...@@ -5194,12 +5197,23 @@ class T_reshape(utt.InferShapeTester, utt.TestOptimizationMixin):
# test broadcast flag for constant value of 1 # test broadcast flag for constant value of 1
c = reshape(b, (b.shape[0], b.shape[1], 1)) c = reshape(b, (b.shape[0], b.shape[1], 1))
f = self.function([b], c) # That reshape may get replaced with a dimshuffle, with is ignored,
# so we pass "ignore_empty=True"
f = self.function([b], c, ignore_empty=True)
assert numpy.all(f(numpy.asarray([[0, 1, 2], [3, 4, 5]])) == assert numpy.all(f(numpy.asarray([[0, 1, 2], [3, 4, 5]])) ==
numpy.asarray([[[0], [1], [2]], [[3], [4], [5]]])) numpy.asarray([[[0], [1], [2]], [[3], [4], [5]]]))
assert (f.maker.fgraph.toposort()[-1].outputs[0].type.broadcastable == assert (f.maker.fgraph.toposort()[-1].outputs[0].type.broadcastable ==
(False, False, True)) (False, False, True))
# test broadcast flag for constant value of 1 if it cannot be
# replaced with dimshuffle
c = reshape(b, (b.shape[1], b.shape[0], 1))
f = self.function([b], c, ignore_empty=True)
assert numpy.all(f(numpy.asarray([[0, 1, 2], [3, 4, 5]])) ==
numpy.asarray([[[0], [1]], [[2], [3]], [[4], [5]]]))
assert (f.maker.fgraph.toposort()[-1].outputs[0].type.broadcastable ==
(False, False, True))
def test_m1(self): def test_m1(self):
t = tensor3() t = tensor3()
rng = numpy.random.RandomState(seed=utt.fetch_seed()) rng = numpy.random.RandomState(seed=utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论