提交 7a81d896 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Do not include Constants in automatically determined FunctionGraph inputs

上级 32562a29
...@@ -114,7 +114,7 @@ class FunctionGraph(MetaObject): ...@@ -114,7 +114,7 @@ class FunctionGraph(MetaObject):
raise ValueError("No outputs specified") raise ValueError("No outputs specified")
if inputs is None: if inputs is None:
inputs = [i for i in graph_inputs(outputs)] inputs = [i for i in graph_inputs(outputs) if not isinstance(i, Constant)]
if clone: if clone:
memo = clone_get_equiv( memo = clone_get_equiv(
......
...@@ -5,7 +5,7 @@ import pytest ...@@ -5,7 +5,7 @@ import pytest
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph, MissingInputError from aesara.graph.fg import FunctionGraph, MissingInputError
from tests.graph.utils import MyVariable, MyVariable2, op1, op2, op3 from tests.graph.utils import MyConstant, MyVariable, MyVariable2, op1, op2, op3
class TestFunctionGraph: class TestFunctionGraph:
...@@ -62,7 +62,9 @@ class TestFunctionGraph: ...@@ -62,7 +62,9 @@ class TestFunctionGraph:
assert fg.get_clients(var3) == [(var4.owner, 0), ("output", 0)] assert fg.get_clients(var3) == [(var4.owner, 0), ("output", 0)]
assert fg.get_clients(var4) == [("output", 1)] assert fg.get_clients(var4) == [("output", 1)]
fg = FunctionGraph(outputs=[var3, var4], clone=False) varC = MyConstant("varC")
var5 = op1(var1, varC)
fg = FunctionGraph(outputs=[var3, var4, var5], clone=False)
assert fg.inputs == [var1, var2] assert fg.inputs == [var1, var2]
memo = {} memo = {}
......
import numpy as np import numpy as np
from aesara.graph.basic import Apply, Variable from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.type import Type from aesara.graph.type import Type
...@@ -37,6 +37,10 @@ def MyVariable(name): ...@@ -37,6 +37,10 @@ def MyVariable(name):
return Variable(MyType(), None, None, name=name) return Variable(MyType(), None, None, name=name)
def MyConstant(name, data=None):
return Constant(MyType(), data, name=name)
def MyVariable2(name): def MyVariable2(name):
return Variable(MyType2(), None, None, name=name) return Variable(MyType2(), None, None, name=name)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论