提交 439e2eaf authored 作者: Frederic Bastien's avatar Frederic Bastien

make use of the new maximum and minimum op to make smaller ARange.infer_shape graph.

上级 d12cf404
...@@ -3175,9 +3175,10 @@ class ARange(Op): ...@@ -3175,9 +3175,10 @@ class ARange(Op):
if is_constant_value(start, 0): if is_constant_value(start, 0):
return [(cast(stop, 'int64'),)] return [(cast(stop, 'int64'),)]
else: else:
return [(theano.tensor.max([cast(stop-start, 'int64'),0]),)] return [(maximum(cast(stop-start, 'int64'),0),)]
else: else:
return [(theano.tensor.max([cast(ceil(cast((stop-start),'float64')/step),'int64'),0]),)] return [(maximum(cast(ceil(cast((stop-start),'float64')
/step),'int64'),0),)]
def perform(self, node, (start, stop, step), (out,)): def perform(self, node, (start, stop, step), (out,)):
start = start.item() start = start.item()
......
...@@ -2075,8 +2075,9 @@ class TestARange(unittest.TestCase): ...@@ -2075,8 +2075,9 @@ class TestARange(unittest.TestCase):
start, stop, step = iscalars('start', 'stop', 'step') start, stop, step = iscalars('start', 'stop', 'step')
out = arange(start, stop, step) out = arange(start, stop, step)
f = function([start, stop, step], out.shape, mode=compile.mode.get_default_mode().excluding('fusion')) f = function([start, stop, step], out.shape, mode=compile.mode.get_default_mode().excluding('fusion'))
assert len(f.maker.env.toposort())==12 assert len(f.maker.env.toposort())==7
#ungly graph... [DimShuffle{x}(step), DimShuffle{x}(start), DimShuffle{x}(stop), Elemwise{Sub{output_types_preference=transfer_type{0}}}[(0, 0)](DimShuffle{x}.0, DimShuffle{x}.0), Elemwise{Cast{float64}}(Elemwise{Sub{output_types_preference=transfer_type{0}}}[(0, 0)].0), Elemwise{TrueDiv{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{float64}}.0, DimShuffle{x}.0), Rebroadcast{0}(Elemwise{TrueDiv{output_types_preference=transfer_type{0}}}[(0, 0)].0), Elemwise{Ceil{output_types_preference=transfer_type{0}}}[(0, 0)](Rebroadcast{0}.0), Elemwise{Cast{int64}}(Elemwise{Ceil{output_types_preference=transfer_type{0}}}[(0, 0)].0), <theano.tensor.basic.Join object at 0x1fb1d10>(0, Elemwise{Cast{int64}}.0, [0]), MaxAndArgmax(<theano.tensor.basic.Join object at 0x1fb1d10>.0, [0]), MakeVector(max)] #7 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{float64}}(Elemwise{sub,no_inplace}.0), Elemwise{TrueDiv{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{float64}}.0, step), Elemwise{Ceil{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{TrueDiv{output_types_preference=transfer_type{0}}}[(0, 0)].0), Elemwise{Cast{int64}}(Elemwise{Ceil{output_types_preference=transfer_type{0}}}[(0, 0)].0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)]
assert out.dtype == start.type.dtype assert out.dtype == start.type.dtype
assert numpy.all(f(0,5,1) == len(numpy.arange(0,5,1))) assert numpy.all(f(0,5,1) == len(numpy.arange(0,5,1)))
assert numpy.all(f(2,11,4) == len(numpy.arange(2,11,4))) assert numpy.all(f(2,11,4) == len(numpy.arange(2,11,4)))
...@@ -2087,9 +2088,8 @@ class TestARange(unittest.TestCase): ...@@ -2087,9 +2088,8 @@ class TestARange(unittest.TestCase):
out = arange(start, stop, 1) out = arange(start, stop, 1)
f = function([start, stop], out.shape, mode=compile.mode.get_default_mode().excluding('fusion')) f = function([start, stop], out.shape, mode=compile.mode.get_default_mode().excluding('fusion'))
assert len(f.maker.env.toposort())==8 assert len(f.maker.env.toposort())==4
#ungly graph... [DimShuffle{x}(start), DimShuffle{x}(stop), Elemwise{Sub{output_types_preference=transfer_type{0}}}[(0, 0)](DimShuffle{x}.0, DimShuffle{x}.0), Rebroadcast{0}(Elemwise{Sub{output_types_preference=transfer_type{0}}}[(0, 0)].0), Elemwise{Cast{int64}}(Rebroadcast{0}.0), <theano.tensor.basic.Join object at 0x1fb1d10>(0, Elemwise{Cast{int64}}.0, [0]), MaxAndArgmax(<theano.tensor.basic.Join object at 0x1fb1d10>.0, [0]), MakeVector(max)] #4 [Elemwise{sub,no_inplace}(stop, start), Elemwise{Cast{int64}}(Elemwise{sub,no_inplace}.0), Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{int64}}.0, 0), MakeVector(Elemwise{Maximum{output_types_preference=transfer_type{0}}}[(0, 0)].0)]
assert out.dtype == start.type.dtype assert out.dtype == start.type.dtype
assert numpy.all(f(0,5) == len(numpy.arange(0,5))) assert numpy.all(f(0,5) == len(numpy.arange(0,5)))
assert numpy.all(f(2,11) == len(numpy.arange(2,11))) assert numpy.all(f(2,11) == len(numpy.arange(2,11)))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论