提交 6496b0c0 authored 作者: James Bergstra's avatar James Bergstra

rewrite local_mul_specialize to use alloc instead of fill_chain

上级 b38ffe57
...@@ -1345,9 +1345,9 @@ register_specialize(local_pow_specialize) ...@@ -1345,9 +1345,9 @@ register_specialize(local_pow_specialize)
@gof.local_optimizer([T.mul]) @gof.local_optimizer([T.mul])
def local_mul_specialize(node): def local_mul_specialize(node):
def fill_chain(v):
return _fill_chain(v, node.inputs)
#here, we are past the point of canonicalization, so we don't want to put in un-necessary fills. #here, we are past the point of canonicalization, so we don't want to put in un-necessary fills.
#
# at this point [post canonicalize], mul() may have many inputs.
if node.op == T.mul: if node.op == T.mul:
#the idea here is that we have pow(x, y) #the idea here is that we have pow(x, y)
neg = False neg = False
...@@ -1365,6 +1365,7 @@ def local_mul_specialize(node): ...@@ -1365,6 +1365,7 @@ def local_mul_specialize(node):
elif N.all(y == -1.0): elif N.all(y == -1.0):
neg ^= True #toggles neg ^= True #toggles
elif N.all(y == 0.0): elif N.all(y == 0.0):
# if we find any zero, we just return right away
return [T.alloc(numpy.asarray(0, dtype=node.outputs[0].dtype), return [T.alloc(numpy.asarray(0, dtype=node.outputs[0].dtype),
*node.env.shape_feature.shape_of[node.outputs[0]])] *node.env.shape_feature.shape_of[node.outputs[0]])]
else: else:
...@@ -1374,26 +1375,29 @@ def local_mul_specialize(node): ...@@ -1374,26 +1375,29 @@ def local_mul_specialize(node):
if new_inputs: if new_inputs:
if len(new_inputs) == 1: if len(new_inputs) == 1:
if neg: if neg:
msg = -new_inputs[0] rval = -new_inputs[0]
else: else:
msg = new_inputs[0] rval = new_inputs[0]
return fill_chain(msg)
else: else:
if neg: if neg:
msg = -T.mul(*new_inputs) rval = -T.mul(*new_inputs)
else: else:
msg = T.mul(*new_inputs) rval = T.mul(*new_inputs)
return [T.alloc(T.cast(msg, node.outputs[0].dtype), return [T.alloc(T.cast(rval, node.outputs[0].dtype),
*node.env.shape_feature.shape_of[node.outputs[0]])] *node.env.shape_feature.shape_of[node.outputs[0]])]
else: else:
# there are no variable inputs to mul
# N.B. this could have been constant-folded...
if neg: if neg:
# return output's worth of -1 # return output's worth of -1
return [T.alloc(numpy.asarray(-1, dtype=node.outputs[0].dtype), return [T.alloc(
numpy.asarray(-1, dtype=node.outputs[0].dtype),
*node.env.shape_feature.shape_of[node.outputs[0]])] *node.env.shape_feature.shape_of[node.outputs[0]])]
else: else:
# return output's worth of 1 # return output's worth of 1
return [T.alloc(numpy.asarray(1, dtype=node.outputs[0].dtype), return [T.alloc(
numpy.asarray(1, dtype=node.outputs[0].dtype),
*node.env.shape_feature.shape_of[node.outputs[0]])] *node.env.shape_feature.shape_of[node.outputs[0]])]
register_specialize(local_mul_specialize) register_specialize(local_mul_specialize)
......
...@@ -1013,6 +1013,51 @@ class test_shapeoptimizer(unittest.TestCase): ...@@ -1013,6 +1013,51 @@ class test_shapeoptimizer(unittest.TestCase):
print f.maker.env.toposort() print f.maker.env.toposort()
assert [] == f.maker.env.toposort() assert [] == f.maker.env.toposort()
def test_local_mul_specialize():
# test a few cases to make sure that the basics are covered
#
mode = theano.config.mode
if mode == 'FAST_COMPILE':
mode = 'FAST_RUN'
v = T.vector()
m = T.vector()
f = function([v,m], v*1, mode=mode)
nodes = [node.op for node in f.maker.env.toposort()]
print nodes
assert nodes == []
f = function([v,m], v*0, mode=mode)
nodes = [node.op for node in f.maker.env.toposort()]
print nodes
assert nodes == [Shape_i(0), T.alloc]
f = function([v,m], v*(-1), mode=mode)
nodes = [node.op for node in f.maker.env.toposort()]
print nodes
assert nodes == [T.neg]
f = function([v,m], v*1*(-m), mode=mode)
nodes = [node.op for node in f.maker.env.toposort()]
print nodes
theano.printing.debugprint(f)
assert nodes == [T.mul, inplace.neg_inplace]
f = function([v,m], v*0*(-m), mode=mode)
nodes = [node.op for node in f.maker.env.toposort()]
print nodes
theano.printing.debugprint(f)
assert nodes == [Shape_i(0), T.alloc]
f = function([v,m], v*(-1)*(-m), mode=mode)
nodes = [node.op for node in f.maker.env.toposort()]
print nodes
theano.printing.debugprint(f)
assert nodes == [T.mul]
if __name__ == '__main__': if __name__ == '__main__':
# unittest.main() # unittest.main()
test_fusion().tes_memory_leak() test_fusion().tes_memory_leak()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论