提交 7c7c4b8f authored 作者: Frederic's avatar Frederic

Don't put useless stuff in the graph

上级 02be7a2c
......@@ -73,8 +73,21 @@ def alpha_merge(cls, alpha_in, beta_in, nd):
if lr is None or targ is None:
return None
inputs = list(targ.inputs)
inputs[alpha_in] = lr * targ.inputs[alpha_in]
inputs[beta_in] = lr * targ.inputs[beta_in]
try:
raise NotScalarConstantError()
c = get_scalar_constant_value(lr)
if c == 0:
inputs[alpha_in] = lr
inputs[beta_in] = lr
elif c == 1:
inputs[alpha_in] = targ.inputs[alpha_in]
inputs[beta_in] = targ.inputs[beta_in]
else:
inputs[alpha_in] = lr * targ.inputs[alpha_in]
inputs[beta_in] = lr * targ.inputs[beta_in]
except NotScalarConstantError:
inputs[alpha_in] = lr * targ.inputs[alpha_in]
inputs[beta_in] = lr * targ.inputs[beta_in]
return maker(targ, *inputs)
return opt
return wrapper
......
......@@ -75,8 +75,20 @@ def alpha_merge(cls, alpha_in, beta_in, nd):
if lr is None or targ is None:
return None
inputs = list(targ.inputs)
inputs[alpha_in] = lr * targ.inputs[alpha_in]
inputs[beta_in] = lr * targ.inputs[beta_in]
try:
c = get_scalar_constant_value(lr)
if c == 0:
inputs[alpha_in] = lr
inputs[beta_in] = lr
elif c == 1:
inputs[alpha_in] = targ.inputs[alpha_in]
inputs[beta_in] = targ.inputs[beta_in]
else:
inputs[alpha_in] = lr * targ.inputs[alpha_in]
inputs[beta_in] = lr * targ.inputs[beta_in]
except NotScalarConstantError:
inputs[alpha_in] = lr * targ.inputs[alpha_in]
inputs[beta_in] = lr * targ.inputs[beta_in]
return maker(targ, *inputs)
return opt
return wrapper
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论