提交 f0cd8ba1 authored 作者: Vikram's avatar Vikram

Optimization to remove upcast in local_cast_cast

上级 4b246506
......@@ -2204,7 +2204,7 @@ def local_cast_cast(node):
when those contrain:
dtype1 == dtype2
TODO: the base dtype is the same (int, uint, float, complex)
OR the base dtype is the same (int, uint, float, complex)
and the first cast cause an upcast.
"""
......@@ -2216,10 +2216,51 @@ def local_cast_cast(node):
not isinstance(x.owner.op, T.Elemwise) or
not isinstance(x.owner.op.scalar_op, scalar.Cast)):
return
if node.op.scalar_op.o_type == x.owner.op.scalar_op.o_type:
type1 = x.owner.op.scalar_op.o_type
type2 = node.op.scalar_op.o_type
base = x.owner.inputs[0]
if type1 == type2:
# We don't need to copy over any stack traces here
return [x]
if(upcast(base.dtype, type1.dtype)):
# Checking for further redundancy. Eg: int8 -> int32 -> int8
if(type2.dtype == base.dtype):
return x.owner.inputs
else:
# Apply the second cast only
v = node.op(base)
return [v]
def upcast(type1, type2):
'''
Given two data types (as strings), check if converting to
type2 from type1 constitutes an upcast.
'''
upcast_pairs = (
('int8','int16'),('int8','int32'),('int8','int64'),('int16','int32'),('int16','int64'),('int32','int64'),
('uint8','uint16'),('uint8','uint32'),('uint8','uint64'),('uint16','uint32'),('uint16','uint64'),('uint32','uint64'),
('float16','float32'),('float16','float32'),('float16','float64'),('float32','float64'),
('complex64','complex128'),
('uint8','int16'),('uint8','int32'),('uint8','int64'),('uint16','int32'),('uint16','int64'),('uint32','int64'),
('int8','float16'),('int8','float32'),('int8','float64'),('int16','float32'),('int16','float64'),('int32','float64'),
('uint8','float16'),('uint8','float32'),('uint8','float64'),('uint16','float32'),('uint16','float64'),('uint32','float64'),
('int8','complex64'),('int16','complex64'),
('uint8','complex64'),('uint16','complex64'),
('float32','complex64'),
('int8','complex128'),('int16','complex128'),('int32','complex128'),
('uint8','complex128'),('uint16','complex128'),('uint32','complex128'),
('float32','complex128'),('float64','complex128')
)
for pair in upcast_pairs:
if(type1 == pair[0] and type2 == pair[1]):
return True
return False
@register_canonicalize
@register_specialize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论