提交 32105241 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3432 from nouiz/crash

crash fix and cgemv optimization
...@@ -843,19 +843,13 @@ class Function(object): ...@@ -843,19 +843,13 @@ class Function(object):
# this is a new vm-provided function or c linker # this is a new vm-provided function or c linker
# they need this because the exception manipulation # they need this because the exception manipulation
# done by raise_with_op is not implemented in C. # done by raise_with_op is not implemented in C.
thunk = None
if hasattr(self.fn, 'thunks'): if hasattr(self.fn, 'thunks'):
# For the CVM thunk = self.fn.thunks[self.fn.position_of_error]
gof.link.raise_with_op( gof.link.raise_with_op(
self.fn.nodes[self.fn.position_of_error], node=self.fn.nodes[self.fn.position_of_error],
self.fn.thunks[self.fn.position_of_error], thunk=thunk,
storage_map=self.fn.storage_map) storage_map=getattr(self.fn, 'storage_map', None))
else:
# For the c linker We don't have access from
# python to all the temps values So for now, we
# just don't print the extra shapes/strides info
gof.link.raise_with_op(
self.fn.nodes[self.fn.position_of_error],
storage_map=self.fn.storage_map)
else: else:
# old-style linkers raise their own exceptions # old-style linkers raise their own exceptions
raise raise
......
...@@ -280,8 +280,8 @@ class SeqOptimizer(Optimizer, list): ...@@ -280,8 +280,8 @@ class SeqOptimizer(Optimizer, list):
print((" time %.3fs for %d/%d nodes" print((" time %.3fs for %d/%d nodes"
" before/after optimization" % ( " before/after optimization" % (
sum(prof), nb_node_before, nb_node_after)), file=stream) sum(prof), nb_node_before, nb_node_after)), file=stream)
print(blanc, " %.3fs for fgraph.validate()" % (validate_time), file=stream)
print(blanc, " %.3fs for callback" % (callback_time), file=stream) print(blanc, " %.3fs for callback" % (callback_time), file=stream)
print(blanc, " %.3fs for fgraph.validate()" % (validate_time), file=stream)
if level == 0: if level == 0:
print(blanc, " time - (name, class, index) - validate time", file=stream) print(blanc, " time - (name, class, index) - validate time", file=stream)
ll = [] ll = []
......
...@@ -3105,7 +3105,9 @@ def std(input, axis=None, keepdims=False): ...@@ -3105,7 +3105,9 @@ def std(input, axis=None, keepdims=False):
""" """
return sqrt(var(input=input, axis=axis, keepdims=keepdims)) ret = sqrt(var(input=input, axis=axis, keepdims=keepdims))
ret.name = 'std'
return ret
class Default(gof.Op): class Default(gof.Op):
......
...@@ -799,7 +799,14 @@ def use_c_gemv(node): ...@@ -799,7 +799,14 @@ def use_c_gemv(node):
@local_optimizer([CGemv(inplace=False)]) @local_optimizer([CGemv(inplace=False)])
def make_c_gemv_destructive(node): def make_c_gemv_destructive(node):
if isinstance(node.op, CGemv) and not node.op.inplace: if isinstance(node.op, CGemv) and not node.op.inplace:
return [cgemv_inplace(*node.inputs)] inputs = list(node.inputs)
dest = inputs[0]
if (dest.owner and
isinstance(dest.owner.op, T.AllocEmpty) and
len(dest.clients) > 1):
inputs[0] = T.AllocEmpty(dest.dtype)(*dest.owner.inputs)
return [cgemv_inplace(*inputs)]
# ##### ####### ####### # ##### ####### #######
......
...@@ -256,6 +256,22 @@ class TestCGemv(TestCase, TestOptimizationMixin): ...@@ -256,6 +256,22 @@ class TestCGemv(TestCase, TestOptimizationMixin):
self.assertRaises(ValueError, f, A_val, ones_3, ones_6) self.assertRaises(ValueError, f, A_val, ones_3, ones_6)
self.assertRaises(ValueError, f, A_val, ones_4, ones_6) self.assertRaises(ValueError, f, A_val, ones_4, ones_6)
def test_multiple_inplace(self):
x = tensor.dmatrix('x')
y = tensor.dvector('y')
z = tensor.dvector('z')
f = theano.function([x, y, z],
[tensor.dot(y, x), tensor.dot(z,x)],
mode=mode_blas_opt)
vx = numpy.random.rand(3, 3)
vy = numpy.random.rand(3)
vz = numpy.random.rand(3)
out = f(vx, vy, vz)
assert numpy.allclose(out[0], numpy.dot(vy, vx))
assert numpy.allclose(out[1], numpy.dot(vz, vx))
assert len([n for n in f.maker.fgraph.apply_nodes
if isinstance(n.op, tensor.AllocEmpty)]) == 2
class TestCGemvFloat32(TestCase, BaseGemv, TestOptimizationMixin): class TestCGemvFloat32(TestCase, BaseGemv, TestOptimizationMixin):
mode = mode_blas_opt mode = mode_blas_opt
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论