提交 a2827cec authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

testing infer_shape: Op TensorFromScalar

上级 4cfabc4a
......@@ -1881,6 +1881,9 @@ class TensorFromScalar(Op):
out, = out_
out[0] = numpy.asarray(s)
def infer_shape(self, node, in_shapes):
return [()]
def grad(self, inp, grads):
s, = inp
dt, = grads
......
......@@ -6148,6 +6148,13 @@ class TestInferShape(utt.InferShapeTester):
[45], ScalarFromTensor,
excluding=["local_tensor_scalar_tensor"])
# TensorFromScalar:
aiscal = scal.float64()
self._compile_and_check([aiscal],
[TensorFromScalar()(aiscal)],
[4.], TensorFromScalar)
if __name__ == '__main__':
t = TestInferShape('setUp')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论