提交 9f5955d0 authored 作者: James Bergstra's avatar James Bergstra

raw_random - correct infer_ndim_bcast in one case

上级 1b67301f
...@@ -254,7 +254,7 @@ def _infer_ndim_bcast(ndim, shape, *args): ...@@ -254,7 +254,7 @@ def _infer_ndim_bcast(ndim, shape, *args):
""" """
# Find the minimum value of ndim required by the *args # Find the minimum value of ndim required by the *args
if len(args) > 0: if args:
args_ndim = max(arg.ndim for arg in args) args_ndim = max(arg.ndim for arg in args)
else: else:
args_ndim = 0 args_ndim = 0
...@@ -324,11 +324,13 @@ def _infer_ndim_bcast(ndim, shape, *args): ...@@ -324,11 +324,13 @@ def _infer_ndim_bcast(ndim, shape, *args):
elif shape is None: elif shape is None:
# The number of drawn samples will be determined automatically, # The number of drawn samples will be determined automatically,
# but we need to know ndim # but we need to know ndim
v_shape = tensor.constant([], dtype='int64') if not args:
if ndim is None: raise TypeError(('_infer_ndim_bcast cannot infer shape without'
ndim = args_ndim ' either shape or args'))
bcast = [False]*ndim #TODO: retrieve broadcasting patterns of arguments template = reduce(lambda a,b:a+b, args)
v_shape = template.shape
bcast = template.broadcastable
ndim = template.ndim
else: else:
v_shape = tensor.as_tensor_variable(shape) v_shape = tensor.as_tensor_variable(shape)
if ndim is None: if ndim is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论