提交 2269b2ec authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Speedup supports c_code

Not using `__call__` avoids the test_value computation
上级 0b2cf52f
...@@ -1332,32 +1332,26 @@ class ScalarOp(COp): ...@@ -1332,32 +1332,26 @@ class ScalarOp(COp):
the given Elemwise inputs, outputs. the given Elemwise inputs, outputs.
""" """
try:
tmp_s_input = [] tmp_s_input = []
# To keep the same aliasing between inputs # To keep the same aliasing between inputs
mapping = dict() mapping = {}
for ii in inputs: for ii in inputs:
if ii in mapping: if ii in mapping:
tmp_s_input.append(mapping[ii]) tmp_s_input.append(mapping[ii])
else: else:
tmp = get_scalar_type(ii.dtype).make_variable() tmp = mapping[ii] = get_scalar_type(ii.dtype).make_variable()
tmp_s_input.append(tmp) tmp_s_input.append(tmp)
mapping[ii] = tmp_s_input[-1]
with config.change_flags(compute_test_value="ignore"): try:
s_op = self(*tmp_s_input, return_list=True)
# if the scalar_op don't have a c implementation,
# we skip its fusion to allow the fusion of the
# other ops.
self.c_code( self.c_code(
s_op[0].owner, self.make_node(*tmp_s_input),
"test_presence_of_c_code", "test_presence_of_c_code",
# FIXME: Shouldn't this be a unique name per unique variable?
["x" for x in inputs], ["x" for x in inputs],
["z" for z in outputs], ["z" for z in outputs],
{"fail": "%(fail)s"}, {"fail": "%(fail)s"},
) )
except (MethodNotDefined, NotImplementedError): except (NotImplementedError, MethodNotDefined):
return False return False
return True return True
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论