提交 594527d0 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

make get_constant_value go through "Alloc" nodes

上级 34f59af0
......@@ -63,7 +63,7 @@ def merge_broadcastables(broadcastables):
def get_constant_value(v):
"""return the constant scalar(0-D) value underlying variable `v`
If v is the output of dimshuffles, fills, this function digs through them.
If v is the output of dimshuffles, fills, allocs, this function digs through them.
If `v` is not some view of constant data, then raise a TypeError.
......@@ -936,36 +936,15 @@ class Canonizer(gof.LocalOptimizer):
return self.inverse(self.merge_num_denum(num, []),
self.merge_num_denum(denum, []))
@classmethod
def get_constant(cls, v):
@staticmethod
def get_constant(v):
"""
Returns a numeric constant if v is a gof.Constant or, well, a
numeric constant. If v is a plain Variable, returns None.
"""
if isinstance(v, N.generic):
return v # doesn't the not hasattr() condition below catch this?
if isinstance(v, gof.Constant):
return v.data
if not hasattr(v, 'owner'):
return v
# NOTE: the following code was buggy, but while I was fixing
# it I realized it is probably made useless by constant
# folding, so screw that. Commented-out code is the half-fixed
# version.
# if v.owner and isinstance(v.owner.op, DimShuffle):
# # see the comments in get_num_denum
# # TODO: this should apply the
# dsn = v.owner
# dsop = dsn.op
# dsi0 = dsn.inputs[0]
# compatible_order = ('x',) * (input.type.ndim - dsi0.type.ndim) + tuple(range(dsi0.type.ndim))
# if dsop.new_order == compatible_order:
# return cls.get_constant(v.owner.inputs[0])
try:
return get_constant_value(v)
except TypeError:
return None
def simplify(self, num, denum):
......@@ -1415,27 +1394,16 @@ def local_add_specialize(node):
def fill_chain(v):
return _fill_chain(v, node.inputs)
def get_constant_through_fills_and_subtensors(v):
if v.owner is not None:
if v.owner.op == T.fill:
assert len(v.owner.inputs) == 2
return get_constant_through_fills_and_subtensors(v.owner.inputs[1])
if isinstance(v.owner.op, T.DimShuffle):
assert len(v.owner.inputs) == 1
return get_constant_through_fills_and_subtensors(v.owner.inputs[0])
elif hasattr(v, 'data'):
return v.data
else:
return v
#here, we are past the point of canonicalization, so we don't want to put in un-necessary fills.
if node.op == T.add:
new_inputs = []
for input in node.inputs:
y = get_constant_through_fills_and_subtensors(input)
try:
y = get_constant_value(input)
except TypeError:
y = input
if N.all(y == 0.0):
continue
else:
new_inputs.append(input)
if len(new_inputs) < len(node.inputs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论