提交 22c67a69 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use direct theano.gof.graph imports in tests.tensor.test_extra_ops

上级 9abca4b8
......@@ -6,6 +6,7 @@ from tests import unittest_tools as utt
from theano import function
from theano import tensor as tt
from theano.configdefaults import config
from theano.gof.graph import ops as graph_ops
from theano.tensor.extra_ops import (
Bartlett,
BroadcastTo,
......@@ -1220,8 +1221,7 @@ def test_broadcast_shape():
arrays_are_shapes=True,
)
assert any(
isinstance(node.op, tt.opt.Assert)
for node in tt.gof.graph.ops([x_tt, y_tt], b_tt)
isinstance(node.op, tt.opt.Assert) for node in graph_ops([x_tt, y_tt], b_tt)
)
assert np.array_equal([z.eval() for z in b_tt], b.shape)
b_tt = broadcast_shape(shape_tuple(x_tt), shape_tuple(y_tt), arrays_are_shapes=True)
......@@ -1229,8 +1229,7 @@ def test_broadcast_shape():
# These are all constants, so there shouldn't be any asserts in the
# resulting graph.
assert not any(
isinstance(node.op, tt.opt.Assert)
for node in tt.gof.graph.ops([x_tt, y_tt], b_tt)
isinstance(node.op, tt.opt.Assert) for node in graph_ops([x_tt, y_tt], b_tt)
)
x = np.array([1, 2, 3])
......@@ -1246,7 +1245,7 @@ def test_broadcast_shape():
# implementation.
# assert not any(
# isinstance(node.op, tt.opt.Assert)
# for node in tt.gof.graph.ops([x_tt, y_tt], b_tt)
# for node in graph_ops([x_tt, y_tt], b_tt)
# )
x = np.empty((1, 2, 3))
......@@ -1258,8 +1257,7 @@ def test_broadcast_shape():
assert b_tt[0].value == 1
assert np.array_equal([z.eval() for z in b_tt], b.shape)
assert not any(
isinstance(node.op, tt.opt.Assert)
for node in tt.gof.graph.ops([x_tt, y_tt], b_tt)
isinstance(node.op, tt.opt.Assert) for node in graph_ops([x_tt, y_tt], b_tt)
)
b_tt = broadcast_shape(shape_tuple(x_tt), shape_tuple(y_tt), arrays_are_shapes=True)
assert np.array_equal([z.eval() for z in b_tt], b.shape)
......@@ -1276,7 +1274,7 @@ def test_broadcast_shape():
# implementation.
# assert not any(
# isinstance(node.op, tt.opt.Assert)
# for node in tt.gof.graph.ops([x_tt, y_tt], b_tt)
# for node in graph_ops([x_tt, y_tt], b_tt)
# )
b_tt = broadcast_shape(shape_tuple(x_tt), shape_tuple(y_tt), arrays_are_shapes=True)
assert np.array_equal([z.eval() for z in b_tt], b.shape)
......@@ -1293,7 +1291,7 @@ def test_broadcast_shape():
# implementation.
# assert not any(
# isinstance(node.op, tt.opt.Assert)
# for node in tt.gof.graph.ops([x_tt, y_tt], b_tt)
# for node in graph_ops([x_tt, y_tt], b_tt)
# )
res = tt.as_tensor(b_tt).eval(
{
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论