提交 c2ede260 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Simplify python implementation of ScalarFromTensor

上级 47a15c6a
...@@ -678,10 +678,9 @@ class ScalarFromTensor(COp): ...@@ -678,10 +678,9 @@ class ScalarFromTensor(COp):
self, [t], [ps.get_scalar_type(dtype=t.type.dtype).make_variable()] self, [t], [ps.get_scalar_type(dtype=t.type.dtype).make_variable()]
) )
def perform(self, node, inp, out_): def perform(self, node, inputs, output_storage):
(s,) = inp # not using .item() because that returns a Python scalar, not a numpy scalar
(out,) = out_ output_storage[0][0] = inputs[0][()]
out[0] = s.flatten()[0]
def infer_shape(self, fgraph, node, in_shapes): def infer_shape(self, fgraph, node, in_shapes):
return [()] return [()]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论