提交 0e24caaf authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

testing infer_shape: Op Default

上级 a50646d5
......@@ -3410,6 +3410,14 @@ class Default(gof.Op):
else:
out[0] = x
def infer_shape(self, node, in_shapes):
if node.inputs[0] is None:
out_shape = in_shapes[1]
else:
out_shape = in_shapes[0]
return [out_shape]
default = Default()
setdefault = default # legacy
......
......@@ -34,7 +34,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
get_constant_value, ivector, reshape, scalar_from_tensor, scal,
iscalars, arange, dscalars, fvector, imatrix, numeric_grad,
opt, ComplexError, TensorDot, lvector, true_div, max, min, Split, roll,
tile, patternbroadcast, Eye, Shape)
tile, patternbroadcast, Eye, Shape, Default)
from theano.tests import unittest_tools as utt
from theano.printing import debugprint
......@@ -6083,6 +6083,12 @@ class TestInferShape(utt.InferShapeTester):
[Shape()(adtens)],
[adtens_val], (opt.MakeVector, Shape))
# Default
self._compile_and_check([admat, bdmat],
[Default()(admat, bdmat)],
[admat_val, bdmat_val], (Default))
if __name__ == '__main__':
t = TestInferShape('setUp')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论