提交 acfa7490 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Michael Osthege

Allow default_output to be any valid Python index

上级 ae7f39e2
...@@ -193,11 +193,9 @@ class Apply(Node, Generic[OpType]): ...@@ -193,11 +193,9 @@ class Apply(Node, Generic[OpType]):
if len(self.outputs) == 1: if len(self.outputs) == 1:
return self.outputs[0] return self.outputs[0]
else: else:
raise ValueError(f"{self.op}.default_output should be an output index.") raise ValueError(
elif not isinstance(do, int): f"Multi-output Op {self.op} default_output not specified"
raise ValueError(f"{self.op}.default_output should be an int or long") )
elif do < 0 or do >= len(self.outputs):
raise ValueError(f"{self.op}.default_output is out of range.")
return self.outputs[do] return self.outputs[do]
def __str__(self): def __str__(self):
......
...@@ -80,7 +80,7 @@ def _as_tensor_Apply(x, name, ndim, **kwargs): ...@@ -80,7 +80,7 @@ def _as_tensor_Apply(x, name, ndim, **kwargs):
# use Apply's default output mechanism # use Apply's default output mechanism
if (x.op.default_output is None) and (len(x.outputs) != 1): if (x.op.default_output is None) and (len(x.outputs) != 1):
raise TypeError( raise TypeError(
"Multi-output Op encountered. " "Multi-output Op without default_output encountered. "
"Retry using only one of the outputs directly." "Retry using only one of the outputs directly."
) )
......
...@@ -500,12 +500,13 @@ class TestMakeVector(utt.InferShapeTester): ...@@ -500,12 +500,13 @@ class TestMakeVector(utt.InferShapeTester):
class ApplyDefaultTestOp(Op): class ApplyDefaultTestOp(Op):
def __init__(self, id): def __init__(self, id, n_outs=1):
self.default_output = id self.default_output = id
self.n_outs = n_outs
def make_node(self, x): def make_node(self, x):
x = at.as_tensor_variable(x) x = at.as_tensor_variable(x)
return Apply(self, [x], [x.type()]) return Apply(self, [x], [x.type() for _ in range(self.n_outs)])
def perform(self, *args, **kwargs): def perform(self, *args, **kwargs):
raise NotImplementedError() raise NotImplementedError()
...@@ -556,16 +557,26 @@ class TestAsTensorVariable: ...@@ -556,16 +557,26 @@ class TestAsTensorVariable:
y = as_tensor_variable(aes.int8()) y = as_tensor_variable(aes.int8())
assert isinstance(y.owner.op, TensorFromScalar) assert isinstance(y.owner.op, TensorFromScalar)
def test_multi_outputs(self): def test_default_output(self):
good_apply_var = ApplyDefaultTestOp(0).make_node(self.x) good_apply_var = ApplyDefaultTestOp(0, n_outs=1).make_node(self.x)
as_tensor_variable(good_apply_var) as_tensor_variable(good_apply_var) is good_apply_var
bad_apply_var = ApplyDefaultTestOp(-1).make_node(self.x) good_apply_var = ApplyDefaultTestOp(-1, n_outs=1).make_node(self.x)
with pytest.raises(ValueError): as_tensor_variable(good_apply_var) is good_apply_var
bad_apply_var = ApplyDefaultTestOp(1, n_outs=1).make_node(self.x)
with pytest.raises(IndexError):
_ = as_tensor_variable(bad_apply_var) _ = as_tensor_variable(bad_apply_var)
bad_apply_var = ApplyDefaultTestOp(2).make_node(self.x) bad_apply_var = ApplyDefaultTestOp(2.0, n_outs=1).make_node(self.x)
with pytest.raises(ValueError): with pytest.raises(TypeError):
_ = as_tensor_variable(bad_apply_var)
good_apply_var = ApplyDefaultTestOp(1, n_outs=2).make_node(self.x)
as_tensor_variable(good_apply_var) is good_apply_var.outputs[1]
bad_apply_var = ApplyDefaultTestOp(None, n_outs=2).make_node(self.x)
with pytest.raises(TypeError, match="Multi-output Op without default_output"):
_ = as_tensor_variable(bad_apply_var) _ = as_tensor_variable(bad_apply_var)
def test_list(self): def test_list(self):
...@@ -578,7 +589,7 @@ class TestAsTensorVariable: ...@@ -578,7 +589,7 @@ class TestAsTensorVariable:
_ = as_tensor_variable(y) _ = as_tensor_variable(y)
bad_apply_var = ApplyDefaultTestOp([0, 1]).make_node(self.x) bad_apply_var = ApplyDefaultTestOp([0, 1]).make_node(self.x)
with pytest.raises(ValueError): with pytest.raises(TypeError):
as_tensor_variable(bad_apply_var) as_tensor_variable(bad_apply_var)
def test_ndim_strip_leading_broadcastable(self): def test_ndim_strip_leading_broadcastable(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论