提交 113211ed authored 作者: Frederic Bastien's avatar Frederic Bastien
......@@ -335,7 +335,10 @@ class Elemwise(Op):
def __setstate__(self, d):
self.__dict__.update(d)
self.ufunc = numpy.frompyfunc(self.scalar_op.impl, self.scalar_op.nin, self.scalar_op.nout)
if self.scalar_op.nin > 0:
self.ufunc = numpy.frompyfunc(self.scalar_op.impl, self.scalar_op.nin, self.scalar_op.nout)
else:
self.ufunc = None
def make_node(self, *inputs):
"""
......@@ -498,7 +501,13 @@ class Elemwise(Op):
# the first (faster) version leads to segfaults
ufunc_args = inputs # + output_storage
ufunc = self.ufunc or numpy.frompyfunc(self.scalar_op.impl, len(inputs), self.scalar_op.nout)
results = ufunc(*ufunc_args)
try:
results = ufunc(*ufunc_args)
except:
errormsg = 'Failed calling ufunc for op', self.scalar_op,\
'for params of shape', [arg.shape for arg in ufunc_args]
raise Exception, errormsg
if ufunc.nout == 1: results = [results]
for result, storage in zip(results, output_storage):
if storage[0].shape:
......
......@@ -407,7 +407,14 @@ class Canonizer(gof.LocalOptimizer):
def get_num_denum(self, input):
if input.owner is None or input.owner.op not in [self.main, self.inverse, self.reciprocal]:
if input.owner and isinstance(input.owner.op, T.DimShuffle):
return self.get_num_denum(input.owner.inputs[0])
dsn = input.owner
dsop = dsn.op
dsi0 = dsn.inputs[0]
compatible_order = ('x',) * (input.type.ndim - dsi0.type.ndim) + tuple(range(dsi0.type.ndim))
if dsop.new_order == compatible_order:
return self.get_num_denum(input.owner.inputs[0])
else:
return [input], []
else:
return [input], []
num = []
......@@ -507,6 +514,8 @@ class Canonizer(gof.LocalOptimizer):
elif op == self.reciprocal:
reorg = len(iops.intersection([self.inverse, self.reciprocal])) != 0
assert len(node.outputs) == 1
orig_num, orig_denum = self.get_num_denum(node.outputs[0])
num, denum = list(orig_num), list(orig_denum)
num, denum = self.simplify(num, denum)
......@@ -522,6 +531,7 @@ class Canonizer(gof.LocalOptimizer):
#new = T.fill(out, new)
elem_op = T.Elemwise(scalar.Identity(scalar.specific_out(getattr(scalar, out.type.dtype))))
new = T.fill(out, elem_op(new))
if new.broadcastable != out.broadcastable:
#this case is tricky... we need to provide exactly the same kind of broadcastable
#pattern, but only if legal...
......@@ -541,7 +551,10 @@ class Canonizer(gof.LocalOptimizer):
new = dimshuffle_op(new)
# if our if's above worked, this should be true. OTW investigate.
assert new.type == out.type
if new.type != out.type:
print >> sys.stderr, 'CANONIZE FAILED: new out = ', new, out
assert new.type == out.type
return [new]
def __str__(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论