提交 963ca200 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix dtype problem when the pruned branch of switch was upcasting the other

上级 5005974e
......@@ -1455,14 +1455,22 @@ def local_remove_switch_const_cond(node):
if cond is constant and cond == 0: right
if cond is constant and cond != 0: left
"""
if ( isinstance(node.op, T.Elemwise) and
if (isinstance(node.op, T.Elemwise) and
isinstance(node.op.scalar_op, scalar.basic.Switch)):
cond = T.extract_constant(node.inputs[0])
if type(cond) is numpy.ndarray and cond.ndim == 0:
if cond == 0:
return [node.inputs[2]]
out = node.inputs[2]
else:
return [node.inputs[1]]
out = node.inputs[1]
if out.ndim != node.outputs[0].ndim:
#TODO: broadcast?
return False
if out.dtype != node.outputs[0].dtype:
out = T.cast(out, node.outputs[0].dtype)
return [out]
return False
return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论