提交 fcbb3e9b authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5665 from vikramnitin9/master

Optimization to remove upcast in local_cast_cast
......@@ -2204,7 +2204,7 @@ def local_cast_cast(node):
when those contrain:
dtype1 == dtype2
TODO: the base dtype is the same (int, uint, float, complex)
OR the base dtype is the same (int, uint, float, complex)
and the first cast cause an upcast.
"""
......@@ -2216,10 +2216,54 @@ def local_cast_cast(node):
not isinstance(x.owner.op, T.Elemwise) or
not isinstance(x.owner.op.scalar_op, scalar.Cast)):
return
if node.op.scalar_op.o_type == x.owner.op.scalar_op.o_type:
type1 = x.owner.op.scalar_op.o_type
type2 = node.op.scalar_op.o_type
base = x.owner.inputs[0]
if type1 == type2:
# We don't need to copy over any stack traces here
return [x]
if(is_an_upcast(base.dtype, type1.dtype)):
# Checking for further redundancy. Eg: int8 -> int32 -> int8
if(type2.dtype == base.dtype):
return x.owner.inputs
else:
# Apply the second cast only
v = node.op(base)
# Copy stack trace from the output of the original cast
copy_stack_trace(node.outputs[0], v)
return [v]
def is_an_upcast(type1, type2):
"""Given two data types (as strings), check if converting to
type2 from type1 constitutes an upcast.
Differs from theano.scalar.upcast
"""
category = {
# The first number in the pair is the dtype (bool, uint, int, float,
# complex). Conversion from higher to lower is never an upcast.
# The second number roughly indicates the precision. Again, conversion
# from higher to lower is never an upcast.
'bool': (0, 0),
'uint8': (1, 1), 'uint16': (1, 2), 'uint32': (1, 3), 'uint64': (1, 4),
'int8': (2, 1), 'int16': (2, 2), 'int32': (2, 3), 'int64': (2, 4),
'float16': (3, 1.5), 'float32': (3, 2.5), 'float64': (3, 3.5),
'complex64': (4, 3), 'complex128': (4, 4)
}
cat1 = category[type1]
cat2 = category[type2]
if(cat2[0] >= cat1[0] and cat2[1] > cat1[1]):
return True
else:
return False
@register_canonicalize
@register_specialize
......
......@@ -4656,7 +4656,7 @@ class T_cast_cast(unittest.TestCase):
mode = theano.compile.get_default_mode()
self.mode = mode.including('local_cast_cast')
def test(self):
def test_consecutive(self):
x = T.fmatrix()
o = T.Elemwise(scal.Cast(scal.Scalar("float64")))(x.astype("float64"))
f = theano.function([x], o, mode=self.mode)
......@@ -4664,7 +4664,7 @@ class T_cast_cast(unittest.TestCase):
f(dx)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, T.Elemwise)
assert isinstance(topo[0].op.scalar_op, scal.basic.Cast)
x = T.dmatrix()
o = T.Elemwise(scal.Cast(scal.Scalar("float32")))(x.astype("float32"))
......@@ -4673,7 +4673,38 @@ class T_cast_cast(unittest.TestCase):
f(dx)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, T.Elemwise)
assert isinstance(topo[0].op.scalar_op, scal.basic.Cast)
def test_upcast(self):
# Upcast followed by any other cast
x = T.fmatrix()
o = T.Elemwise(scal.Cast(scal.Scalar("complex128")))(x.astype("complex64"))
f = theano.function([x], o, mode=self.mode)
dx = numpy.random.rand(5, 4).astype("float32")
f(dx)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op.scalar_op, scal.basic.Cast)
# Upcast followed by a downcast back to the base type
x = T.fmatrix()
o = T.Elemwise(scal.Cast(scal.Scalar("float32")))(x.astype("float64"))
f = theano.function([x], o, mode=self.mode)
dx = numpy.random.rand(5, 4).astype('float32')
f(dx)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, DeepCopyOp)
# Downcast followed by an upcast back to the base type
# Optimization shouldn't be applied
x = T.dmatrix()
o = T.Elemwise(scal.Cast(scal.Scalar("float64")))(x.astype("float32"))
f = theano.function([x], o, mode=self.mode)
dx = numpy.random.rand(5, 4)
f(dx)
topo = f.maker.fgraph.toposort()
assert (len(topo) == 1 and isinstance(topo[0].op.scalar_op, scal.basic.Composite)) or (len(topo) > 1)
class T_func_inverse(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论