提交 113211ed authored 作者: Frederic Bastien's avatar Frederic Bastien
...@@ -335,7 +335,10 @@ class Elemwise(Op): ...@@ -335,7 +335,10 @@ class Elemwise(Op):
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
self.ufunc = numpy.frompyfunc(self.scalar_op.impl, self.scalar_op.nin, self.scalar_op.nout) if self.scalar_op.nin > 0:
self.ufunc = numpy.frompyfunc(self.scalar_op.impl, self.scalar_op.nin, self.scalar_op.nout)
else:
self.ufunc = None
def make_node(self, *inputs): def make_node(self, *inputs):
""" """
...@@ -498,7 +501,13 @@ class Elemwise(Op): ...@@ -498,7 +501,13 @@ class Elemwise(Op):
# the first (faster) version leads to segfaults # the first (faster) version leads to segfaults
ufunc_args = inputs # + output_storage ufunc_args = inputs # + output_storage
ufunc = self.ufunc or numpy.frompyfunc(self.scalar_op.impl, len(inputs), self.scalar_op.nout) ufunc = self.ufunc or numpy.frompyfunc(self.scalar_op.impl, len(inputs), self.scalar_op.nout)
results = ufunc(*ufunc_args)
try:
results = ufunc(*ufunc_args)
except:
errormsg = 'Failed calling ufunc for op', self.scalar_op,\
'for params of shape', [arg.shape for arg in ufunc_args]
raise Exception, errormsg
if ufunc.nout == 1: results = [results] if ufunc.nout == 1: results = [results]
for result, storage in zip(results, output_storage): for result, storage in zip(results, output_storage):
if storage[0].shape: if storage[0].shape:
......
...@@ -407,7 +407,14 @@ class Canonizer(gof.LocalOptimizer): ...@@ -407,7 +407,14 @@ class Canonizer(gof.LocalOptimizer):
def get_num_denum(self, input): def get_num_denum(self, input):
if input.owner is None or input.owner.op not in [self.main, self.inverse, self.reciprocal]: if input.owner is None or input.owner.op not in [self.main, self.inverse, self.reciprocal]:
if input.owner and isinstance(input.owner.op, T.DimShuffle): if input.owner and isinstance(input.owner.op, T.DimShuffle):
return self.get_num_denum(input.owner.inputs[0]) dsn = input.owner
dsop = dsn.op
dsi0 = dsn.inputs[0]
compatible_order = ('x',) * (input.type.ndim - dsi0.type.ndim) + tuple(range(dsi0.type.ndim))
if dsop.new_order == compatible_order:
return self.get_num_denum(input.owner.inputs[0])
else:
return [input], []
else: else:
return [input], [] return [input], []
num = [] num = []
...@@ -507,6 +514,8 @@ class Canonizer(gof.LocalOptimizer): ...@@ -507,6 +514,8 @@ class Canonizer(gof.LocalOptimizer):
elif op == self.reciprocal: elif op == self.reciprocal:
reorg = len(iops.intersection([self.inverse, self.reciprocal])) != 0 reorg = len(iops.intersection([self.inverse, self.reciprocal])) != 0
assert len(node.outputs) == 1
orig_num, orig_denum = self.get_num_denum(node.outputs[0]) orig_num, orig_denum = self.get_num_denum(node.outputs[0])
num, denum = list(orig_num), list(orig_denum) num, denum = list(orig_num), list(orig_denum)
num, denum = self.simplify(num, denum) num, denum = self.simplify(num, denum)
...@@ -522,6 +531,7 @@ class Canonizer(gof.LocalOptimizer): ...@@ -522,6 +531,7 @@ class Canonizer(gof.LocalOptimizer):
#new = T.fill(out, new) #new = T.fill(out, new)
elem_op = T.Elemwise(scalar.Identity(scalar.specific_out(getattr(scalar, out.type.dtype)))) elem_op = T.Elemwise(scalar.Identity(scalar.specific_out(getattr(scalar, out.type.dtype))))
new = T.fill(out, elem_op(new)) new = T.fill(out, elem_op(new))
if new.broadcastable != out.broadcastable: if new.broadcastable != out.broadcastable:
#this case is tricky... we need to provide exactly the same kind of broadcastable #this case is tricky... we need to provide exactly the same kind of broadcastable
#pattern, but only if legal... #pattern, but only if legal...
...@@ -541,7 +551,10 @@ class Canonizer(gof.LocalOptimizer): ...@@ -541,7 +551,10 @@ class Canonizer(gof.LocalOptimizer):
new = dimshuffle_op(new) new = dimshuffle_op(new)
# if our if's above worked, this should be true. OTW investigate. # if our if's above worked, this should be true. OTW investigate.
assert new.type == out.type if new.type != out.type:
print >> sys.stderr, 'CANONIZE FAILED: new out = ', new, out
assert new.type == out.type
return [new] return [new]
def __str__(self): def __str__(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论