提交 64590a18 authored 作者: Frederic's avatar Frederic

Re-enable some call to get_scalar_constant_value, but disable the tracking of elemwise.

上级 64d621b4
......@@ -532,7 +532,7 @@ get_scalar_constant_value_elemwises = (
scal.LT, scal.GT, scal.LE, scal.GE,
scal.Sub, scal.Add, scal.Mod, scal.Mul,
scal.IntDiv, scal.TrueDiv)
def get_scalar_constant_value(orig_v):
def get_scalar_constant_value(orig_v, elemwise=True):
"""return the constant scalar(0-D) value underlying variable `v`
If v is the output of dimshuffles, fills, allocs, rebroadcasts, cast
......@@ -541,6 +541,9 @@ def get_scalar_constant_value(orig_v):
If `v` is not some view of constant scalar data, then raise a
NotScalarConstantError.
:param elemwise: If False, we won't try to go into elemwise.
So this call is faster.
:note: There may be another function similar to this one in the
code, but I'm not sure where it is.
"""
......@@ -590,7 +593,7 @@ def get_scalar_constant_value(orig_v):
ret = [[None]]
v.owner.op.perform(v.owner, const, ret)
return ret[0][0]
elif isinstance(v.owner.op, Elemwise):
elif elemwise and isinstance(v.owner.op, Elemwise):
if isinstance(v.owner.op.scalar_op, scal.Second):
# We don't need both input to be constant for second
shape, val = v.owner.inputs
......@@ -3079,7 +3082,7 @@ pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right'))
##########################
def extract_constant(x):
def extract_constant(x, elemwise=True):
'''
This function is basically a call to tensor.get_scalar_constant_value. The
main difference is the behaviour in case of failure. While
......@@ -3089,7 +3092,7 @@ def extract_constant(x):
ScalarVariable, we convert it to a tensor with tensor_from_scalar.
'''
try:
x = get_scalar_constant_value(x)
x = get_scalar_constant_value(x, elemwise=elemwise)
except NotScalarConstantError:
pass
if (isinstance(x, scal.ScalarVariable) or
......
......@@ -1580,10 +1580,8 @@ def local_upcast_elemwise_constant_inputs(node):
new_inputs.append(i)
else:
try:
if not isinstance(i, Constant):
raise NotScalarConstantError()
# works only for scalars
cval_i = get_scalar_constant_value(i)
cval_i = get_scalar_constant_value(i, elemwise=False)
if all(i.broadcastable):
new_inputs.append(T.shape_padleft(
T.cast(cval_i, output_dtype),
......@@ -2328,9 +2326,8 @@ def local_remove_switch_const_cond(node):
if cond is constant and cond != 0: left
"""
if (isinstance(node.op, T.Elemwise) and
isinstance(node.op.scalar_op, scalar.basic.Switch) and
isinstance(node.inputs[0], Constant)):
cond = T.extract_constant(node.inputs[0])
isinstance(node.op.scalar_op, scalar.basic.Switch)):
cond = T.extract_constant(node.inputs[0], elemwise=False)
if type(cond) is numpy.ndarray and cond.ndim == 0:
if cond == 0:
out = node.inputs[2]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论