提交 bd654a8b authored 作者: Dustin Webb's avatar Dustin Webb

Incorporated reviewer comments.

上级 f9fc5dfd
......@@ -1534,10 +1534,10 @@ def local_remove_useless_assert(node):
def local_alloc_elemwise(node):
"""
elemwise(alloc(x, shp), ..., y.TensorType(BROADCAST CONDITION))
-> elemwise(x, y.TensorType(no broadcast flag))
-> elemwise(x, y.TensorType(BROADCAST CONDITION))
elemwise(dimshuffle(alloc(x, shp)),... ,y.TensorType(BROADCAST CONDITION))
-> elemwise(x.dimshuffle(...), y.TensorType(no broadcast flag))
-> elemwise(x.dimshuffle(...), y.TensorType(BROADCAST CONDITION))
BROADCAST CONDITION: the condition is that the one input that are
not to be optimized to have the same broadcast pattern as the
......@@ -1596,20 +1596,6 @@ def local_alloc_elemwise(node):
assert_op_idx = 0 # The first one is as good as any to use.
else:
# When there is only one input then we can optimize if the
# broadcast patterns of the input and output match.
i = node.inputs[0]
if i.type.broadcastable == node.outputs[0].type.broadcastable:
new_i = []
if isinstance(i.owner.op, T.Alloc):
new_i.append(i.owner.inputs[0])
elif dimshuffled_alloc(i):
new_i.append(i.owner.inputs[0].owner.inputs[0])
assert(len(new_i) > 0)
return node.op(*new_i,
return_list=True)
# Otherwise nothing can be done.
return False
......@@ -1626,7 +1612,7 @@ def local_alloc_elemwise(node):
assert i.type.ndim == cmp_op.ndim
if (theano.config.experimental.local_alloc_elemwise_assert
and node.fgraph.shape_feature.same_shape(i, cmp_op)):
and not node.fgraph.shape_feature.same_shape(i, cmp_op)):
assert_op = assert_(assert_op,
*[T.eq(i.shape[idx], cmp_op.shape[idx])\
for idx in xrange(i.type.ndim) \
......@@ -1637,7 +1623,7 @@ def local_alloc_elemwise(node):
elif i.owner and dimshuffled_alloc(i):
assert i.type.ndim == cmp_op.type.ndim
if (theano.config.experimental.local_alloc_elemwise_assert
and node.fgraph.shape_feature.same_shape(i, cmp_op)):
and not node.fgraph.shape_feature.same_shape(i, cmp_op)):
assert_op = assert_(assert_op,
*[T.eq(i.shape[idx], cmp_op.shape[idx])
for idx in xrange(i.type.ndim)
......@@ -1655,10 +1641,14 @@ def local_alloc_elemwise(node):
# an alloc all but one is optimized.
theano.configparser.AddConfigVar('experimental.local_alloc_elemwise',
"If True enable the experimental optimization local_alloc_elemwise",
"DEPRECATED: If True, enable the experimental"
" optimization local_alloc_elemwise."
" Generates error if not True. Use"
" optimizer_excluding=local_alloc_elemwise"
" to dsiable.",
theano.configparser.BoolParam(
False,
is_valid=lambda x: return not x
True,
is_valid=lambda x: x
),
in_c_key=False)
#This version if faster but not as safe.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论