提交 9f5955d0 authored 作者: James Bergstra's avatar James Bergstra

raw_random - correct infer_ndim_bcast in one case

上级 1b67301f
......@@ -254,7 +254,7 @@ def _infer_ndim_bcast(ndim, shape, *args):
"""
# Find the minimum value of ndim required by the *args
if len(args) > 0:
if args:
args_ndim = max(arg.ndim for arg in args)
else:
args_ndim = 0
......@@ -324,11 +324,13 @@ def _infer_ndim_bcast(ndim, shape, *args):
elif shape is None:
# The number of drawn samples will be determined automatically,
# but we need to know ndim
v_shape = tensor.constant([], dtype='int64')
if ndim is None:
ndim = args_ndim
bcast = [False]*ndim #TODO: retrieve broadcasting patterns of arguments
if not args:
raise TypeError(('_infer_ndim_bcast cannot infer shape without'
' either shape or args'))
template = reduce(lambda a,b:a+b, args)
v_shape = template.shape
bcast = template.broadcastable
ndim = template.ndim
else:
v_shape = tensor.as_tensor_variable(shape)
if ndim is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论