提交 a2827cec authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

testing infer_shape: Op TensorFromScalar

上级 4cfabc4a
...@@ -1881,6 +1881,9 @@ class TensorFromScalar(Op): ...@@ -1881,6 +1881,9 @@ class TensorFromScalar(Op):
out, = out_ out, = out_
out[0] = numpy.asarray(s) out[0] = numpy.asarray(s)
def infer_shape(self, node, in_shapes):
return [()]
def grad(self, inp, grads): def grad(self, inp, grads):
s, = inp s, = inp
dt, = grads dt, = grads
......
...@@ -6148,6 +6148,13 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6148,6 +6148,13 @@ class TestInferShape(utt.InferShapeTester):
[45], ScalarFromTensor, [45], ScalarFromTensor,
excluding=["local_tensor_scalar_tensor"]) excluding=["local_tensor_scalar_tensor"])
# TensorFromScalar:
aiscal = scal.float64()
self._compile_and_check([aiscal],
[TensorFromScalar()(aiscal)],
[4.], TensorFromScalar)
if __name__ == '__main__': if __name__ == '__main__':
t = TestInferShape('setUp') t = TestInferShape('setUp')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论