提交 dc68f404 authored 作者: James Bergstra's avatar James Bergstra

correction to constant canonicalization

上级 dfc52a04
...@@ -636,8 +636,12 @@ def local_upcast_elemwise_constant_inputs(node): ...@@ -636,8 +636,12 @@ def local_upcast_elemwise_constant_inputs(node):
else: else:
try: try:
cval_i = get_constant_value(i) # works only for scalars I think cval_i = get_constant_value(i) # works only for scalars I think
new_inputs.append(T.cast(cval_i, output_dtype)) if 0==sum((not b for b in i.broadcastable)): # I mean all() but this might work in python2.4
except: new_inputs.append(T.cast(cval_i, output_dtype))
else:
new_inputs.append(T.alloc(T.cast(cval_i, output_dtype),
*[Shape_i(d)(i) for d in xrange(i.ndim)]))
except TypeError:
if isinstance(i, T.TensorConstant): #for the case of a non-scalar if isinstance(i, T.TensorConstant): #for the case of a non-scalar
new_inputs.append(T.cast(i, output_dtype)) new_inputs.append(T.cast(i, output_dtype))
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论