提交 e6629d8d authored 作者: James Bergstra's avatar James Bergstra

Optimizations: Added local_fill_useless and reverted a saveguard on the

Canonicalizer's get_num_denum function... this might break tests, but it was essential for getting rid of numerical instabilities in RBM free energy grad.
上级 60a89d0b
......@@ -44,34 +44,50 @@ def _fill_chain(new_out, orig_inputs):
new_out = T.fill(i, new_out)
return [new_out]
def get_constant_value(v, fill=False):
"""return the constant value underlying variable `v`
def get_constant_value(v):
"""return the constant scalar(0-D) value underlying variable `v`
If v is the output of dimshuffles, fills, this function digs through them.
If `v` is not some view of constant data, then raise a TypeError.
if fill is True, then it returns (v, [...]) where the second term is a list of variables
that were used in the fill expressions
:note: There may be another function similar to this one in the code, but I'm not sure where it
is.
"""
if isinstance(v, gof.Constant):
if fill:
return v.data, []
return v.data
#TODO: consider checking for arrays of the form e.g. [1,1,1,1] where
# it is not a constant, but in some cases it *could* be replaced with one.
# Note that this would have an effect on the broadcasting of inputs and so on
try:
complex(v.data) #works for all numeric scalars
return v.data
except:
raise TypeError(v)
if v.owner and isinstance(v.owner.op, T.DimShuffle):
return get_constant_value(v.owner.inputs[0], fill=fill)
if fill:
if v.owner and v.owner.op == T.fill:
shape, val = v.owner.inputs
# fill(a,b) fills the shape of 'a' filled with 'b'
rval, rshapes = get_constant_value(val, fill=fill)
return rval, rshapes + [shape]
return get_constant_value(v.owner.inputs[0])
if v.owner and v.owner.op == T.fill:
shape, val = v.owner.inputs
# fill(a,b) fills the shape of 'a' filled with 'b'
rval, rshapes = get_constant_value(val)
return rval, rshapes + [shape]
raise TypeError(v)
def scalarconsts_rest(inputs):
"""Partition a list of variables into two kinds:
scalar constants, and the rest."""
consts = []
origconsts = []
nonconsts = []
for i in inputs:
try:
v = get_constant_value(i)
consts.append(v)
origconsts.append(i)
except:
nonconsts.append(i)
return consts, origconsts, nonconsts
@gof.optimizer
def insert_inplace_optimizer(env):
"""
......@@ -287,7 +303,24 @@ def local_fill_lift(node):
return False
register_canonicalize(local_fill_lift, 'fill_lift')
register_specialize(local_fill_lift, 'fill_lift')
@register_specialize
@register_canonicalize
@gof.local_optimizer([T.fill])
def local_fill_useless(node):
"""fill(y,x) -> x
This is legal when the output of fill has the same type as x,
because it means that y isn't contributing anything.
"""
if node.op == T.fill:
shape, val = node.inputs
output, = node.outputs
if output.type == val.type:
# if shape is not being used to broadcast
# then we can ignore it.
return [val]
##################
# Subtensor opts #
......@@ -581,13 +614,18 @@ class Canonizer(gof.LocalOptimizer):
# the dtype of the 'input' argument. The leaf-Variables of the graph covered by the
# recursion may be of any Variable type.
if len(input.clients) > 1:
# this logic is too conservative, but doing it is better than not doing it.
#
# we don't want to canonize a subgraph that we will need to compute anyway for the other clients.
# This check is too conservative because if the other clients are also in the subgraph we are canonizing,
# then we should [probably?] recurse anyway.
return [input], []
if 0:
# UPDATE: This logic makes it impossible to recognize some important patterns
# (e.g. variants on the x/x)
# and it is screwing up the RBM free energy gradient.
#TODO: review this
if len(input.clients) > 1:
# this logic is too conservative, but doing it is better than not doing it.
#
# we don't want to canonize a subgraph that we will need to compute anyway for the other clients.
# This check is too conservative because if the other clients are also in the subgraph we are canonizing,
# then we should [probably?] recurse anyway.
return [input], []
if input.owner is None or input.owner.op not in [self.main, self.inverse, self.reciprocal]:
if input.owner and isinstance(input.owner.op, T.DimShuffle):
......@@ -835,8 +873,8 @@ class Canonizer(gof.LocalOptimizer):
# Here we make the canonical version of the graph around this node
# See the documentation of get_num_denum and simplify
orig_num, orig_denum = self.get_num_denum(node.outputs[0])
num, denum = list(orig_num), list(orig_denum)
num, denum = self.simplify(num, denum)
num, denum = self.simplify(list(orig_num), list(orig_denum))
def same(x, y):
return len(x) == len(y) and all(N.all(xe == ye) for xe, ye in zip(x, y))
......@@ -1150,23 +1188,16 @@ def local_log1p(node):
if node.op == T.log:
log_arg, = node.inputs
if log_arg.owner and log_arg.owner.op == T.add:
add_inputs = log_arg.owner.inputs
consts = [0]
fills = []
nonconsts = []
for add_in in add_inputs:
try:
v, f = get_constant_value(add_in, fill=True)
consts.append(v)
fills.extend(f)
except:
nonconsts.append(add_in)
if nonconsts:
if numpy.allclose(numpy.sum(consts), 1):
if len(nonconsts)==1:
return _fill_chain(T.log1p(nonconsts[0]), fills)
else:
return _fill_chain(T.log1p(T.add(*nonconsts)), fills)
scalars, scalar_inputs, nonconsts = \
scalarconsts_rest(log_arg.owner.inputs)
# scalar_inputs are potentially dimshuffled and fill'd scalars
if scalars and numpy.allclose(numpy.sum(scalars), 1):
if not nonconsts:
pass # leave for constant-merge
if len(nonconsts)==1:
return _fill_chain(T.log1p(nonconsts[0]), scalar_inputs)
else:
return _fill_chain(T.log1p(T.add(*nonconsts)), scalar_inputs)
def add_calculate(num, denum, aslist = False, out_type=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论