提交 0e24caaf authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

testing infer_shape: Op Default

上级 a50646d5
...@@ -3410,6 +3410,14 @@ class Default(gof.Op): ...@@ -3410,6 +3410,14 @@ class Default(gof.Op):
else: else:
out[0] = x out[0] = x
def infer_shape(self, node, in_shapes):
if node.inputs[0] is None:
out_shape = in_shapes[1]
else:
out_shape = in_shapes[0]
return [out_shape]
default = Default() default = Default()
setdefault = default # legacy setdefault = default # legacy
......
...@@ -34,7 +34,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -34,7 +34,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
get_constant_value, ivector, reshape, scalar_from_tensor, scal, get_constant_value, ivector, reshape, scalar_from_tensor, scal,
iscalars, arange, dscalars, fvector, imatrix, numeric_grad, iscalars, arange, dscalars, fvector, imatrix, numeric_grad,
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll, opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll,
tile, patternbroadcast, Eye, Shape) tile, patternbroadcast, Eye, Shape, Default)
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.printing import debugprint from theano.printing import debugprint
...@@ -6083,6 +6083,12 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6083,6 +6083,12 @@ class TestInferShape(utt.InferShapeTester):
[Shape()(adtens)], [Shape()(adtens)],
[adtens_val], (opt.MakeVector, Shape)) [adtens_val], (opt.MakeVector, Shape))
# Default
self._compile_and_check([admat, bdmat],
[Default()(admat, bdmat)],
[admat_val, bdmat_val], (Default))
if __name__ == '__main__': if __name__ == '__main__':
t = TestInferShape('setUp') t = TestInferShape('setUp')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论