提交 312e640c authored 作者: Frederic Bastien's avatar Frederic Bastien

return right time (list vs var)

上级 1c31c496
...@@ -2032,7 +2032,7 @@ def local_useless_elemwise(node): ...@@ -2032,7 +2032,7 @@ def local_useless_elemwise(node):
# Copy stack trace from input to constant output # Copy stack trace from input to constant output
copy_stack_trace(node.outputs[0], ret) copy_stack_trace(node.outputs[0], ret)
return ret return [ret]
elif node.op.scalar_op == theano.scalar.neq and len(node.inputs) == 2: elif node.op.scalar_op == theano.scalar.neq and len(node.inputs) == 2:
if node.inputs[0] == node.inputs[1]: if node.inputs[0] == node.inputs[1]:
# it is the same var in the graph. That will always be false # it is the same var in the graph. That will always be false
...@@ -2040,7 +2040,7 @@ def local_useless_elemwise(node): ...@@ -2040,7 +2040,7 @@ def local_useless_elemwise(node):
# Copy stack trace from input to constant output # Copy stack trace from input to constant output
copy_stack_trace(node.outputs[0], ret) copy_stack_trace(node.outputs[0], ret)
return ret return [ret]
elif node.op.scalar_op == theano.scalar.mul and len(node.inputs) == 1: elif node.op.scalar_op == theano.scalar.mul and len(node.inputs) == 1:
# No need to copy over any stack trace # No need to copy over any stack trace
...@@ -2060,8 +2060,8 @@ def local_useless_elemwise(node): ...@@ -2060,8 +2060,8 @@ def local_useless_elemwise(node):
const_val = T.extract_constant(node.inputs[0], only_process_constants=True) const_val = T.extract_constant(node.inputs[0], only_process_constants=True)
if not isinstance(const_val, Variable): if not isinstance(const_val, Variable):
if const_val == 0: if const_val == 0:
return T.zeros_like(node.inputs[1], dtype=dtype, return [T.zeros_like(node.inputs[1], dtype=dtype,
opt=True) opt=True)]
else: else:
return [node.inputs[1]] return [node.inputs[1]]
...@@ -2069,8 +2069,8 @@ def local_useless_elemwise(node): ...@@ -2069,8 +2069,8 @@ def local_useless_elemwise(node):
const_val = T.extract_constant(node.inputs[1], only_process_constants=True) const_val = T.extract_constant(node.inputs[1], only_process_constants=True)
if not isinstance(const_val, Variable): if not isinstance(const_val, Variable):
if const_val == 0: if const_val == 0:
return T.zeros_like(node.inputs[0], dtype=dtype, return [T.zeros_like(node.inputs[0], dtype=dtype,
opt=True) opt=True)]
else: else:
return [node.inputs[0]] return [node.inputs[0]]
...@@ -2083,7 +2083,8 @@ def local_useless_elemwise(node): ...@@ -2083,7 +2083,8 @@ def local_useless_elemwise(node):
if const_val == 0: if const_val == 0:
return [node.inputs[1]] return [node.inputs[1]]
else: else:
return T.ones_like(node.inputs[1], dtype=dtype, opt=True) return [T.ones_like(node.inputs[1], dtype=dtype,
opt=True)]
if isinstance(node.inputs[1], T.TensorConstant): if isinstance(node.inputs[1], T.TensorConstant):
const_val = T.extract_constant(node.inputs[1], only_process_constants=True) const_val = T.extract_constant(node.inputs[1], only_process_constants=True)
...@@ -2091,12 +2092,13 @@ def local_useless_elemwise(node): ...@@ -2091,12 +2092,13 @@ def local_useless_elemwise(node):
if const_val == 0: if const_val == 0:
return [node.inputs[0]] return [node.inputs[0]]
else: else:
return T.ones_like(node.inputs[0], dtype=dtype, opt=True) return [T.ones_like(node.inputs[0], dtype=dtype,
opt=True)]
elif (isinstance(node.op.scalar_op, scalar.XOR) and elif (isinstance(node.op.scalar_op, scalar.XOR) and
len(node.inputs) == 2): len(node.inputs) == 2):
if node.inputs[0] is node.inputs[1]: if node.inputs[0] is node.inputs[1]:
return T.zeros_like(node.inputs[0], dtype=dtype, opt=True) return [T.zeros_like(node.inputs[0], dtype=dtype, opt=True)]
@register_specialize @register_specialize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论