提交 71b85fd6 authored 作者: Frederic Bastien's avatar Frederic Bastien

Rework tensor.constant() to don't create TensorConstant that won't be used. This…

Rework tensor.constant() to don't create TensorConstant that won't be used. This will help investigate where new constant are used. Small speed up. Remove useless fct.
上级 d00947b5
...@@ -361,7 +361,7 @@ class TestAutoName: ...@@ -361,7 +361,7 @@ class TestAutoName:
r3 = tensor.constant(1.6) r3 = tensor.constant(1.6)
# The cache still create a new object that we don't return. # The cache still create a new object that we don't return.
# This is why we must increase by 2 and not 1. # This is why we must increase by 2 and not 1.
assert r3.auto_name == "auto_" + str(autoname_id + 2) assert r3.auto_name == "auto_" + str(autoname_id + 1)
def test_tensorvariable(self): def test_tensorvariable(self):
# Get counter value # Get counter value
......
...@@ -19,7 +19,7 @@ from theano.gof.type import Generic ...@@ -19,7 +19,7 @@ from theano.gof.type import Generic
from theano.tensor import elemwise from theano.tensor import elemwise
from theano.tensor.var import (AsTensorError, TensorVariable, from theano.tensor.var import (AsTensorError, TensorVariable,
TensorConstant, TensorConstant, TensorConstantSignature,
_tensor_py_operators) _tensor_py_operators)
from theano.tensor.type import TensorType, values_eq_approx_always_true from theano.tensor.type import TensorType, values_eq_approx_always_true
from theano.tensor.type_other import NoneConst from theano.tensor.type_other import NoneConst
...@@ -220,7 +220,7 @@ _as_tensor_variable = as_tensor_variable ...@@ -220,7 +220,7 @@ _as_tensor_variable = as_tensor_variable
as_tensor = as_tensor_variable as_tensor = as_tensor_variable
def constant_or_value(x, rtype, name=None, ndim=None, dtype=None): def constant(x, name=None, ndim=None, dtype=None):
"""Return a symbolic `Constant` with value `x`. """Return a symbolic `Constant` with value `x`.
Raises Raises
...@@ -230,6 +230,16 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None): ...@@ -230,6 +230,16 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None):
ValueError ValueError
`x` could not be expanded to have ndim dimensions. `x` could not be expanded to have ndim dimensions.
Note
----
We create a small cache of frequently used constant.
This speed up the Merge optimization for big graph.
We want to cache all scalar to don't merge as frequently constants.
But we don't want to cache too much stuff.
So we cache integer with dtype [u]int and float where the value is
between -10 and 10.
We cache all broadcast pattern for scalar.
""" """
x_ = scal.convert(x, dtype=dtype) x_ = scal.convert(x, dtype=dtype)
...@@ -245,41 +255,29 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None): ...@@ -245,41 +255,29 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None):
assert len(bcastable) == ndim assert len(bcastable) == ndim
try: try:
if rtype is TensorConstant: type = TensorType(dtype=x_.dtype, broadcastable=bcastable)
x_ = x_.copy() if not constant.enable:
rval = rtype( return TensorConstant(type, x_, name=name)
TensorType(dtype=x_.dtype, broadcastable=bcastable),
x_, name=name)
return rval
except Exception:
raise TypeError("Could not convert %s to TensorType" % x, type(x))
sig = TensorConstantSignature((type, x_))
if sig in constant_cache:
return constant_cache[sig]
def constant(x, name=None, ndim=None, dtype=None): ret = TensorConstant(type, x_, name=name)
ret = constant_or_value(x, rtype=TensorConstant, name=name, ndim=ndim, if (x_.size == 1 and
dtype=dtype) (-10) <= x_ <= 10 and
(x_.dtype in int_dtypes or x_.dtype in uint_dtypes or
# We create a small cache of frequently used constant. (x_.dtype in float_dtypes and
# This speed up the Merge optimization for big graph.
# We want to cache all scalar to don't merge as frequently constants.
# But we don't want to cache too much stuff
# So we cache integer with dtype [u]int and float where the value is
# between -10 and 10
# We want to cache all broadcast pattern for scalar.
if not constant.enable:
return ret
sig = ret.signature()
if (sig not in constant_cache and ret.data.size == 1 and
(-10) <= ret.data <= 10 and
(ret.dtype in int_dtypes or ret.dtype in uint_dtypes or
(ret.dtype in float_dtypes and
# Limit the size of the cache. # Limit the size of the cache.
len(constant_cache) < 10000))): len(constant_cache) < 10000))):
constant_cache[sig] = ret constant_cache[sig] = ret
# This is needed to raise a good error to the user. # This is needed to raise a good error to the user.
ret.cached = True ret.cached = True
return ret
except Exception:
raise TypeError("Could not convert %s to TensorType" % x, type(x))
return constant_cache.get(sig, ret)
constant.enable = True constant.enable = True
constant_cache = {} constant_cache = {}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论