提交 85db9b9d authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1608 from lamblin/fix_rng_empty_shape

Fix rng empty shape
...@@ -10,6 +10,7 @@ import warnings ...@@ -10,6 +10,7 @@ import warnings
import numpy import numpy
from theano import Op, Apply, shared, config, Variable from theano import Op, Apply, shared, config, Variable
from theano import tensor
from theano.tensor import (raw_random, TensorType, as_tensor_variable, from theano.tensor import (raw_random, TensorType, as_tensor_variable,
get_vector_length, cast, opt, scal) get_vector_length, cast, opt, scal)
from theano.tensor import sqrt, log, sin, cos, join, prod from theano.tensor import sqrt, log, sin, cos, join, prod
...@@ -890,7 +891,8 @@ class MRG_RandomStreams(object): ...@@ -890,7 +891,8 @@ class MRG_RandomStreams(object):
constant = False constant = False
if isinstance(size, tuple) and all([isinstance(i, (numpy.integer, int)) for i in size]): if isinstance(size, tuple) and all([isinstance(i, (numpy.integer, int)) for i in size]):
constant = True constant = True
n_samples = numpy.prod(size) # Force dtype because it defaults to float when size is empty
n_samples = numpy.prod(size, dtype='int64')
if n_samples % 2 == 1: if n_samples % 2 == 1:
n_samples += 1 n_samples += 1
...@@ -928,8 +930,10 @@ class MRG_RandomStreams(object): ...@@ -928,8 +930,10 @@ class MRG_RandomStreams(object):
else: else:
final_samples = normal_samples[:prod(size)] final_samples = normal_samples[:prod(size)]
if size: if not size:
final_samples = final_samples.reshape(size) # Force the dtype to be int64, otherwise reshape complains
size = tensor.constant(size, dtype='int64')
final_samples = final_samples.reshape(size)
final_samples = avg + std * final_samples final_samples = avg + std * final_samples
......
...@@ -351,7 +351,10 @@ def _infer_ndim_bcast(ndim, shape, *args): ...@@ -351,7 +351,10 @@ def _infer_ndim_bcast(ndim, shape, *args):
ValueError('negative shape', s) ValueError('negative shape', s)
# post-condition: shape may still contain both symbolic and # post-condition: shape may still contain both symbolic and
# non-symbolic things # non-symbolic things
v_shape = tensor.stack(*pre_v_shape) if len(pre_v_shape) == 0:
v_shape = tensor.constant([], dtype='int32')
else:
v_shape = tensor.stack(*pre_v_shape)
elif shape is None: elif shape is None:
# The number of drawn samples will be determined automatically, # The number of drawn samples will be determined automatically,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论