提交 825bf41f authored 作者: Frederic's avatar Frederic

Make a cache for TensorConstant.

This lower memory usage and speed up optimization by asking for less merge.
上级 d9bad63c
...@@ -400,10 +400,26 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None): ...@@ -400,10 +400,26 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None):
raise TypeError("Could not convert %s to TensorType" % x, type(x)) raise TypeError("Could not convert %s to TensorType" % x, type(x))
constant_cache = {}
def constant(x, name=None, ndim=None, dtype=None): def constant(x, name=None, ndim=None, dtype=None):
return constant_or_value(x, rtype=TensorConstant, name=name, ndim=ndim, ret = constant_or_value(x, rtype=TensorConstant, name=name, ndim=ndim,
dtype=dtype) dtype=dtype)
#We create a small cache of frequently used constant.
#This speed up the Merge optimization for big graph.
#We want to cache all scalar to don't merge as frequently constants.
#But we don't want to cache too much stuff
#So we cache integer with dtype [u]int and float where the value is between -10 and 10
#We want to cache all broadcast pattern for scalar.
sig = ret.signature()
if (sig not in constant_cache and ret.data.size == 1 and
ret.data <= 10 and ret.data >= -10 and
(ret.dtype in int_dtypes or ret.dtype in uint_dtypes or
(ret.dtype in float_dtypes and int(ret.data) == ret.data))):
constant_cache[sig] = ret
return constant_cache.get(sig, ret)
def _obj_is_wrappable_as_tensor(x): def _obj_is_wrappable_as_tensor(x):
try: try:
......
...@@ -749,10 +749,8 @@ class Elemwise(Op): ...@@ -749,10 +749,8 @@ class Elemwise(Op):
# the gradient contains a constant, translate it as # the gradient contains a constant, translate it as
# an equivalent TensorType of size 1 and proper number of # an equivalent TensorType of size 1 and proper number of
# dimensions # dimensions
res = TensorConstant(TensorType(dtype=r.type.dtype, res = theano.tensor.constant(numpy.asarray(r.data), dtype=r.type.dtype)
broadcastable=()), return DimShuffle((), ['x'] * nd, inplace=False)(res)
numpy.asarray(r.data)) # .reshape(b)
return DimShuffle((), ['x'] * nd, inplace=True)(res)
new_r = Elemwise(node.op, {})( new_r = Elemwise(node.op, {})(
*[transform(ipt) for ipt in node.inputs]) *[transform(ipt) for ipt in node.inputs])
return new_r return new_r
......
...@@ -3711,17 +3711,16 @@ def local_add_specialize(node): ...@@ -3711,17 +3711,16 @@ def local_add_specialize(node):
continue continue
new_inputs.append(input) new_inputs.append(input)
if len(new_inputs) < len(node.inputs): if len(new_inputs) < len(node.inputs):
dtype = node.outputs[0].type.dtype dtype = node.outputs[0].type.dtype
if len(new_inputs) == 0: if len(new_inputs) == 0:
#we got rid of the entire expression! #we got rid of the entire expression!
ndim = node.outputs[0].type.ndim ndim = node.outputs[0].type.ndim
return fill_chain( #Reuse call to constant for cache()
T.TensorConstant( cst = T.constant(numpy.zeros((1,) * ndim, dtype=dtype))
T.TensorType( assert cst.type.broadcastable == [True] * ndim
dtype=dtype, return fill_chain(cst)
broadcastable=[True] * ndim),
numpy.zeros((1,) * ndim, dtype=dtype)))
if len(new_inputs) == 1: if len(new_inputs) == 1:
ret = fill_chain(new_inputs[0]) ret = fill_chain(new_inputs[0])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论