提交 ee3e49e4 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2083 from daemonmaker/todo_astensorvariable

Completed TODO in as_tensor_variable which said to strip off leading bro...
...@@ -164,10 +164,15 @@ def as_tensor_variable(x, name=None, ndim=None): ...@@ -164,10 +164,15 @@ def as_tensor_variable(x, name=None, ndim=None):
return x return x
else: else:
if (x.type.ndim > ndim): if (x.type.ndim > ndim):
# TODO: strip off leading broadcastable dimensions # strip off leading broadcastable dimensions
raise ValueError( first_non_broadcastable = [idx for idx in range(x.ndim)
'TensorType could not be cast to have %i dimensions' % if x.broadcastable[idx] == False][0]
ndim, x.type) x = x.dimshuffle(range(x.ndim)[first_non_broadcastable:])
if x.ndim > ndim:
raise ValueError(
'TensorType could not be cast to have %i dimensions' % ndim, x.type
)
return x
elif (x.type.ndim < ndim): elif (x.type.ndim < ndim):
return shape_padleft(x, n_ones=(ndim - x.type.ndim)) return shape_padleft(x, n_ones=(ndim - x.type.ndim))
else: else:
......
...@@ -1919,6 +1919,19 @@ Allocb4GradTester = makeBroadcastTester( ...@@ -1919,6 +1919,19 @@ Allocb4GradTester = makeBroadcastTester(
) )
def test_as_tensor_variable():
x = tensor.TensorType(config.floatX, (True, False))('x')
x = as_tensor_variable(x, ndim=1)
assert(x.ndim == 1)
x = tensor.matrix('x', dtype=config.floatX)
try:
x = as_tensor_variable(x, ndim=1)
assert(False) # The call above should have failed
except ValueError:
pass
class TestAlloc(unittest.TestCase): class TestAlloc(unittest.TestCase):
dtype = config.floatX dtype = config.floatX
mode = mode_opt mode = mode_opt
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论