提交 44470332 authored 作者: Frederic's avatar Frederic

optimize consecutive cast

上级 e89232cb
...@@ -1449,6 +1449,29 @@ def local_alloc_unary(node): ...@@ -1449,6 +1449,29 @@ def local_alloc_unary(node):
return [T.alloc(T.cast(v, node.outputs[0].dtype), *shp)] return [T.alloc(T.cast(v, node.outputs[0].dtype), *shp)]
@register_canonicalize
@register_specialize
@gof.local_optimizer([T.Elemwise])
def local_cast_cast(node):
"""cast(cast(x, dtype1), dtype2)
when those contrain:
dtype1 == dtype2
TODO: the base dtype is the same (int, uint, float, complex)
and the first cast cause an upcast.
"""
if (not isinstance(node.op, T.Elemwise) or
not isinstance(node.op.scalar_op, scalar.Cast)):
return
x = node.inputs[0]
if (not x.owner or
not isinstance(x.owner.op, T.Elemwise) or
not isinstance(x.owner.op.scalar_op, scalar.Cast)):
return
if node.op.scalar_op.o_type == x.owner.op.scalar_op.o_type:
return [x]
class Assert(T.Op): class Assert(T.Op):
""" """
Implements assertion in a computational graph. Implements assertion in a computational graph.
......
...@@ -3599,6 +3599,31 @@ class T_useless_elemwise(unittest.TestCase): ...@@ -3599,6 +3599,31 @@ class T_useless_elemwise(unittest.TestCase):
assert topo[0].op == deep_copy_op assert topo[0].op == deep_copy_op
class T_cast_cast(unittest.TestCase):
def setUp(self):
mode = theano.compile.get_default_mode()
self.mode = mode.including('local_cast_cast')
def test(self):
x = T.fmatrix()
o = T.Elemwise(scal.Cast(scal.Scalar("float64")))(x.astype("float64"))
f = theano.function([x], o, mode=self.mode)
dx = numpy.random.rand(5, 4).astype("float32")
f(dx)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, T.Elemwise)
x = T.dmatrix()
o = T.Elemwise(scal.Cast(scal.Scalar("float32")))(x.astype("float32"))
f = theano.function([x], o, mode=self.mode)
dx = numpy.random.rand(5, 4)
f(dx)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, T.Elemwise)
def test_constant_folding(): def test_constant_folding():
""" Test that constant folding get registered at fast_compile """ Test that constant folding get registered at fast_compile
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论