提交 90c45509 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix local_mul_canonizer to prevent a cycle during Canonization

- When extracting num and denum constants, wrap them in T.Constants, so they keep their dtype, and cannot be downcasted - Upcast all elements in num and denum to get the type of the neutral element, if it is not specified - Cosmetic changes
上级 743f01e7
......@@ -1312,10 +1312,9 @@ class Canonizer(gof.LocalOptimizer):
ln = [self.calculate([], [], aslist = False)]
if not ld:
if ln == 1:
if isinstance(num[0], gof.Variable):
return num[0]
else:
return T.as_tensor_variable(num[0])
# num[0] should always be a variable
assert isinstance(num[0], gof.Variable)
return num[0]
else:
return self.main(*num)
return self.inverse(self.merge_num_denum(num, []),
......@@ -1406,9 +1405,12 @@ class Canonizer(gof.LocalOptimizer):
# we can't allow ct == []
# TODO: why is this branch needed when merge_num_denum does it for us?
ct = [self.calculate(numct, denumct, aslist = False, out_type=out_type)]
# TODO: why are we not wrapping ct in a Constant right now?
if orig_num and len(numct) == 1 and len(denumct) == 0 and ct and N.all(ct == self.get_constant(orig_num[0])):
# Wrapping ct in a Constant with the right dtype
ct = [T.constant(c, dtype=out_type.dtype) for c in ct]
if orig_num and len(numct) == 1 and len(denumct) == 0 and ct and\
N.all([c.data for c in ct] == self.get_constant(orig_num[0])):
# this is an important trick :( if it so happens that:
# * there's exactly one constant on the numerator and none on the denominator
# * it's not the neutral element (ct is an empty list in that case)
......@@ -1424,7 +1426,7 @@ class Canonizer(gof.LocalOptimizer):
op = node.op
if op not in [self.main, self.inverse, self.reciprocal]:
return False
inputs = node.inputs
out = node.outputs[0]
assert len(node.outputs) == 1
......@@ -1440,7 +1442,7 @@ class Canonizer(gof.LocalOptimizer):
if c=='output': continue
if _bypass_dimshuffle(c).op in [self.main, self.inverse, self.reciprocal]:
return False
# Here we make the canonical version of the graph around this node
# See the documentation of get_num_denum and simplify
orig_num, orig_denum = self.get_num_denum(node.outputs[0])
......@@ -1480,22 +1482,17 @@ def mul_calculate(num, denum, aslist=False, out_type=None):
if not num and not denum:
# Smallest 1 possible.
if aslist:
return []
return []
else:
return N.int8(1)
return N.int8(1)
#return [] if aslist else N.int8(1)
# Make sure we do not accidently upcast data types.
if out_type is None:
# TODO: remove this error-causing heuristic
if num:
first = num[0]
else:
first = denum[0]
#first = num[0] if num else denum[0]
one = N.asarray(first).dtype.type(1)
out_dtype = scalar.upcast(*[v.dtype for v in (num+denum)])
else:
one = theano._asarray(1, dtype=out_type.dtype)
out_dtype = out_type.dtype
one = theano._asarray(1, dtype=out_dtype)
v = reduce(N.multiply, num, one) / reduce(N.multiply, denum, one)
if aslist:
if N.all(v == 1):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论