提交 3daf9b02 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Cosmetic changes

上级 ff542349
...@@ -257,46 +257,41 @@ def register_specialize_device(lopt, *tags, **kwargs): ...@@ -257,46 +257,41 @@ def register_specialize_device(lopt, *tags, **kwargs):
@register_stabilize @register_stabilize
@gof.local_optimizer([None]) @gof.local_optimizer([None])
def local_0_dot_x(node): def local_0_dot_x(node):
if not isinstance(node.op, T.Dot):
return False
if isinstance(node.op, T.Dot): x = node.inputs[0]
x = node.inputs[0] y = node.inputs[1]
y = node.inputs[1] replace = False
replace = False try:
try: if get_constant_value(x) == 0:
if get_constant_value(x) == 0:
replace = True
except TypeError:
pass
try:
if get_constant_value(y) == 0:
replace = True
except TypeError:
pass
# TODO: Integrate that into get_constant_value somehow
if isinstance(x, T.TensorConstant) and (x.tag.unique_value == 0):
replace = True replace = True
if isinstance(y, T.TensorConstant) and (y.tag.unique_value == 0): except TypeError:
pass
try:
if get_constant_value(y) == 0:
replace = True replace = True
except TypeError:
pass
if replace: # TODO: Integrate that into get_constant_value somehow
if x.ndim ==2 and y.ndim == 2: if isinstance(x, T.TensorConstant) and (x.tag.unique_value == 0):
return [T.alloc( T.constant(0,dtype = node.outputs[0].type.dtype), replace = True
x.shape[0], y.shape[1])] if isinstance(y, T.TensorConstant) and (y.tag.unique_value == 0):
elif x.ndim==1 and y.ndim == 2: replace = True
return [T.alloc( T.constant(0,dtype = node.outputs[0].type.dtype),
y.shape[1])] if replace:
elif x.ndim==2 and y.ndim ==1: constant_zero = T.constant(0, dtype=node.outputs[0].type.dtype)
return [T.alloc( T.constant(0,dtype = node.outputs[0].type.dtype), if x.ndim == 2 and y.ndim == 2:
x.shape[0])] return [T.alloc(constant_zero, x.shape[0], y.shape[1])]
elif x.ndim==1 and y.ndim==1: elif x.ndim == 1 and y.ndim == 2:
return [T.constant(0,dtype = node.outputs[0].type.dtype)] return [T.alloc(constant_zero, y.shape[1])]
else: elif x.ndim == 2 and y.ndim == 1:
return False return [T.alloc(constant_zero, x.shape[0])]
elif x.ndim == 1 and y.ndim == 1:
return [constant_zero]
else:
return False
###################### ######################
# DimShuffle lifters # # DimShuffle lifters #
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论