提交 7af51a52 authored 作者: abergeron's avatar abergeron

Merge pull request #1931 from nouiz/eq_computation

Crash fix equal_computations() again and scan opt speed up.
...@@ -20,6 +20,12 @@ The most frequent way to control the number of threads used is via the ...@@ -20,6 +20,12 @@ The most frequent way to control the number of threads used is via the
threads you want to use before starting the Python process. Some BLAS threads you want to use before starting the Python process. Some BLAS
implementations support other environment variables. implementations support other environment variables.
To test if you BLAS support OpenMP/Multiple cores, you can use the theano/misc/check_blas.py scripts from the command line like this::
OMP_NUM_THREAD=1 python theano/misc/check_blas.py -q
OMP_NUM_THREAD=2 python theano/misc/check_blas.py -q
Parallel element wise ops with OpenMP Parallel element wise ops with OpenMP
===================================== =====================================
...@@ -46,5 +52,13 @@ a slow one) for a vector of size ``openmp_elemwise_minsize`` with and ...@@ -46,5 +52,13 @@ a slow one) for a vector of size ``openmp_elemwise_minsize`` with and
without OpenMP and shows the time difference between the cases. without OpenMP and shows the time difference between the cases.
The only way to control the number of threads used is via the The only way to control the number of threads used is via the
``OMP_NUM_THREADS`` environment variable. Set it to the number of threads ``OMP_NUM_THREADS`` environment variable. Set it to the number of
you want to use before starting the Python process. threads you want to use before starting the Python process. You can
test this with this command::
$OMP_NUM_THREADS=2 python theano/misc/elemwise_openmp_speedup.py
#The output
Fast op time without openmp 0.000533s with openmp 0.000474s speedup 1.12
Slow op time without openmp 0.002987s with openmp 0.001553s speedup 1.92
...@@ -658,6 +658,7 @@ class MergeOptimizer(Optimizer): ...@@ -658,6 +658,7 @@ class MergeOptimizer(Optimizer):
print >> stream, blanc, " replace_time", replace_time print >> stream, blanc, " replace_time", replace_time
print >> stream, blanc, " validate_time", validate_time print >> stream, blanc, " validate_time", validate_time
print >> stream, blanc, " callback_time", callback_time print >> stream, blanc, " callback_time", callback_time
if callback_time > 1:
print >> stream, blanc, " callbacks_time" print >> stream, blanc, " callbacks_time"
for i in sorted(callbacks_time.iteritems(), key=lambda a: a[1]): for i in sorted(callbacks_time.iteritems(), key=lambda a: a[1]):
if i[1] > 0: if i[1] > 0:
......
...@@ -69,7 +69,9 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -69,7 +69,9 @@ def remove_constants_and_unused_inputs_scan(node):
op.tap_array[:(op.n_mit_mot + op.n_mit_sot)]])) op.tap_array[:(op.n_mit_mot + op.n_mit_sot)]]))
st += op.n_sit_sot st += op.n_sit_sot
st += op.n_shared_outs st += op.n_shared_outs
op_ins, op_outs = scan_utils.reconstruct_graph(op.inputs, op.outputs)
op_ins = op.inputs
op_outs = op.outputs
# Corresponds to the initial states, which should stay untouched. # Corresponds to the initial states, which should stay untouched.
# We put those variables aside, and put them back at the end. # We put those variables aside, and put them back at the end.
...@@ -94,25 +96,26 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -94,25 +96,26 @@ def remove_constants_and_unused_inputs_scan(node):
all_ins = gof.graph.inputs(op_outs) all_ins = gof.graph.inputs(op_outs)
for idx in xrange(op.n_seqs): for idx in xrange(op.n_seqs):
if (isinstance(node.inputs[idx + 1], tensor.TensorConstant) and node_inp = node.inputs[idx + 1]
node.inputs[idx + 1].tag.unique_value is not None): if (isinstance(node_inp, tensor.TensorConstant) and
node_inp.tag.unique_value is not None):
try: try:
# This works if input is a constant that has all entries # This works if input is a constant that has all entries
# equal # equal
givens[op_ins[idx]] = node.inputs[idx + 1].clone()[0] givens[op_ins[idx]] = node_inp.clone()[0]
except TypeError: except TypeError:
pass pass
elif op_ins[idx] in all_ins: elif op_ins[idx] in all_ins:
# Check for identical other sequence # Check for identical other sequence
identical_seqs = [x for x in nw_outer identical_seqs = [x for x in nw_outer
if scan_utils.equal_computations( if scan_utils.equal_computations(
[x], [node.inputs[idx + 1]])] [x], [node_inp])]
if identical_seqs: if identical_seqs:
index = node.inputs.index(identical_seqs[0]) - 1 index = node.inputs.index(identical_seqs[0]) - 1
givens[op_ins[idx]] = op_ins[index] givens[op_ins[idx]] = op_ins[index]
else: else:
nw_inner += [op_ins[idx]] nw_inner += [op_ins[idx]]
nw_outer += [node.inputs[idx + 1]] nw_outer += [node_inp]
nw_n_seqs = len(nw_inner) nw_n_seqs = len(nw_inner)
# Add outputs stuff # Add outputs stuff
......
...@@ -391,6 +391,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -391,6 +391,7 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
or `ys`. or `ys`.
''' '''
assert len(xs) == len(ys)
if in_xs is None: if in_xs is None:
in_xs = [] in_xs = []
if in_ys is None: if in_ys is None:
...@@ -401,47 +402,46 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -401,47 +402,46 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
return False return False
if y.owner and not x.owner: if y.owner and not x.owner:
return False return False
if x.owner and y.owner: if x.owner: # Check above tell that y.owner eval to True too.
if x.owner.outputs.index(x) != y.owner.outputs.index(y): if x.owner.outputs.index(x) != y.owner.outputs.index(y):
return False return False
if x not in in_xs and x.type != y.type:
return False
if len(in_xs) != len(in_ys): if len(in_xs) != len(in_ys):
return False return False
for _x, _y in izip(in_xs, in_ys): for _x, _y in izip(in_xs, in_ys):
if _x.type != _y.type: if _x.type != _y.type:
return False return False
nds_x = gof.graph.io_toposort(in_xs, xs)
nds_y = gof.graph.io_toposort(in_ys, ys)
if len(nds_x) != len(nds_y):
return False
common = set(zip(in_xs, in_ys)) common = set(zip(in_xs, in_ys))
n_nodes = len(nds_x)
cont = True
idx = 0
for dx, dy in izip(xs, ys): for dx, dy in izip(xs, ys):
if not dx.owner or not dy.owner: # We checked above that both dx and dy have an owner or not
if dy.owner or dx.owner: if not dx.owner:
return False if (isinstance(dx, tensor.Constant) and
elif (isinstance(dx, tensor.Constant) and
isinstance(dy, tensor.Constant)): isinstance(dy, tensor.Constant)):
if not (numpy.all(dx.data == dy.data) and if not dx.equals(dy):
dx.type.dtype == dy.type.dtype and
dx.data.shape == dy.data.shape):
return False return False
else: else:
pass pass
elif (dx, dy) not in common and dx != dy: elif (dx, dy) not in common and dx != dy:
return False return False
while cont and idx < n_nodes: nds_x = gof.graph.io_toposort(in_xs, xs)
nds_y = gof.graph.io_toposort(in_ys, ys)
if len(nds_x) != len(nds_y):
return False
n_nodes = len(nds_x)
idx = 0
while idx < n_nodes:
nd_x = nds_x[idx] nd_x = nds_x[idx]
nd_y = nds_y[idx] nd_y = nds_y[idx]
if nd_x.op != nd_y.op: if nd_x.op != nd_y.op:
cont = False return False
elif len(nd_x.inputs) != len(nd_y.inputs): elif len(nd_x.inputs) != len(nd_y.inputs):
cont = False return False
elif len(nd_x.outputs) != len(nd_y.outputs): elif len(nd_x.outputs) != len(nd_y.outputs):
cont = False return False
else: else:
for dx, dy in izip(nd_x.inputs, nd_y.inputs): for dx, dy in izip(nd_x.inputs, nd_y.inputs):
if (dx, dy) not in common: if (dx, dy) not in common:
...@@ -453,14 +453,13 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None): ...@@ -453,14 +453,13 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
else: else:
pass pass
else: else:
cont = False return False
if cont:
for dx, dy in izip(nd_x.outputs, nd_y.outputs): for dx, dy in izip(nd_x.outputs, nd_y.outputs):
common.add((dx, dy)) common.add((dx, dy))
idx += 1 idx += 1
return cont return True
def infer_shape(outs, inputs, input_shapes): def infer_shape(outs, inputs, input_shapes):
......
import theano
from theano.scan_module.scan_utils import equal_computations
from theano.tensor.type_other import NoneConst
def test_equal_compuations():
# This was a bug report by a Theano user.
c = NoneConst
assert equal_computations([c], [c])
m = theano.tensor.matrix()
max_argmax1 = theano.tensor.max_and_argmax(m)
max_argmax2 = theano.tensor.max_and_argmax(m)
assert equal_computations(max_argmax1, max_argmax2)
...@@ -1581,6 +1581,7 @@ class GemmOptimizer(Optimizer): ...@@ -1581,6 +1581,7 @@ class GemmOptimizer(Optimizer):
print >> stream, blanc, " time_toposort", prof[9] print >> stream, blanc, " time_toposort", prof[9]
print >> stream, blanc, " validate_time", prof[10] print >> stream, blanc, " validate_time", prof[10]
print >> stream, blanc, " callback_time", prof[11] print >> stream, blanc, " callback_time", prof[11]
if prof[11] > 1:
print >> stream, blanc, " callbacks_time" print >> stream, blanc, " callbacks_time"
for i in sorted(prof[12].iteritems(), key=lambda a: a[1]): for i in sorted(prof[12].iteritems(), key=lambda a: a[1]):
if i[1] > 0: if i[1] > 0:
......
...@@ -2526,6 +2526,10 @@ def local_reshape_lift(node): ...@@ -2526,6 +2526,10 @@ def local_reshape_lift(node):
len(node.inputs[0].owner.inputs) == 1): len(node.inputs[0].owner.inputs) == 1):
r = node.op(node.inputs[0].owner.inputs[0], node.inputs[1]) r = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
e = node.inputs[0].owner.op(r) e = node.inputs[0].owner.op(r)
# In rare case the original broadcast was (False, True), but
# the new one is (False, False). So don't crash in that case.
if e.type != node.outputs[0].type:
e = T.patternbroadcast(e, node.outputs[0].broadcastable)
return [e] return [e]
...@@ -4937,6 +4941,7 @@ class FusionOptimizer(Optimizer): ...@@ -4937,6 +4941,7 @@ class FusionOptimizer(Optimizer):
print >> stream, blanc, " nb_inconsistency_replace", prof[3] print >> stream, blanc, " nb_inconsistency_replace", prof[3]
print >> stream, blanc, " validate_time", prof[4] print >> stream, blanc, " validate_time", prof[4]
print >> stream, blanc, " callback_time", prof[5] print >> stream, blanc, " callback_time", prof[5]
if prof[5] > 1:
print >> stream, blanc, " callbacks_time" print >> stream, blanc, " callbacks_time"
for i in sorted(prof[6].iteritems(), key=lambda a: a[1]): for i in sorted(prof[6].iteritems(), key=lambda a: a[1]):
if i[1] > 0: if i[1] > 0:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论