提交 ee3e49e4 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2083 from daemonmaker/todo_astensorvariable

Completed TODO in as_tensor_variable which said to strip off leading bro...
......@@ -164,10 +164,15 @@ def as_tensor_variable(x, name=None, ndim=None):
return x
else:
if (x.type.ndim > ndim):
# TODO: strip off leading broadcastable dimensions
raise ValueError(
'TensorType could not be cast to have %i dimensions' %
ndim, x.type)
# strip off leading broadcastable dimensions
first_non_broadcastable = [idx for idx in range(x.ndim)
if x.broadcastable[idx] == False][0]
x = x.dimshuffle(range(x.ndim)[first_non_broadcastable:])
if x.ndim > ndim:
raise ValueError(
'TensorType could not be cast to have %i dimensions' % ndim, x.type
)
return x
elif (x.type.ndim < ndim):
return shape_padleft(x, n_ones=(ndim - x.type.ndim))
else:
......
......@@ -1919,6 +1919,19 @@ Allocb4GradTester = makeBroadcastTester(
)
def test_as_tensor_variable():
x = tensor.TensorType(config.floatX, (True, False))('x')
x = as_tensor_variable(x, ndim=1)
assert(x.ndim == 1)
x = tensor.matrix('x', dtype=config.floatX)
try:
x = as_tensor_variable(x, ndim=1)
assert(False) # The call above should have failed
except ValueError:
pass
class TestAlloc(unittest.TestCase):
dtype = config.floatX
mode = mode_opt
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论