提交 97ba90ce authored 作者: Frederic Bastien's avatar Frederic Bastien

modif canonicaly to lift toward input abs throught mul and true_div. This allow…

modif canonicaly to lift toward input abs throught mul and true_div. This allow the canonicalizer simplier check_for_x_over_absX to cover more case.
上级 5566207f
...@@ -2010,6 +2010,42 @@ def check_for_x_over_absX(numerators, denominators): ...@@ -2010,6 +2010,42 @@ def check_for_x_over_absX(numerators, denominators):
return numerators, denominators return numerators, denominators
local_mul_canonizer.add_simplifier(check_for_x_over_absX, 'X_over_absX') local_mul_canonizer.add_simplifier(check_for_x_over_absX, 'X_over_absX')
@register_canonicalize
@gof.local_optimizer([T.abs_])
def local_abs_lift(node):
"""
move the abs toward the input. This is needed for check_for_x_over_absX to apply in more case.
"""
if node.op == T.abs_ and node.inputs[0].owner:
assert node.nin == 1
if node.inputs[0].owner.op == T.mul:
return [T.mul(*[T.abs_(i) for i in node.inputs[0].owner.inputs])]
if node.inputs[0].owner.op == T.true_div:
i = node.inputs[0].owner.inputs
return [T.true_div(T.abs_(i[0]),T.abs_(i[1]))]
@register_specialize
@gof.local_optimizer([])
def local_abs_merge(node):
"""
merge abs generated by local_abs_lift when the canonizer don't need it anymore
"""
if node.op == T.mul and sum([i.owner.op == T.abs_ for i in node.inputs if i.owner])>1:
inputs = []
for i in node.inputs:
if i.owner and i.owner.op == T.abs_:
inputs.append(i.owner.inputs[0])
else:
const = get_constant_value(i)
if not (const>=0).all():
return False
inputs.append(i)
return [T.abs_(T.mul(*inputs))]
if node.op == T.true_div and sum([i.owner.op == T.abs_ for i in node.inputs if i.owner])==2:
return [T.abs_(T.true_div(node.inputs[0].owner.inputs[0],node.inputs[1].owner.inputs[0]))]
@register_stabilize @register_stabilize
@gof.local_optimizer([T.log]) @gof.local_optimizer([T.log])
def local_log1p(node): def local_log1p(node):
......
...@@ -484,9 +484,54 @@ class test_canonize(unittest.TestCase): ...@@ -484,9 +484,54 @@ class test_canonize(unittest.TestCase):
assert numpy.all(numpy.isfinite(out)) assert numpy.all(numpy.isfinite(out))
assert numpy.allclose(out,numpy.sign(val_inputs[0])) assert numpy.allclose(out,numpy.sign(val_inputs[0]))
assert(out_dtype==out.dtype) assert(out_dtype==out.dtype)
assert len(f.maker.env.toposort())==1
#test (2*x) / (3*abs(x)) -> sign(x)
for id,(g, sym_inputs, val_inputs, out_dtype) in enumerate([
((2*dx)/(3*abs(dx)),[dx],[0.5-dxv],'float64'),
((2*fx)/(3*abs(fx)),[fx],[0.5-fxv],'float32'),
((2*dx)/(3*abs(dx)),[dx],[0.0*dxv],'float64'),
((2*fx)/(3*abs(fx)),[fx],[0.0*fxv],'float32'),
((2*dv)/(3*abs(dv)),[dv],[0.5-dvv],'float64'),
((2*fv)/(3*abs(fv)),[fv],[0.5-fvv],'float32'),
]):
f = compile.function(list(sym_inputs), g,
mode=mode)
topo = f.maker.env.toposort()
out = f(*val_inputs)
assert numpy.all(numpy.isfinite(out))
assert numpy.allclose(out,numpy.sign(val_inputs[0])*2/3)
assert(out_dtype==out.dtype)
finally: finally:
mode._optimizer = old_optimizer mode._optimizer = old_optimizer
def test_abs_mul_div(self):
"""
test that if we have
4 * x / abs(2*x) it get simplifier during canonicalisation.
"""
x=T.dscalar()
a=T.abs_(x)
mode = theano.compile.mode.get_default_mode().excluding("local_elemwise_fusion")
f=theano.function([x],[(4*x)/abs(2*x)], mode = mode)
print f.maker.env.toposort()
print
f(.1)
f(0)
f(-1)
assert len(f.maker.env.toposort())==2
f=theano.function([x],[(4*x)/abs(2/x)], mode = mode)
print f.maker.env.toposort()
print
f(.1)
f(0)
f(-1)
assert len(f.maker.env.toposort())==2
assert f.maker.env.toposort()[0].op==T.abs_
def test_multiple_case_that_fail(self): def test_multiple_case_that_fail(self):
import theano.tensor, theano.compile import theano.tensor, theano.compile
...@@ -553,6 +598,30 @@ class test_canonize(unittest.TestCase): ...@@ -553,6 +598,30 @@ class test_canonize(unittest.TestCase):
""" """
raise SkipTest("Not implemented") raise SkipTest("Not implemented")
def test_local_merge_abs():
x,y,z = T.matrices('xyz')
x_val = numpy.random.rand(5,5)
y_val = numpy.random.rand(5,5)
z_val = numpy.random.rand(5,5)
mode = theano.config.mode
if mode == "FAST_COMPILE":
mode = "FAST_RUN"
mode = theano.compile.mode.get_mode(mode).excluding("local_elemwise_fusion")
f = theano.function([x,y,z],(abs(y*z*-2)), mode=mode)
f(x_val,y_val,z_val)
theano.printing.debugprint(f)
assert isinstance(f.maker.env.toposort()[1].op.scalar_op, scal.Abs)
assert len(f.maker.env.toposort())==2
f = theano.function([x,y,z],abs(x/y), mode=mode)
f(x_val,y_val,z_val)
theano.printing.debugprint(f)
assert isinstance(f.maker.env.toposort()[1].op.scalar_op, scal.Abs)
assert len(f.maker.env.toposort())==2
def test_mixeddiv(): def test_mixeddiv():
"""Test that int division is preserved""" """Test that int division is preserved"""
i = iscalar() i = iscalar()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论