提交 2b06fa16 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Thomas Wiecki

Use np.shape instead of var.shape in theano.tensor.basic.get_scalar_constant_value

上级 7736732d
...@@ -440,7 +440,7 @@ def get_scalar_constant_value( ...@@ -440,7 +440,7 @@ def get_scalar_constant_value(
i = v.owner.op.i i = v.owner.op.i
inp = v.owner.inputs[0] inp = v.owner.inputs[0]
if isinstance(inp, Constant): if isinstance(inp, Constant):
return np.asarray(inp.data.shape[i]) return np.asarray(np.shape(inp.data)[i])
# The shape of a broadcastable dimension is 1 # The shape of a broadcastable dimension is 1
if hasattr(inp.type, "broadcastable") and inp.type.broadcastable[i]: if hasattr(inp.type, "broadcastable") and inp.type.broadcastable[i]:
return np.asarray(1) return np.asarray(1)
...@@ -638,7 +638,7 @@ def get_scalar_constant_value( ...@@ -638,7 +638,7 @@ def get_scalar_constant_value(
return np.asarray(1) return np.asarray(1)
if isinstance(grandparent, Constant): if isinstance(grandparent, Constant):
return np.asarray(grandparent.data.shape[idx]) return np.asarray(np.shape(grandparent.data)[idx])
raise NotScalarConstantError(v) raise NotScalarConstantError(v)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论