提交 7cb50f8b authored 作者: Dustin Webb's avatar Dustin Webb

Completed TODO in tensor.py:as_tensor_variable which said to use the Apply…

Completed TODO in tensor.py:as_tensor_variable which said to use the Apply default output mechanism. Added tests to excited the various possible failure scenarios.
上级 c85d1953
...@@ -145,13 +145,13 @@ def as_tensor_variable(x, name=None, ndim=None): ...@@ -145,13 +145,13 @@ def as_tensor_variable(x, name=None, ndim=None):
return x._as_TensorVariable() # TODO: pass name and ndim arguments return x._as_TensorVariable() # TODO: pass name and ndim arguments
if isinstance(x, gof.Apply): if isinstance(x, gof.Apply):
# TODO: use Apply's default output mechanism # use Apply's default output mechanism
if len(x.outputs) != 1: if (x.op.default_output is None) and (len(x.outputs) != 1):
raise ValueError( raise ValueError(
"It is ambiguous which output of a multi-output Op has" "It is ambiguous which output of a multi-output Op has"
" to be fetched.", x) " to be fetched.", x)
else:
x = x.outputs[0] x = x.default_output()
if isinstance(x, Variable): if isinstance(x, Variable):
if isinstance(x.type, scal.Scalar): if isinstance(x.type, scal.Scalar):
x = tensor_from_scalar(x) x = tensor_from_scalar(x)
......
...@@ -1919,6 +1919,48 @@ Allocb4GradTester = makeBroadcastTester( ...@@ -1919,6 +1919,48 @@ Allocb4GradTester = makeBroadcastTester(
) )
class ApplyDefaultTestOp(theano.Op):
def __init__(self, id):
self.default_output = id
def make_node(self, x):
x = theano.tensor.as_tensor_variable(x)
return theano.Apply(self, [x], [x.type()])
class TestApplyDefaultOutput(unittest.TestCase):
def setUp(self):
self.x = tensor.scalar('x')
def test_one_output(self):
good_apply_var = ApplyDefaultTestOp(0).make_node(self.x)
x = as_tensor_variable(good_apply_var)
def test_below_zero_output(self):
bad_apply_var = ApplyDefaultTestOp(-1).make_node(self.x)
try:
x = as_tensor_variable(bad_apply_var)
assert(False) # The above call should have failed
except AttributeError:
pass
def test_above_output_len(self):
bad_apply_var = ApplyDefaultTestOp(2).make_node(self.x)
try:
x = as_tensor_variable(bad_apply_var)
assert(False) # The above call should have failed
except AttributeError:
pass
def test_list(self):
bad_apply_var = ApplyDefaultTestOp([0, 1]).make_node(self.x)
try:
x = as_tensor_variable(bad_apply_var)
assert(False) # The above call should have failed
except AttributeError:
pass
class TestAlloc(unittest.TestCase): class TestAlloc(unittest.TestCase):
dtype = config.floatX dtype = config.floatX
mode = mode_opt mode = mode_opt
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论