提交 825bf41f authored 作者: Frederic's avatar Frederic

Make a cache for TensorConstant.

This lower memory usage and speed up optimization by asking for less merge.
上级 d9bad63c
......@@ -400,10 +400,26 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None):
raise TypeError("Could not convert %s to TensorType" % x, type(x))
constant_cache = {}
def constant(x, name=None, ndim=None, dtype=None):
return constant_or_value(x, rtype=TensorConstant, name=name, ndim=ndim,
ret = constant_or_value(x, rtype=TensorConstant, name=name, ndim=ndim,
dtype=dtype)
#We create a small cache of frequently used constant.
#This speed up the Merge optimization for big graph.
#We want to cache all scalar to don't merge as frequently constants.
#But we don't want to cache too much stuff
#So we cache integer with dtype [u]int and float where the value is between -10 and 10
#We want to cache all broadcast pattern for scalar.
sig = ret.signature()
if (sig not in constant_cache and ret.data.size == 1 and
ret.data <= 10 and ret.data >= -10 and
(ret.dtype in int_dtypes or ret.dtype in uint_dtypes or
(ret.dtype in float_dtypes and int(ret.data) == ret.data))):
constant_cache[sig] = ret
return constant_cache.get(sig, ret)
def _obj_is_wrappable_as_tensor(x):
try:
......
......@@ -749,10 +749,8 @@ class Elemwise(Op):
# the gradient contains a constant, translate it as
# an equivalent TensorType of size 1 and proper number of
# dimensions
res = TensorConstant(TensorType(dtype=r.type.dtype,
broadcastable=()),
numpy.asarray(r.data)) # .reshape(b)
return DimShuffle((), ['x'] * nd, inplace=True)(res)
res = theano.tensor.constant(numpy.asarray(r.data), dtype=r.type.dtype)
return DimShuffle((), ['x'] * nd, inplace=False)(res)
new_r = Elemwise(node.op, {})(
*[transform(ipt) for ipt in node.inputs])
return new_r
......
......@@ -3711,17 +3711,16 @@ def local_add_specialize(node):
continue
new_inputs.append(input)
if len(new_inputs) < len(node.inputs):
dtype = node.outputs[0].type.dtype
if len(new_inputs) == 0:
#we got rid of the entire expression!
ndim = node.outputs[0].type.ndim
return fill_chain(
T.TensorConstant(
T.TensorType(
dtype=dtype,
broadcastable=[True] * ndim),
numpy.zeros((1,) * ndim, dtype=dtype)))
#Reuse call to constant for cache()
cst = T.constant(numpy.zeros((1,) * ndim, dtype=dtype))
assert cst.type.broadcastable == [True] * ndim
return fill_chain(cst)
if len(new_inputs) == 1:
ret = fill_chain(new_inputs[0])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论