提交 85e09e63 authored 作者: Dustin Webb's avatar Dustin Webb

Completed TODO in as_tensor_variable which said to strip off leading broadcastable dimensions.

上级 c85d1953
...@@ -164,10 +164,10 @@ def as_tensor_variable(x, name=None, ndim=None): ...@@ -164,10 +164,10 @@ def as_tensor_variable(x, name=None, ndim=None):
return x return x
else: else:
if (x.type.ndim > ndim): if (x.type.ndim > ndim):
# TODO: strip off leading broadcastable dimensions # strip off leading broadcastable dimensions
raise ValueError( first_non_broadcastable = [idx for idx in range(x.ndim)
'TensorType could not be cast to have %i dimensions' % if x.broadcastable[idx] == False][0]
ndim, x.type) return x.dimshuffle(range(x.ndim)[first_non_broadcastable:])
elif (x.type.ndim < ndim): elif (x.type.ndim < ndim):
return shape_padleft(x, n_ones=(ndim - x.type.ndim)) return shape_padleft(x, n_ones=(ndim - x.type.ndim))
else: else:
......
...@@ -1919,6 +1919,12 @@ Allocb4GradTester = makeBroadcastTester( ...@@ -1919,6 +1919,12 @@ Allocb4GradTester = makeBroadcastTester(
) )
def test_as_tensor_variable():
x = tensor.TensorType(config.floatX, (True, False))()
x = as_tensor_variable(x, ndim=1)
assert(x.ndim == 1)
class TestAlloc(unittest.TestCase): class TestAlloc(unittest.TestCase):
dtype = config.floatX dtype = config.floatX
mode = mode_opt mode = mode_opt
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论