提交 5e51a8f2 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

move the extract_constant outside the canonical_form function so that I can

re-use it for other optimizations.
上级 553c304b
...@@ -2607,6 +2607,25 @@ def get_idx_list(inputs, idx_list): ...@@ -2607,6 +2607,25 @@ def get_idx_list(inputs, idx_list):
return cdata return cdata
def extract_constant(x):
'''
This function is basically a call to tensor.get_constant_value. The
main difference is the behaviour in case of failure. While
get_constant_value raises an TypeError, this function returns x,
as a tensor ( by removing the last scalar_from_tensor ) if needed
or None if that is the value of x.
'''
try:
x = get_constant_value(x)
except:
pass
if isinstance(x, scal.ScalarVariable):
if x.owner and isinstance(x.owner.op, ScalarFromTensor):
x = x.owner.inputs[0]
else:
x = tensor.tensor_from_scalar(x)
return x
def get_canonical_form_slice(theslice, length): def get_canonical_form_slice(theslice, length):
''' '''
...@@ -2618,24 +2637,6 @@ def get_canonical_form_slice(theslice, length): ...@@ -2618,24 +2637,6 @@ def get_canonical_form_slice(theslice, length):
resulting set of numbers needs to be reversed or not. resulting set of numbers needs to be reversed or not.
''' '''
def extract_constant(x):
'''
This function is basically a call to tensor.get_constant_value. The
main difference is the behaviour in case of failure. While
get_constant_value raises an TypeError, this function returns x,
as a tensor ( by removing the last scalar_from_tensor ) if needed
or None if that is the value of x.
'''
try:
x = get_constant_value(x)
except:
pass
if isinstance(x, scal.ScalarVariable):
if x.owner and isinstance(x.owner.op, ScalarFromTensor):
x = x.owner.inputs[0]
else:
x = tensor.tensor_from_scalar(x)
return x
if isinstance(theslice,slice): if isinstance(theslice,slice):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论