提交 73598c96 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Add a flag to Constants if all entries equal a value.

This flag is going to be used by optimizations. For example dot(zeros, X) should be optimized out even if zeros is a constant.
上级 ba9ffef6
...@@ -1419,7 +1419,20 @@ class TensorConstant(_tensor_py_operators, Constant): ...@@ -1419,7 +1419,20 @@ class TensorConstant(_tensor_py_operators, Constant):
To create a TensorConstant, use the `constant` function in this module. To create a TensorConstant, use the `constant` function in this module.
""" """
def __init__(self, type, data, name = None):
Constant.__init__(self, type, data, name)
if (isinstance(data, numpy.ndarray) and
data.ndim > 0 and
len(numpy.unique(data)) == 1):
self.tag.unique_value = numpy.unique(data)[0]
else:
self.tag.unique_value = None
def __str__(self): def __str__(self):
if self.tag.unique_value is not None:
name = "%s of %s"%(str(self.data.shape),
str(self.tag.unique_value))
else:
name = "%s"%self.data name = "%s"%self.data
if len(name) > 20: if len(name) > 20:
name = name[:10]+".."+name[-10:] name = name[:10]+".."+name[-10:]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论