提交 963ca200 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix dtype problem when the pruned branch of switch was upcasting the other

上级 5005974e
...@@ -1455,14 +1455,22 @@ def local_remove_switch_const_cond(node): ...@@ -1455,14 +1455,22 @@ def local_remove_switch_const_cond(node):
if cond is constant and cond == 0: right if cond is constant and cond == 0: right
if cond is constant and cond != 0: left if cond is constant and cond != 0: left
""" """
if ( isinstance(node.op, T.Elemwise) and if (isinstance(node.op, T.Elemwise) and
isinstance(node.op.scalar_op, scalar.basic.Switch)): isinstance(node.op.scalar_op, scalar.basic.Switch)):
cond = T.extract_constant(node.inputs[0]) cond = T.extract_constant(node.inputs[0])
if type(cond) is numpy.ndarray and cond.ndim == 0: if type(cond) is numpy.ndarray and cond.ndim == 0:
if cond == 0: if cond == 0:
return [node.inputs[2]] out = node.inputs[2]
else: else:
return [node.inputs[1]] out = node.inputs[1]
if out.ndim != node.outputs[0].ndim:
#TODO: broadcast?
return False
if out.dtype != node.outputs[0].dtype:
out = T.cast(out, node.outputs[0].dtype)
return [out]
return False return False
return False return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论