提交 24718ae9 authored 作者: James Bergstra's avatar James Bergstra

fixed error in canonizer

上级 a46caf01
...@@ -706,28 +706,13 @@ class Canonizer(gof.LocalOptimizer): ...@@ -706,28 +706,13 @@ class Canonizer(gof.LocalOptimizer):
return False return False
new = self.merge_num_denum(num, denum) new = self.merge_num_denum(num, denum)
if new.dtype != out.dtype: if new.type.dtype != out.type.dtype:
#new = T.fill(out, new) #new = T.fill(out, new)
elem_op = T.Elemwise(scalar.Identity(scalar.specific_out(getattr(scalar, out.type.dtype)))) elem_op = T.Elemwise(scalar.Identity(scalar.specific_out(getattr(scalar, out.type.dtype))))
new = T.fill(out, elem_op(new)) new = elem_op(new)
if new.broadcastable != out.broadcastable: if new.type.broadcastable != out.type.broadcastable:
#this case is tricky... we need to provide exactly the same kind of broadcastable new = T.fill(out, new)
#pattern, but only if legal...
dlen = len(new.broadcastable) - len(out.broadcastable)
if dlen > 0:
#try to take the leading ranks of new.broadcastable, which should be broadcastable
# ranks
#if this means skipping over nonbroadcastable ranks, then DimShuffle will fail
dimshuffle_op = T.DimShuffle(new.broadcastable,
range(dlen, len(new.broadcastable)))
new = dimshuffle_op(new)
elif dlen < 0:
#we have to boost up a scalar or something
dimshuffle_op = T.DimShuffle(new.broadcastable,
['x' for x in range(-dlen)] + range(0, len(new.broadcastable)))
new = dimshuffle_op(new)
# if our if's above worked, this should be true. OTW investigate. # if our if's above worked, this should be true. OTW investigate.
if new.type != out.type: if new.type != out.type:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论