提交 87c93405 authored 作者: James Bergstra's avatar James Bergstra

Added optional broadcastable argument to shared variable constructor for

tensors.
上级 2efcc604
......@@ -138,13 +138,14 @@ def generic_constructor(value, name=None, strict=False):
class TensorSharedVariable(SharedVariable, theano.tensor.basic._tensor_py_operators):
pass
@shared_constructor
def tensor_constructor(value, name=None, strict=False):
def tensor_constructor(value, name=None, strict=False, broadcastable=None):
"""SharedVariable Constructor for TensorType"""
if not isinstance(value, numpy.ndarray):
raise TypeError()
bcast = [b==1 for b in value.shape]
type = TensorType(value.dtype, broadcastable=bcast)
if broadcastable is None:
broadcastable = [b==1 for b in value.shape]
type = TensorType(value.dtype, broadcastable=broadcastable)
return TensorSharedVariable(type=type, value=value, name=name, strict=strict)
@shared_constructor
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论