提交 22c67a69 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use direct theano.gof.graph imports in tests.tensor.test_extra_ops

上级 9abca4b8
...@@ -6,6 +6,7 @@ from tests import unittest_tools as utt ...@@ -6,6 +6,7 @@ from tests import unittest_tools as utt
from theano import function from theano import function
from theano import tensor as tt from theano import tensor as tt
from theano.configdefaults import config from theano.configdefaults import config
from theano.gof.graph import ops as graph_ops
from theano.tensor.extra_ops import ( from theano.tensor.extra_ops import (
Bartlett, Bartlett,
BroadcastTo, BroadcastTo,
...@@ -1220,8 +1221,7 @@ def test_broadcast_shape(): ...@@ -1220,8 +1221,7 @@ def test_broadcast_shape():
arrays_are_shapes=True, arrays_are_shapes=True,
) )
assert any( assert any(
isinstance(node.op, tt.opt.Assert) isinstance(node.op, tt.opt.Assert) for node in graph_ops([x_tt, y_tt], b_tt)
for node in tt.gof.graph.ops([x_tt, y_tt], b_tt)
) )
assert np.array_equal([z.eval() for z in b_tt], b.shape) assert np.array_equal([z.eval() for z in b_tt], b.shape)
b_tt = broadcast_shape(shape_tuple(x_tt), shape_tuple(y_tt), arrays_are_shapes=True) b_tt = broadcast_shape(shape_tuple(x_tt), shape_tuple(y_tt), arrays_are_shapes=True)
...@@ -1229,8 +1229,7 @@ def test_broadcast_shape(): ...@@ -1229,8 +1229,7 @@ def test_broadcast_shape():
# These are all constants, so there shouldn't be any asserts in the # These are all constants, so there shouldn't be any asserts in the
# resulting graph. # resulting graph.
assert not any( assert not any(
isinstance(node.op, tt.opt.Assert) isinstance(node.op, tt.opt.Assert) for node in graph_ops([x_tt, y_tt], b_tt)
for node in tt.gof.graph.ops([x_tt, y_tt], b_tt)
) )
x = np.array([1, 2, 3]) x = np.array([1, 2, 3])
...@@ -1246,7 +1245,7 @@ def test_broadcast_shape(): ...@@ -1246,7 +1245,7 @@ def test_broadcast_shape():
# implementation. # implementation.
# assert not any( # assert not any(
# isinstance(node.op, tt.opt.Assert) # isinstance(node.op, tt.opt.Assert)
# for node in tt.gof.graph.ops([x_tt, y_tt], b_tt) # for node in graph_ops([x_tt, y_tt], b_tt)
# ) # )
x = np.empty((1, 2, 3)) x = np.empty((1, 2, 3))
...@@ -1258,8 +1257,7 @@ def test_broadcast_shape(): ...@@ -1258,8 +1257,7 @@ def test_broadcast_shape():
assert b_tt[0].value == 1 assert b_tt[0].value == 1
assert np.array_equal([z.eval() for z in b_tt], b.shape) assert np.array_equal([z.eval() for z in b_tt], b.shape)
assert not any( assert not any(
isinstance(node.op, tt.opt.Assert) isinstance(node.op, tt.opt.Assert) for node in graph_ops([x_tt, y_tt], b_tt)
for node in tt.gof.graph.ops([x_tt, y_tt], b_tt)
) )
b_tt = broadcast_shape(shape_tuple(x_tt), shape_tuple(y_tt), arrays_are_shapes=True) b_tt = broadcast_shape(shape_tuple(x_tt), shape_tuple(y_tt), arrays_are_shapes=True)
assert np.array_equal([z.eval() for z in b_tt], b.shape) assert np.array_equal([z.eval() for z in b_tt], b.shape)
...@@ -1276,7 +1274,7 @@ def test_broadcast_shape(): ...@@ -1276,7 +1274,7 @@ def test_broadcast_shape():
# implementation. # implementation.
# assert not any( # assert not any(
# isinstance(node.op, tt.opt.Assert) # isinstance(node.op, tt.opt.Assert)
# for node in tt.gof.graph.ops([x_tt, y_tt], b_tt) # for node in graph_ops([x_tt, y_tt], b_tt)
# ) # )
b_tt = broadcast_shape(shape_tuple(x_tt), shape_tuple(y_tt), arrays_are_shapes=True) b_tt = broadcast_shape(shape_tuple(x_tt), shape_tuple(y_tt), arrays_are_shapes=True)
assert np.array_equal([z.eval() for z in b_tt], b.shape) assert np.array_equal([z.eval() for z in b_tt], b.shape)
...@@ -1293,7 +1291,7 @@ def test_broadcast_shape(): ...@@ -1293,7 +1291,7 @@ def test_broadcast_shape():
# implementation. # implementation.
# assert not any( # assert not any(
# isinstance(node.op, tt.opt.Assert) # isinstance(node.op, tt.opt.Assert)
# for node in tt.gof.graph.ops([x_tt, y_tt], b_tt) # for node in graph_ops([x_tt, y_tt], b_tt)
# ) # )
res = tt.as_tensor(b_tt).eval( res = tt.as_tensor(b_tt).eval(
{ {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论