提交 7c7c4b8f authored 作者: Frederic's avatar Frederic

Don't put useless stuff in the graph

上级 02be7a2c
...@@ -73,8 +73,21 @@ def alpha_merge(cls, alpha_in, beta_in, nd): ...@@ -73,8 +73,21 @@ def alpha_merge(cls, alpha_in, beta_in, nd):
if lr is None or targ is None: if lr is None or targ is None:
return None return None
inputs = list(targ.inputs) inputs = list(targ.inputs)
inputs[alpha_in] = lr * targ.inputs[alpha_in] try:
inputs[beta_in] = lr * targ.inputs[beta_in] raise NotScalarConstantError()
c = get_scalar_constant_value(lr)
if c == 0:
inputs[alpha_in] = lr
inputs[beta_in] = lr
elif c == 1:
inputs[alpha_in] = targ.inputs[alpha_in]
inputs[beta_in] = targ.inputs[beta_in]
else:
inputs[alpha_in] = lr * targ.inputs[alpha_in]
inputs[beta_in] = lr * targ.inputs[beta_in]
except NotScalarConstantError:
inputs[alpha_in] = lr * targ.inputs[alpha_in]
inputs[beta_in] = lr * targ.inputs[beta_in]
return maker(targ, *inputs) return maker(targ, *inputs)
return opt return opt
return wrapper return wrapper
......
...@@ -75,8 +75,20 @@ def alpha_merge(cls, alpha_in, beta_in, nd): ...@@ -75,8 +75,20 @@ def alpha_merge(cls, alpha_in, beta_in, nd):
if lr is None or targ is None: if lr is None or targ is None:
return None return None
inputs = list(targ.inputs) inputs = list(targ.inputs)
inputs[alpha_in] = lr * targ.inputs[alpha_in] try:
inputs[beta_in] = lr * targ.inputs[beta_in] c = get_scalar_constant_value(lr)
if c == 0:
inputs[alpha_in] = lr
inputs[beta_in] = lr
elif c == 1:
inputs[alpha_in] = targ.inputs[alpha_in]
inputs[beta_in] = targ.inputs[beta_in]
else:
inputs[alpha_in] = lr * targ.inputs[alpha_in]
inputs[beta_in] = lr * targ.inputs[beta_in]
except NotScalarConstantError:
inputs[alpha_in] = lr * targ.inputs[alpha_in]
inputs[beta_in] = lr * targ.inputs[beta_in]
return maker(targ, *inputs) return maker(targ, *inputs)
return opt return opt
return wrapper return wrapper
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论