Unverified 提交 86bc1d29 authored 作者: Harshvir Sandhu's avatar Harshvir Sandhu 提交者: GitHub

Add name kwarg to Op.__call__ (#693)

上级 14651fb5
...@@ -246,7 +246,9 @@ class Op(MetaObject): ...@@ -246,7 +246,9 @@ class Op(MetaObject):
) )
return Apply(self, inputs, [o() for o in self.otypes]) return Apply(self, inputs, [o() for o in self.otypes])
def __call__(self, *inputs: Any, **kwargs) -> Variable | list[Variable]: def __call__(
self, *inputs: Any, name=None, return_list=False, **kwargs
) -> Variable | list[Variable]:
r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs. r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
This method is just a wrapper around :meth:`Op.make_node`. This method is just a wrapper around :meth:`Op.make_node`.
...@@ -288,8 +290,15 @@ class Op(MetaObject): ...@@ -288,8 +290,15 @@ class Op(MetaObject):
the :attr:`Op.default_output` property. the :attr:`Op.default_output` property.
""" """
return_list = kwargs.pop("return_list", False)
node = self.make_node(*inputs, **kwargs) node = self.make_node(*inputs, **kwargs)
if name is not None:
if len(node.outputs) == 1:
node.outputs[0].name = name
elif self.default_output is not None:
node.outputs[self.default_output].name = name
else:
for i, n in enumerate(node.outputs):
n.name = f"{name}_{i}"
if config.compute_test_value != "off": if config.compute_test_value != "off":
compute_test_value(node) compute_test_value(node)
......
...@@ -232,3 +232,46 @@ def test_op_input_broadcastable(): ...@@ -232,3 +232,46 @@ def test_op_input_broadcastable():
x = pt.TensorType(dtype="float64", shape=(1,))("x") x = pt.TensorType(dtype="float64", shape=(1,))("x")
assert SomeOp()(x).type == pt.dvector assert SomeOp()(x).type == pt.dvector
@pytest.mark.parametrize("multi_output", [True, False])
def test_call_name(multi_output):
def dummy_variable(name):
return Variable(MyType(thingy=None), None, None, name=name)
x = dummy_variable("x")
class TestCallOp(Op):
def __init__(self, default_output, multi_output):
super().__init__()
self.default_output = default_output
self.multi_output = multi_output
def make_node(self, input):
inputs = [input]
if self.multi_output:
outputs = [input.type(), input.type()]
else:
outputs = [input.type()]
return Apply(self, inputs, outputs)
def perform(self, node, inputs, outputs):
raise NotImplementedError()
if multi_output:
multi_op = TestCallOp(default_output=None, multi_output=multi_output)
res = multi_op(x, name="test_name")
for i, r in enumerate(res):
assert r.name == f"test_name_{i}"
multi_op = TestCallOp(default_output=1, multi_output=multi_output)
result = multi_op(x, name="test_name")
assert result.owner.outputs[0].name is None
assert result.name == "test_name"
else:
single_op = TestCallOp(default_output=None, multi_output=multi_output)
res_single = single_op(x, name="test_name")
assert res_single.name == "test_name"
res_nameless = single_op(x)
assert res_nameless.name is None
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论