提交 4cfabc4a authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

testing infer_shape: Op ScalarFromTensor

上级 6a05466b
...@@ -1911,6 +1911,9 @@ class ScalarFromTensor(Op): ...@@ -1911,6 +1911,9 @@ class ScalarFromTensor(Op):
out, = out_ out, = out_
out[0] = s.flatten()[0] out[0] = s.flatten()[0]
def infer_shape(self, node, in_shapes):
return [()]
def grad(self, inp, grads): def grad(self, inp, grads):
s, = inp s, = inp
dt, = grads dt, = grads
......
...@@ -34,7 +34,8 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -34,7 +34,8 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
get_constant_value, ivector, reshape, scalar_from_tensor, scal, get_constant_value, ivector, reshape, scalar_from_tensor, scal,
iscalars, arange, dscalars, fvector, imatrix, numeric_grad, iscalars, arange, dscalars, fvector, imatrix, numeric_grad,
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll, opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll,
tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements) tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements,
ScalarFromTensor, TensorFromScalar)
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.printing import debugprint from theano.printing import debugprint
...@@ -6139,6 +6140,13 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6139,6 +6140,13 @@ class TestInferShape(utt.InferShapeTester):
[PermuteRowElements()(adtens, aivec, abool)], [PermuteRowElements()(adtens, aivec, abool)],
[adtens_val, aivec_val], PermuteRowElements) [adtens_val, aivec_val], PermuteRowElements)
# ScalarFromTensor
aiscal = iscalar()
aconst = constant(45)
self._compile_and_check([aiscal],
[TensorFromScalar()(ScalarFromTensor()(aiscal))],
[45], ScalarFromTensor,
excluding=["local_tensor_scalar_tensor"])
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论