提交 32105241 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3432 from nouiz/crash

crash fix and cgemv optimization
......@@ -843,19 +843,13 @@ class Function(object):
# this is a new vm-provided function or c linker
# they need this because the exception manipulation
# done by raise_with_op is not implemented in C.
thunk = None
if hasattr(self.fn, 'thunks'):
# For the CVM
gof.link.raise_with_op(
self.fn.nodes[self.fn.position_of_error],
self.fn.thunks[self.fn.position_of_error],
storage_map=self.fn.storage_map)
else:
# For the c linker We don't have access from
# python to all the temps values So for now, we
# just don't print the extra shapes/strides info
gof.link.raise_with_op(
self.fn.nodes[self.fn.position_of_error],
storage_map=self.fn.storage_map)
thunk = self.fn.thunks[self.fn.position_of_error]
gof.link.raise_with_op(
node=self.fn.nodes[self.fn.position_of_error],
thunk=thunk,
storage_map=getattr(self.fn, 'storage_map', None))
else:
# old-style linkers raise their own exceptions
raise
......
......@@ -280,8 +280,8 @@ class SeqOptimizer(Optimizer, list):
print((" time %.3fs for %d/%d nodes"
" before/after optimization" % (
sum(prof), nb_node_before, nb_node_after)), file=stream)
print(blanc, " %.3fs for fgraph.validate()" % (validate_time), file=stream)
print(blanc, " %.3fs for callback" % (callback_time), file=stream)
print(blanc, " %.3fs for fgraph.validate()" % (validate_time), file=stream)
if level == 0:
print(blanc, " time - (name, class, index) - validate time", file=stream)
ll = []
......
......@@ -3105,7 +3105,9 @@ def std(input, axis=None, keepdims=False):
"""
return sqrt(var(input=input, axis=axis, keepdims=keepdims))
ret = sqrt(var(input=input, axis=axis, keepdims=keepdims))
ret.name = 'std'
return ret
class Default(gof.Op):
......
......@@ -799,7 +799,14 @@ def use_c_gemv(node):
@local_optimizer([CGemv(inplace=False)])
def make_c_gemv_destructive(node):
if isinstance(node.op, CGemv) and not node.op.inplace:
return [cgemv_inplace(*node.inputs)]
inputs = list(node.inputs)
dest = inputs[0]
if (dest.owner and
isinstance(dest.owner.op, T.AllocEmpty) and
len(dest.clients) > 1):
inputs[0] = T.AllocEmpty(dest.dtype)(*dest.owner.inputs)
return [cgemv_inplace(*inputs)]
# ##### ####### #######
......
......@@ -256,6 +256,22 @@ class TestCGemv(TestCase, TestOptimizationMixin):
self.assertRaises(ValueError, f, A_val, ones_3, ones_6)
self.assertRaises(ValueError, f, A_val, ones_4, ones_6)
def test_multiple_inplace(self):
x = tensor.dmatrix('x')
y = tensor.dvector('y')
z = tensor.dvector('z')
f = theano.function([x, y, z],
[tensor.dot(y, x), tensor.dot(z,x)],
mode=mode_blas_opt)
vx = numpy.random.rand(3, 3)
vy = numpy.random.rand(3)
vz = numpy.random.rand(3)
out = f(vx, vy, vz)
assert numpy.allclose(out[0], numpy.dot(vy, vx))
assert numpy.allclose(out[1], numpy.dot(vz, vx))
assert len([n for n in f.maker.fgraph.apply_nodes
if isinstance(n.op, tensor.AllocEmpty)]) == 2
class TestCGemvFloat32(TestCase, BaseGemv, TestOptimizationMixin):
mode = mode_blas_opt
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论