提交 af68ea42 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

dot(zeros,x) -> zeros

new optimization for dot products.
上级 73598c96
...@@ -249,6 +249,57 @@ def register_specialize_device(lopt, *tags, **kwargs): ...@@ -249,6 +249,57 @@ def register_specialize_device(lopt, *tags, **kwargs):
return lopt return lopt
#####################
# Dot optimizations #
#####################
@register_canonicalize
@register_stabilize
@gof.local_optimizer([None])
def local_0_dot_x(node):
if isinstance(node.op, T.Dot):
x = node.inputs[0]
y = node.inputs[1]
replace = False
if x.owner and isinstance(x.owner.op, T.Alloc):
try:
val = get_constant_value(x.owner.inputs[0])
if numpy.all(val == 0):
replace = True
except TypeError:
pass
if y.owner and isinstance(y.owner.op, T.Alloc):
try:
val = get_constant_value(y.owner.inputs[0])
if numpy.all(val == 0):
replace = True
except TypeError:
pass
if isinstance(x, T.TensorConstant) and (x.tag.unique_value == 0):
replace = True
if isinstance(y, T.TensorConstant) and (y.tag.unique_value == 0):
replace = True
if replace:
if x.ndim ==2 and y.ndim == 2:
return [T.alloc( T.constant(0,dtype = node.outputs[0].type.dtype),
x.shape[0], y.shape[1])]
elif x.ndim==1 and y.ndim == 2:
return [T.alloc( T.constant(0,dtype = node.outputs[0].type.dtype),
y.shape[1])]
elif x.ndim==2 and y.ndim ==1:
return [T.alloc( T.constant(0,dtype = node.outputs[0].type.dtype),
x.shape[0])]
elif x.ndim==1 and y.ndim==1:
return [T.constant(0,dtype = node.outputs[0].type.dtype)]
else:
return False
else:
return False
###################### ######################
# DimShuffle lifters # # DimShuffle lifters #
###################### ######################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论