提交 e6629d8d authored 作者: James Bergstra's avatar James Bergstra

Optimizations: Added local_fill_useless and reverted a saveguard on the

Canonicalizer's get_num_denum function... this might break tests, but it was essential for getting rid of numerical instabilities in RBM free energy grad.
上级 60a89d0b
...@@ -44,34 +44,50 @@ def _fill_chain(new_out, orig_inputs): ...@@ -44,34 +44,50 @@ def _fill_chain(new_out, orig_inputs):
new_out = T.fill(i, new_out) new_out = T.fill(i, new_out)
return [new_out] return [new_out]
def get_constant_value(v, fill=False): def get_constant_value(v):
"""return the constant value underlying variable `v` """return the constant scalar(0-D) value underlying variable `v`
If v is the output of dimshuffles, fills, this function digs through them. If v is the output of dimshuffles, fills, this function digs through them.
If `v` is not some view of constant data, then raise a TypeError. If `v` is not some view of constant data, then raise a TypeError.
if fill is True, then it returns (v, [...]) where the second term is a list of variables
that were used in the fill expressions
:note: There may be another function similar to this one in the code, but I'm not sure where it :note: There may be another function similar to this one in the code, but I'm not sure where it
is. is.
""" """
if isinstance(v, gof.Constant): if isinstance(v, gof.Constant):
if fill: #TODO: consider checking for arrays of the form e.g. [1,1,1,1] where
return v.data, [] # it is not a constant, but in some cases it *could* be replaced with one.
return v.data # Note that this would have an effect on the broadcasting of inputs and so on
try:
complex(v.data) #works for all numeric scalars
return v.data
except:
raise TypeError(v)
if v.owner and isinstance(v.owner.op, T.DimShuffle): if v.owner and isinstance(v.owner.op, T.DimShuffle):
return get_constant_value(v.owner.inputs[0], fill=fill) return get_constant_value(v.owner.inputs[0])
if fill: if v.owner and v.owner.op == T.fill:
if v.owner and v.owner.op == T.fill: shape, val = v.owner.inputs
shape, val = v.owner.inputs # fill(a,b) fills the shape of 'a' filled with 'b'
# fill(a,b) fills the shape of 'a' filled with 'b' rval, rshapes = get_constant_value(val)
rval, rshapes = get_constant_value(val, fill=fill) return rval, rshapes + [shape]
return rval, rshapes + [shape]
raise TypeError(v) raise TypeError(v)
def scalarconsts_rest(inputs):
"""Partition a list of variables into two kinds:
scalar constants, and the rest."""
consts = []
origconsts = []
nonconsts = []
for i in inputs:
try:
v = get_constant_value(i)
consts.append(v)
origconsts.append(i)
except:
nonconsts.append(i)
return consts, origconsts, nonconsts
@gof.optimizer @gof.optimizer
def insert_inplace_optimizer(env): def insert_inplace_optimizer(env):
""" """
...@@ -287,7 +303,24 @@ def local_fill_lift(node): ...@@ -287,7 +303,24 @@ def local_fill_lift(node):
return False return False
register_canonicalize(local_fill_lift, 'fill_lift') register_canonicalize(local_fill_lift, 'fill_lift')
register_specialize(local_fill_lift, 'fill_lift')
@register_specialize
@register_canonicalize
@gof.local_optimizer([T.fill])
def local_fill_useless(node):
"""fill(y,x) -> x
This is legal when the output of fill has the same type as x,
because it means that y isn't contributing anything.
"""
if node.op == T.fill:
shape, val = node.inputs
output, = node.outputs
if output.type == val.type:
# if shape is not being used to broadcast
# then we can ignore it.
return [val]
################## ##################
# Subtensor opts # # Subtensor opts #
...@@ -581,13 +614,18 @@ class Canonizer(gof.LocalOptimizer): ...@@ -581,13 +614,18 @@ class Canonizer(gof.LocalOptimizer):
# the dtype of the 'input' argument. The leaf-Variables of the graph covered by the # the dtype of the 'input' argument. The leaf-Variables of the graph covered by the
# recursion may be of any Variable type. # recursion may be of any Variable type.
if len(input.clients) > 1: if 0:
# this logic is too conservative, but doing it is better than not doing it. # UPDATE: This logic makes it impossible to recognize some important patterns
# # (e.g. variants on the x/x)
# we don't want to canonize a subgraph that we will need to compute anyway for the other clients. # and it is screwing up the RBM free energy gradient.
# This check is too conservative because if the other clients are also in the subgraph we are canonizing, #TODO: review this
# then we should [probably?] recurse anyway. if len(input.clients) > 1:
return [input], [] # this logic is too conservative, but doing it is better than not doing it.
#
# we don't want to canonize a subgraph that we will need to compute anyway for the other clients.
# This check is too conservative because if the other clients are also in the subgraph we are canonizing,
# then we should [probably?] recurse anyway.
return [input], []
if input.owner is None or input.owner.op not in [self.main, self.inverse, self.reciprocal]: if input.owner is None or input.owner.op not in [self.main, self.inverse, self.reciprocal]:
if input.owner and isinstance(input.owner.op, T.DimShuffle): if input.owner and isinstance(input.owner.op, T.DimShuffle):
...@@ -835,8 +873,8 @@ class Canonizer(gof.LocalOptimizer): ...@@ -835,8 +873,8 @@ class Canonizer(gof.LocalOptimizer):
# Here we make the canonical version of the graph around this node # Here we make the canonical version of the graph around this node
# See the documentation of get_num_denum and simplify # See the documentation of get_num_denum and simplify
orig_num, orig_denum = self.get_num_denum(node.outputs[0]) orig_num, orig_denum = self.get_num_denum(node.outputs[0])
num, denum = list(orig_num), list(orig_denum) num, denum = self.simplify(list(orig_num), list(orig_denum))
num, denum = self.simplify(num, denum)
def same(x, y): def same(x, y):
return len(x) == len(y) and all(N.all(xe == ye) for xe, ye in zip(x, y)) return len(x) == len(y) and all(N.all(xe == ye) for xe, ye in zip(x, y))
...@@ -1150,23 +1188,16 @@ def local_log1p(node): ...@@ -1150,23 +1188,16 @@ def local_log1p(node):
if node.op == T.log: if node.op == T.log:
log_arg, = node.inputs log_arg, = node.inputs
if log_arg.owner and log_arg.owner.op == T.add: if log_arg.owner and log_arg.owner.op == T.add:
add_inputs = log_arg.owner.inputs scalars, scalar_inputs, nonconsts = \
consts = [0] scalarconsts_rest(log_arg.owner.inputs)
fills = [] # scalar_inputs are potentially dimshuffled and fill'd scalars
nonconsts = [] if scalars and numpy.allclose(numpy.sum(scalars), 1):
for add_in in add_inputs: if not nonconsts:
try: pass # leave for constant-merge
v, f = get_constant_value(add_in, fill=True) if len(nonconsts)==1:
consts.append(v) return _fill_chain(T.log1p(nonconsts[0]), scalar_inputs)
fills.extend(f) else:
except: return _fill_chain(T.log1p(T.add(*nonconsts)), scalar_inputs)
nonconsts.append(add_in)
if nonconsts:
if numpy.allclose(numpy.sum(consts), 1):
if len(nonconsts)==1:
return _fill_chain(T.log1p(nonconsts[0]), fills)
else:
return _fill_chain(T.log1p(T.add(*nonconsts)), fills)
def add_calculate(num, denum, aslist = False, out_type=None): def add_calculate(num, denum, aslist = False, out_type=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论