提交 e5ebf260 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add kwargs to OpFromGraphs constructed in OpFromGraph.make_node

上级 75b7233e
...@@ -811,6 +811,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -811,6 +811,7 @@ class OpFromGraph(Op, HasInnerGraph):
rop_overrides=self.rop_overrides, rop_overrides=self.rop_overrides,
connection_pattern=self._connection_pattern, connection_pattern=self._connection_pattern,
name=self.name, name=self.name,
**self.kwargs,
) )
new_inputs = ( new_inputs = (
list(non_shared_inputs) + unshared_inputs + new_op.shared_inputs list(non_shared_inputs) + unshared_inputs + new_op.shared_inputs
......
...@@ -465,7 +465,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -465,7 +465,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
x = at.scalar("x") x = at.scalar("x")
y = shared(1.0, name="y") y = shared(1.0, name="y")
test_ofg = OpFromGraph([x], [x + y]) test_ofg = OpFromGraph([x], [x + y], on_unused_input="ignore")
assert test_ofg.inputs == [x] assert test_ofg.inputs == [x]
assert test_ofg.shared_inputs == [y] assert test_ofg.shared_inputs == [y]
...@@ -477,6 +477,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -477,6 +477,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
out_new = test_ofg.make_node(*(out.owner.inputs[:1] + [y_clone])).outputs[0] out_new = test_ofg.make_node(*(out.owner.inputs[:1] + [y_clone])).outputs[0]
assert "on_unused_input" in out_new.owner.op.kwargs
assert out_new.owner.op.inputs == [x] assert out_new.owner.op.inputs == [x]
assert out_new.owner.op.shared_inputs == [y_clone] assert out_new.owner.op.shared_inputs == [y_clone]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论