提交 3668e8ba authored 作者: James Bergstra's avatar James Bergstra

fixed bug in canonize related to output type-matching. added fill_chain…

fixed bug in canonize related to output type-matching. added fill_chain function. fixed local_mul_specialize too
上级 0696f748
......@@ -252,7 +252,7 @@ class TensorType(Type):
def __hash__(self):
"""Hash equal for same kinds of TensorType"""
return hash(self.dtype) ^ hash(self.broadcastable)
return hash(type(self)) ^ hash(self.dtype) ^ hash(self.broadcastable)
ndim = property(lambda self: len(self.broadcastable), doc = "number of dimensions")
"""Number of dimensions
......
......@@ -34,6 +34,11 @@ def in2out(*local_opts, **kwargs):
failure_callback=TopoOptimizer.warn_inplace,
**kwargs)
def _fill_chain(new_out, orig_inputs):
for i in orig_inputs:
new_out = T.fill(i, new_out)
return [new_out]
@gof.optimizer
......@@ -692,9 +697,6 @@ class Canonizer(gof.LocalOptimizer):
elif op == self.reciprocal:
reorg = len(iops.intersection([self.inverse, self.reciprocal])) != 0
# just in case
assert len(node.outputs) == 1
# Here we make the canonical version of the graph around this node
# See the documentation of get_num_denum and simplify
orig_num, orig_denum = self.get_num_denum(node.outputs[0])
......@@ -715,17 +717,16 @@ class Canonizer(gof.LocalOptimizer):
elem_op = T.Elemwise(scalar.Identity(scalar.specific_out(getattr(scalar, out.type.dtype))))
new = elem_op(new)
if new.type != out.type:
for x in orig_num + orig_denum:
if x.type == out.type:
new = T.fill(x, new)
break
assert (new.type == out.type) == (not (new.type != out.type))
if not (new.type == out.type):
new = _fill_chain(new, node.inputs)[0]
if new.type != out.type:
if new.type == out.type:
return [new]
else:
print >> sys.stderr, 'CANONIZE FAILED: new, out = ', new, ',', out, 'types', new.type, ',', out.type
return False
else:
return [new]
def __str__(self):
return getattr(self, 'name', 'Canonizer(%s, %s, %s)' % (self.main, self.inverse, self.reciprocal))
......@@ -817,6 +818,8 @@ register_specialize(local_pow_specialize)
@gof.local_optimizer([T.mul])
def local_mul_specialize(node):
def fill_chain(v):
return _fill_chain(v, node.inputs)
#here, we are past the point of canonicalization, so we don't want to put in un-necessary fills.
if node.op == T.mul:
#the idea here is that we have pow(x, y)
......@@ -829,20 +832,20 @@ def local_mul_specialize(node):
elif N.all(y == -1.0):
neg ^= True #toggles
elif N.all(y == 0.0):
return [input]
return fill_chain(input)
else:
new_inputs.append(input)
if len(new_inputs) < len(node.inputs):
if len(new_inputs) == 0:
newval = -y.flatten()[0] if neg else y.flatten()[0]
return [T.TensorConstant(T.TensorType(dtype=node.outputs[0].type.dtype,
broadcastable = [True] * node.outputs[0].ndim), N.asarray(newval))]
return fill_chain(T.TensorConstant(T.TensorType(dtype=node.outputs[0].type.dtype,
broadcastable = [True] * node.outputs[0].ndim), N.asarray(newval)))
if len(new_inputs) == 1:
return [-new_inputs[0]] if neg else new_inputs
return fill_chain(-new_inputs[0] if neg else new_inputs[0])
else:
return [-T.mul(*new_inputs)] if neg else \
[T.mul(*new_inputs)]
return fill_chain(-T.mul(*new_inputs) if neg else \
T.mul(*new_inputs))
else:
return False
register_specialize(local_mul_specialize)
......@@ -985,7 +988,7 @@ def local_greedy_distributor(node):
rval = local_mul_canonizer.merge_num_denum(new_num, new_denum)
if rval.type != out.type:
if not (rval.type == out.type):
#WHY DOES THIS HAPPEN?
return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论