提交 f2c76070 authored 作者: Caglar's avatar Caglar

Added the fred's suggestions.

上级 e22c6ad6
......@@ -134,7 +134,8 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
nonconsts = []
for i in inputs:
try:
v = get_scalar_constant_value(i, only_process_constants=only_process_constants)
v = get_scalar_constant_value(i, elemwise=elemwise,
only_process_constants=only_process_constants)
consts.append(v)
origconsts.append(i)
except NotScalarConstantError:
......@@ -2319,7 +2320,7 @@ def local_upcast_elemwise_constant_inputs(node):
else:
try:
# works only for scalars
cval_i = get_scalar_constant_value(i, elemwise=False,
cval_i = get_scalar_constant_value(i,
only_process_constants=True)
if all(i.broadcastable):
new_inputs.append(T.shape_padleft(
......@@ -3733,7 +3734,7 @@ def local_useless_switch(node):
"""
if (isinstance(node.op, T.Elemwise) and
isinstance(node.op.scalar_op, scalar.basic.Switch)):
cond = T.extract_constant(node.inputs[0], elemwise=False,
cond = T.extract_constant(node.inputs[0],
only_process_constants=True)
if ((type(cond) is numpy.ndarray and cond.ndim == 0) or
isinstance(cond, numpy.number)):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论