提交 594527d0 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

make get_constant_value go through "Alloc" nodes

上级 34f59af0
...@@ -63,7 +63,7 @@ def merge_broadcastables(broadcastables): ...@@ -63,7 +63,7 @@ def merge_broadcastables(broadcastables):
def get_constant_value(v): def get_constant_value(v):
"""return the constant scalar(0-D) value underlying variable `v` """return the constant scalar(0-D) value underlying variable `v`
If v is the output of dimshuffles, fills, this function digs through them. If v is the output of dimshuffles, fills, allocs, this function digs through them.
If `v` is not some view of constant data, then raise a TypeError. If `v` is not some view of constant data, then raise a TypeError.
...@@ -936,37 +936,16 @@ class Canonizer(gof.LocalOptimizer): ...@@ -936,37 +936,16 @@ class Canonizer(gof.LocalOptimizer):
return self.inverse(self.merge_num_denum(num, []), return self.inverse(self.merge_num_denum(num, []),
self.merge_num_denum(denum, [])) self.merge_num_denum(denum, []))
@classmethod @staticmethod
def get_constant(cls, v): def get_constant(v):
""" """
Returns a numeric constant if v is a gof.Constant or, well, a Returns a numeric constant if v is a gof.Constant or, well, a
numeric constant. If v is a plain Variable, returns None. numeric constant. If v is a plain Variable, returns None.
""" """
if isinstance(v, N.generic): try:
return v # doesn't the not hasattr() condition below catch this? return get_constant_value(v)
if isinstance(v, gof.Constant): except TypeError:
return v.data return None
if not hasattr(v, 'owner'):
return v
# NOTE: the following code was buggy, but while I was fixing
# it I realized it is probably made useless by constant
# folding, so screw that. Commented-out code is the half-fixed
# version.
# if v.owner and isinstance(v.owner.op, DimShuffle):
# # see the comments in get_num_denum
# # TODO: this should apply the
# dsn = v.owner
# dsop = dsn.op
# dsi0 = dsn.inputs[0]
# compatible_order = ('x',) * (input.type.ndim - dsi0.type.ndim) + tuple(range(dsi0.type.ndim))
# if dsop.new_order == compatible_order:
# return cls.get_constant(v.owner.inputs[0])
return None
def simplify(self, num, denum): def simplify(self, num, denum):
""" """
...@@ -1415,28 +1394,17 @@ def local_add_specialize(node): ...@@ -1415,28 +1394,17 @@ def local_add_specialize(node):
def fill_chain(v): def fill_chain(v):
return _fill_chain(v, node.inputs) return _fill_chain(v, node.inputs)
def get_constant_through_fills_and_subtensors(v):
if v.owner is not None:
if v.owner.op == T.fill:
assert len(v.owner.inputs) == 2
return get_constant_through_fills_and_subtensors(v.owner.inputs[1])
if isinstance(v.owner.op, T.DimShuffle):
assert len(v.owner.inputs) == 1
return get_constant_through_fills_and_subtensors(v.owner.inputs[0])
elif hasattr(v, 'data'):
return v.data
else:
return v
#here, we are past the point of canonicalization, so we don't want to put in un-necessary fills. #here, we are past the point of canonicalization, so we don't want to put in un-necessary fills.
if node.op == T.add: if node.op == T.add:
new_inputs = [] new_inputs = []
for input in node.inputs: for input in node.inputs:
y = get_constant_through_fills_and_subtensors(input) try:
y = get_constant_value(input)
except TypeError:
y = input
if N.all(y == 0.0): if N.all(y == 0.0):
continue continue
else: new_inputs.append(input)
new_inputs.append(input)
if len(new_inputs) < len(node.inputs): if len(new_inputs) < len(node.inputs):
if len(new_inputs) == 0: if len(new_inputs) == 0:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论