提交 c2ede260 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Simplify python implementation of ScalarFromTensor

上级 47a15c6a
......@@ -678,10 +678,9 @@ class ScalarFromTensor(COp):
self, [t], [ps.get_scalar_type(dtype=t.type.dtype).make_variable()]
)
def perform(self, node, inp, out_):
(s,) = inp
(out,) = out_
out[0] = s.flatten()[0]
def perform(self, node, inputs, output_storage):
# not using .item() because that returns a Python scalar, not a numpy scalar
output_storage[0][0] = inputs[0][()]
def infer_shape(self, fgraph, node, in_shapes):
return [()]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论