提交 85e09e63 authored 作者: Dustin Webb's avatar Dustin Webb

Completed TODO in as_tensor_variable which said to strip off leading broadcastable dimensions.

上级 c85d1953
......@@ -164,10 +164,10 @@ def as_tensor_variable(x, name=None, ndim=None):
return x
else:
if (x.type.ndim > ndim):
# TODO: strip off leading broadcastable dimensions
raise ValueError(
'TensorType could not be cast to have %i dimensions' %
ndim, x.type)
# strip off leading broadcastable dimensions
first_non_broadcastable = [idx for idx in range(x.ndim)
if x.broadcastable[idx] == False][0]
return x.dimshuffle(range(x.ndim)[first_non_broadcastable:])
elif (x.type.ndim < ndim):
return shape_padleft(x, n_ones=(ndim - x.type.ndim))
else:
......
......@@ -1919,6 +1919,12 @@ Allocb4GradTester = makeBroadcastTester(
)
def test_as_tensor_variable():
x = tensor.TensorType(config.floatX, (True, False))()
x = as_tensor_variable(x, ndim=1)
assert(x.ndim == 1)
class TestAlloc(unittest.TestCase):
dtype = config.floatX
mode = mode_opt
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论