提交 d7a02fa1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Check for Constants in OpFromGraph constructor

上级 10bb77bf
...@@ -11,6 +11,7 @@ from aesara.configdefaults import config ...@@ -11,6 +11,7 @@ from aesara.configdefaults import config
from aesara.gradient import DisconnectedType, Rop, grad from aesara.gradient import DisconnectedType, Rop, grad
from aesara.graph.basic import ( from aesara.graph.basic import (
Apply, Apply,
Constant,
Variable, Variable,
clone_replace, clone_replace,
graph_inputs, graph_inputs,
...@@ -331,13 +332,21 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -331,13 +332,21 @@ class OpFromGraph(Op, HasInnerGraph):
name=None, name=None,
**kwargs, **kwargs,
): ):
if not isinstance(outputs, list):
raise TypeError(f"outputs must be list, got {type(outputs)}") if not (isinstance(inputs, list) and isinstance(outputs, list)):
raise TypeError("Inputs and outputs must be lists")
for i in inputs + outputs: for i in inputs + outputs:
if not isinstance(i, Variable): if not isinstance(i, Variable):
raise TypeError("inputs and outputs must be Variable instances", i) raise TypeError(
f"Inputs and outputs must be Variable instances; got {i}"
)
if i in inputs and isinstance(i, Constant):
raise TypeError(f"Constants not allowed as inputs; {i}")
if "updates" in kwargs or "givens" in kwargs: if "updates" in kwargs or "givens" in kwargs:
raise TypeError("updates and givens are not allowed here") raise NotImplementedError("Updates and givens are not allowed here")
self.is_inline = inline self.is_inline = inline
# To correctly support shared variables the inner fct should # To correctly support shared variables the inner fct should
# not see them. Otherwise there is a problem with the gradient. # not see them. Otherwise there is a problem with the gradient.
......
...@@ -19,12 +19,28 @@ from aesara.tensor.math import round as aet_round ...@@ -19,12 +19,28 @@ from aesara.tensor.math import round as aet_round
from aesara.tensor.math import sigmoid from aesara.tensor.math import sigmoid
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.random.utils import RandomStream from aesara.tensor.random.utils import RandomStream
from aesara.tensor.shape import specify_shape
from aesara.tensor.type import TensorType, matrices, matrix, scalar, vector, vectors from aesara.tensor.type import TensorType, matrices, matrix, scalar, vector, vectors
from tests import unittest_tools from tests import unittest_tools
from tests.graph.utils import MyVariable from tests.graph.utils import MyVariable
class TestOpFromGraph(unittest_tools.InferShapeTester): class TestOpFromGraph(unittest_tools.InferShapeTester):
def test_valid_input(self):
x, y, z = matrices("xyz")
with pytest.raises(TypeError):
OpFromGraph((x,), (x,))
with pytest.raises(TypeError):
OpFromGraph([1], [1])
with pytest.raises(TypeError):
OpFromGraph([x, as_tensor(1)], [x])
with pytest.raises(NotImplementedError):
OpFromGraph([x], [x], updates={})
@pytest.mark.parametrize( @pytest.mark.parametrize(
"cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)] "cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)]
) )
...@@ -39,9 +55,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -39,9 +55,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
xv = np.ones((2, 2), dtype=config.floatX) xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3 yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5 zv = np.ones((2, 2), dtype=config.floatX) * 5
# print function, function.__module__
# print fn.maker.fgraph.toposort()
fn(xv, yv, zv)
assert np.all(8.0 == fn(xv, yv, zv)) assert np.all(8.0 == fn(xv, yv, zv))
assert np.all(8.0 == fn(xv, yv, zv)) assert np.all(8.0 == fn(xv, yv, zv))
...@@ -412,7 +425,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -412,7 +425,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
# shape # shape
x = MyVariable("x") x = MyVariable("x")
y = matrix("y") y = matrix("y")
z = as_tensor([1, 2]) z = specify_shape(vector("z"), (2,))
op_graph = OpFromGraph([x, y, z], [x, y]) op_graph = OpFromGraph([x, y, z], [x, y])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论