提交 a3ea318e authored 作者: Frederic Bastien's avatar Frederic Bastien

fix syntax error of tensor.ARange.infer_shape found by Justin Brayer and fix…

fix syntax error of tensor.ARange.infer_shape found by Justin Brayer and fix behavior error found while making test for that case. Added test for this. The new graph are ugly... 12 ops with dimshuffles and rebroardcast...
上级 c6104ebc
......@@ -3122,27 +3122,20 @@ class ARange(Op):
def infer_shape(self, node, i_shapes):
start, stop, step = node.inputs
def is_constant_value(var, value):
if numpy.all(var == value):
return True
if isinstance(var, gof.Constant):
return numpy.all(var.data == value)
if var.owner:
if var.owner.op == T.alloc:
return is_constant_value(var.owner.inputs[0], value)
if isinstance(var.owner.op, DimShuffle):
return is_constant_value(var.owner.inputs[0], value)
if var.owner.op == T.fill:
shape, in_var = var.owner.inputs
return is_constant_value(in_var, value)
try:
v = get_constant_value(var)
return numpy.all(v == value)
except:
pass
return False
if is_constant_value(step, 1):
if is_constant_value(start, 0):
return [(cast(stop, 'int64'),)]
else:
return [(cast(stop-start, 'int64'),)]
return [(theano.tensor.max([cast(stop-start, 'int64'),0]),)]
else:
return [(cast(ceil(cast((stop-start),'float64')/step),'int64'),)]
return [(theano.tensor.max([cast(ceil(cast((stop-start),'float64')/step),'int64'),0]),)]
def perform(self, node, (start, stop, step), (out,)):
start = start.item()
......
......@@ -2028,6 +2028,45 @@ class TestARange(unittest.TestCase):
assert out2.owner.op is out3.owner.op
assert out3.owner.op is not out4.owner.op
def test_infer_shape(self):
start, stop, step = iscalars('start', 'stop', 'step')
out = arange(start, stop, step)
f = function([start, stop, step], out.shape, mode=compile.mode.get_default_mode().excluding('fusion'))
assert len(f.maker.env.toposort())==12
#ungly graph... [DimShuffle{x}(step), DimShuffle{x}(start), DimShuffle{x}(stop), Elemwise{Sub{output_types_preference=transfer_type{0}}}[(0, 0)](DimShuffle{x}.0, DimShuffle{x}.0), Elemwise{Cast{float64}}(Elemwise{Sub{output_types_preference=transfer_type{0}}}[(0, 0)].0), Elemwise{TrueDiv{output_types_preference=transfer_type{0}}}[(0, 0)](Elemwise{Cast{float64}}.0, DimShuffle{x}.0), Rebroadcast{0}(Elemwise{TrueDiv{output_types_preference=transfer_type{0}}}[(0, 0)].0), Elemwise{Ceil{output_types_preference=transfer_type{0}}}[(0, 0)](Rebroadcast{0}.0), Elemwise{Cast{int64}}(Elemwise{Ceil{output_types_preference=transfer_type{0}}}[(0, 0)].0), <theano.tensor.basic.Join object at 0x1fb1d10>(0, Elemwise{Cast{int64}}.0, [0]), MaxAndArgmax(<theano.tensor.basic.Join object at 0x1fb1d10>.0, [0]), MakeVector(max)]
assert out.dtype == start.type.dtype
assert numpy.all(f(0,5,1) == len(numpy.arange(0,5,1)))
assert numpy.all(f(2,11,4) == len(numpy.arange(2,11,4)))
assert numpy.all(f(-5,1,1) == len(numpy.arange(-5,1,1)))
assert numpy.all(f(10,2,-2) == len(numpy.arange(10,2,-2)))
assert numpy.all(f(10,2,2) == len(numpy.arange(10,2,2)))
assert numpy.all(f(0,0,1) == len(numpy.arange(0,0,1)))
out = arange(start, stop, 1)
f = function([start, stop], out.shape, mode=compile.mode.get_default_mode().excluding('fusion'))
assert len(f.maker.env.toposort())==8
#ungly graph... [DimShuffle{x}(start), DimShuffle{x}(stop), Elemwise{Sub{output_types_preference=transfer_type{0}}}[(0, 0)](DimShuffle{x}.0, DimShuffle{x}.0), Rebroadcast{0}(Elemwise{Sub{output_types_preference=transfer_type{0}}}[(0, 0)].0), Elemwise{Cast{int64}}(Rebroadcast{0}.0), <theano.tensor.basic.Join object at 0x1fb1d10>(0, Elemwise{Cast{int64}}.0, [0]), MaxAndArgmax(<theano.tensor.basic.Join object at 0x1fb1d10>.0, [0]), MakeVector(max)]
assert out.dtype == start.type.dtype
assert numpy.all(f(0,5) == len(numpy.arange(0,5)))
assert numpy.all(f(2,11) == len(numpy.arange(2,11)))
assert numpy.all(f(-5,1) == len(numpy.arange(-5,1)))
assert numpy.all(f(10,2) == len(numpy.arange(10,2)))
assert numpy.all(f(10,2) == len(numpy.arange(10,2)))
assert numpy.all(f(0,0) == len(numpy.arange(0,0)))
out = arange(0, stop, 1)
f = function([stop], out.shape, mode=compile.mode.get_default_mode().excluding('fusion'))
assert len(f.maker.env.toposort())==2
#[Elemwise{Cast{int64}}(stop), MakeVector(Elemwise{Cast{int64}}.0)]
assert out.dtype == start.type.dtype
assert numpy.all(f(5) == len(numpy.arange(0,5)))
assert numpy.all(f(11) == len(numpy.arange(0,11)))
assert numpy.all(f(1) == len(numpy.arange(0,1)))
assert numpy.all(f(2) == len(numpy.arange(0,2)))
assert numpy.all(f(2) == len(numpy.arange(0,2)))
assert numpy.all(f(0) == len(numpy.arange(0,0)))
class TestInversePermutation(unittest.TestCase):
def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论