提交 19580d30 authored 作者: James Bergstra's avatar James Bergstra

default broadcasting pattern for shared tensors is all False

上级 02a737df
......@@ -151,12 +151,22 @@ class TensorSharedVariable(SharedVariable, theano.tensor.basic._tensor_py_operat
pass
@shared_constructor
def tensor_constructor(value, name=None, strict=False, broadcastable=None):
"""SharedVariable Constructor for TensorType"""
"""SharedVariable Constructor for TensorType
:note: Regarding the inference of the broadcastable pattern...
The default is to assume that the value might be resized in any dimension, so the default
broadcastable is ``(False,)*len(value.shape)``. The optional `broadcastable` argument will
override this default.
"""
if not isinstance(value, numpy.ndarray):
raise TypeError()
# if no broadcastable is given, then the default is to assume that the value might be
# resized in any dimension in the future.
#
if broadcastable is None:
broadcastable = [b==1 for b in value.shape]
broadcastable = (False,)*len(value.shape)
type = TensorType(value.dtype, broadcastable=broadcastable)
return TensorSharedVariable(type=type, value=value, name=name, strict=strict)
......
......@@ -26,7 +26,7 @@ class Test_SharedVariable(unittest.TestCase):
b = shared(numpy.random.rand(4,5))
assert b.type == TensorType('float64', broadcastable=[False,False])
b = shared(numpy.random.rand(5,1,2))
assert b.type == TensorType('float64', broadcastable=[False,True,False])
assert b.type == TensorType('float64', broadcastable=[False,False,False])
assert shared([]).type == generic
def badfunc():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论