提交 439e2eaf authored 作者: Frederic Bastien's avatar Frederic Bastien

make use of the new maximum and minimum op to make smaller ARange.infer_shape graph.

上级 d12cf404
......@@ -3175,9 +3175,10 @@ class ARange(Op):
if is_constant_value(start, 0):
return [(cast(stop, 'int64'),)]
else:
return [(theano.tensor.max([cast(stop-start, 'int64'),0]),)]
return [(maximum(cast(stop-start, 'int64'),0),)]
else:
return [(theano.tensor.max([cast(ceil(cast((stop-start),'float64')/step),'int64'),0]),)]
return [(maximum(cast(ceil(cast((stop-start),'float64')
/step),'int64'),0),)]
def perform(self, node, (start, stop, step), (out,)):
start = start.item()
......
......@@ -2075,8 +2075,9 @@ class TestARange(unittest.TestCase):
start, stop, step = iscalars('start', 'stop', 'step')
out = arange(start, stop, step)
f = function([start, stop, step], out.shape, mode=compile.mode.get_default_mode().excluding('fusion'))
assert len(f.maker.env.toposort())==12
#ungly graph... [DimShuffle{x}(step), DimShuffle{x}(start), DimShuffle{x}(stop), Elemwise{Sub{output_types_preference=transfer_type{0}}}[(0, 0)](DimShuffle{x}.0, DimShuffle{x}.0), Elemwise{Cast{float64}}(Elemwise{Sub{output_types_preference=transfer_type{0}}}[(0, 0)].0), Elemwise{TrueDiv{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{float64}}.0, DimShuffle{x}.0), Rebroadcast{0}(Elemwise{TrueDiv{output_types_preference=transfer_type{0}}}[(0, 0)].0), Elemwise{Ceil{output_types_preference=transfer_type{0}}}[(0, 0)](Rebroadcast{0}.0), Elemwise{Cast{int64}}(Elemwise{Ceil{output_types_preference=transfer_type{0}}}[(0, 0)].0), <theano.tensor.basic.Join object at 0x1fb1d10>(0, Elemwise{Cast{int64}}.0, [0]), MaxAndArgmax(<theano.tensor.basic.Join object at 0x1fb1d10>.0, [0]), MakeVector(max)]
assert len(f.maker.env.toposort())==7
#7 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{float64}}(Elemwise{sub,no_inplace}.0), Elemwise{TrueDiv{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{float64}}.0, step), Elemwise{Ceil{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{TrueDiv{output_types_preference=transfer_type{0}}}[(0, 0)].0), Elemwise{Cast{int64}}(Elemwise{Ceil{output_types_preference=transfer_type{0}}}[(0, 0)].0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)]
assert out.dtype == start.type.dtype
assert numpy.all(f(0,5,1) == len(numpy.arange(0,5,1)))
assert numpy.all(f(2,11,4) == len(numpy.arange(2,11,4)))
......@@ -2087,9 +2088,8 @@ class TestARange(unittest.TestCase):
out = arange(start, stop, 1)
f = function([start, stop], out.shape, mode=compile.mode.get_default_mode().excluding('fusion'))
assert len(f.maker.env.toposort())==8
#ungly graph... [DimShuffle{x}(start), DimShuffle{x}(stop), Elemwise{Sub{output_types_preference=transfer_type{0}}}[(0, 0)](DimShuffle{x}.0, DimShuffle{x}.0), Rebroadcast{0}(Elemwise{Sub{output_types_preference=transfer_type{0}}}[(0, 0)].0), Elemwise{Cast{int64}}(Rebroadcast{0}.0), <theano.tensor.basic.Join object at 0x1fb1d10>(0, Elemwise{Cast{int64}}.0, [0]), MaxAndArgmax(<theano.tensor.basic.Join object at 0x1fb1d10>.0, [0]), MakeVector(max)]
assert len(f.maker.env.toposort())==4
#4 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{int64}}(Elemwise{sub,no_inplace}.0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)]
assert out.dtype == start.type.dtype
assert numpy.all(f(0,5) == len(numpy.arange(0,5)))
assert numpy.all(f(2,11) == len(numpy.arange(2,11)))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论