提交 4cfabc4a authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

testing infer_shape: Op ScalarFromTensor

上级 6a05466b
......@@ -1911,6 +1911,9 @@ class ScalarFromTensor(Op):
out, = out_
out[0] = s.flatten()[0]
def infer_shape(self, node, in_shapes):
return [()]
def grad(self, inp, grads):
s, = inp
dt, = grads
......
......@@ -34,7 +34,8 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
get_constant_value, ivector, reshape, scalar_from_tensor, scal,
iscalars, arange, dscalars, fvector, imatrix, numeric_grad,
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll,
tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements)
tile, patternbroadcast, Eye, Shape, Default, Dot, PermuteRowElements,
ScalarFromTensor, TensorFromScalar)
from theano.tests import unittest_tools as utt
from theano.printing import debugprint
......@@ -6139,6 +6140,13 @@ class TestInferShape(utt.InferShapeTester):
[PermuteRowElements()(adtens, aivec, abool)],
[adtens_val, aivec_val], PermuteRowElements)
# ScalarFromTensor
aiscal = iscalar()
aconst = constant(45)
self._compile_and_check([aiscal],
[TensorFromScalar()(ScalarFromTensor()(aiscal))],
[45], ScalarFromTensor,
excluding=["local_tensor_scalar_tensor"])
if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论