提交 3668e8ba authored 作者: James Bergstra's avatar James Bergstra

fixed bug in canonize related to output type-matching. added fill_chain…

fixed bug in canonize related to output type-matching. added fill_chain function. fixed local_mul_specialize too
上级 0696f748
...@@ -252,7 +252,7 @@ class TensorType(Type): ...@@ -252,7 +252,7 @@ class TensorType(Type):
def __hash__(self): def __hash__(self):
"""Hash equal for same kinds of TensorType""" """Hash equal for same kinds of TensorType"""
return hash(self.dtype) ^ hash(self.broadcastable) return hash(type(self)) ^ hash(self.dtype) ^ hash(self.broadcastable)
ndim = property(lambda self: len(self.broadcastable), doc = "number of dimensions") ndim = property(lambda self: len(self.broadcastable), doc = "number of dimensions")
"""Number of dimensions """Number of dimensions
......
...@@ -34,6 +34,11 @@ def in2out(*local_opts, **kwargs): ...@@ -34,6 +34,11 @@ def in2out(*local_opts, **kwargs):
failure_callback=TopoOptimizer.warn_inplace, failure_callback=TopoOptimizer.warn_inplace,
**kwargs) **kwargs)
def _fill_chain(new_out, orig_inputs):
for i in orig_inputs:
new_out = T.fill(i, new_out)
return [new_out]
@gof.optimizer @gof.optimizer
...@@ -692,9 +697,6 @@ class Canonizer(gof.LocalOptimizer): ...@@ -692,9 +697,6 @@ class Canonizer(gof.LocalOptimizer):
elif op == self.reciprocal: elif op == self.reciprocal:
reorg = len(iops.intersection([self.inverse, self.reciprocal])) != 0 reorg = len(iops.intersection([self.inverse, self.reciprocal])) != 0
# just in case
assert len(node.outputs) == 1
# Here we make the canonical version of the graph around this node # Here we make the canonical version of the graph around this node
# See the documentation of get_num_denum and simplify # See the documentation of get_num_denum and simplify
orig_num, orig_denum = self.get_num_denum(node.outputs[0]) orig_num, orig_denum = self.get_num_denum(node.outputs[0])
...@@ -715,17 +717,16 @@ class Canonizer(gof.LocalOptimizer): ...@@ -715,17 +717,16 @@ class Canonizer(gof.LocalOptimizer):
elem_op = T.Elemwise(scalar.Identity(scalar.specific_out(getattr(scalar, out.type.dtype)))) elem_op = T.Elemwise(scalar.Identity(scalar.specific_out(getattr(scalar, out.type.dtype))))
new = elem_op(new) new = elem_op(new)
if new.type != out.type: assert (new.type == out.type) == (not (new.type != out.type))
for x in orig_num + orig_denum:
if x.type == out.type: if not (new.type == out.type):
new = T.fill(x, new) new = _fill_chain(new, node.inputs)[0]
break
if new.type != out.type: if new.type == out.type:
return [new]
else:
print >> sys.stderr, 'CANONIZE FAILED: new, out = ', new, ',', out, 'types', new.type, ',', out.type print >> sys.stderr, 'CANONIZE FAILED: new, out = ', new, ',', out, 'types', new.type, ',', out.type
return False return False
else:
return [new]
def __str__(self): def __str__(self):
return getattr(self, 'name', 'Canonizer(%s, %s, %s)' % (self.main, self.inverse, self.reciprocal)) return getattr(self, 'name', 'Canonizer(%s, %s, %s)' % (self.main, self.inverse, self.reciprocal))
...@@ -817,6 +818,8 @@ register_specialize(local_pow_specialize) ...@@ -817,6 +818,8 @@ register_specialize(local_pow_specialize)
@gof.local_optimizer([T.mul]) @gof.local_optimizer([T.mul])
def local_mul_specialize(node): def local_mul_specialize(node):
def fill_chain(v):
return _fill_chain(v, node.inputs)
#here, we are past the point of canonicalization, so we don't want to put in un-necessary fills. #here, we are past the point of canonicalization, so we don't want to put in un-necessary fills.
if node.op == T.mul: if node.op == T.mul:
#the idea here is that we have pow(x, y) #the idea here is that we have pow(x, y)
...@@ -829,20 +832,20 @@ def local_mul_specialize(node): ...@@ -829,20 +832,20 @@ def local_mul_specialize(node):
elif N.all(y == -1.0): elif N.all(y == -1.0):
neg ^= True #toggles neg ^= True #toggles
elif N.all(y == 0.0): elif N.all(y == 0.0):
return [input] return fill_chain(input)
else: else:
new_inputs.append(input) new_inputs.append(input)
if len(new_inputs) < len(node.inputs): if len(new_inputs) < len(node.inputs):
if len(new_inputs) == 0: if len(new_inputs) == 0:
newval = -y.flatten()[0] if neg else y.flatten()[0] newval = -y.flatten()[0] if neg else y.flatten()[0]
return [T.TensorConstant(T.TensorType(dtype=node.outputs[0].type.dtype, return fill_chain(T.TensorConstant(T.TensorType(dtype=node.outputs[0].type.dtype,
broadcastable = [True] * node.outputs[0].ndim), N.asarray(newval))] broadcastable = [True] * node.outputs[0].ndim), N.asarray(newval)))
if len(new_inputs) == 1: if len(new_inputs) == 1:
return [-new_inputs[0]] if neg else new_inputs return fill_chain(-new_inputs[0] if neg else new_inputs[0])
else: else:
return [-T.mul(*new_inputs)] if neg else \ return fill_chain(-T.mul(*new_inputs) if neg else \
[T.mul(*new_inputs)] T.mul(*new_inputs))
else: else:
return False return False
register_specialize(local_mul_specialize) register_specialize(local_mul_specialize)
...@@ -985,7 +988,7 @@ def local_greedy_distributor(node): ...@@ -985,7 +988,7 @@ def local_greedy_distributor(node):
rval = local_mul_canonizer.merge_num_denum(new_num, new_denum) rval = local_mul_canonizer.merge_num_denum(new_num, new_denum)
if rval.type != out.type: if not (rval.type == out.type):
#WHY DOES THIS HAPPEN? #WHY DOES THIS HAPPEN?
return False return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论