提交 303cc80d authored 作者: Mohammad Pezeshki's avatar Mohammad Pezeshki

compatible dimes are now fixed

上级 ae53db81
......@@ -4162,18 +4162,22 @@ def local_sum_prod_div_dimshuffle(node):
isinstance(numerator.owner.op, T.DimShuffle)):
# Check compatibility
new_order = numerator.owner.op.new_order
compatible_dims = True
for ax in axis:
if ax < len(new_order) and new_order[ax] == 'x':
_logger.warn('WARNING: Your current code is fine, but'
' Theano versions between '
'rev. 3bd9b789f5e8 (2010-06-16) and'
' cfc6322e5ad4 (2010-08-03) would '
'have given an incorrect result. '
'To disable this warning, set the Theano'
' flag warn.sum_div_dimshuffle_bug to'
' False.')
if len(new_order) <= ax or new_order[ax] != 'x':
compatible_dims = False
break
if compatible_dims:
_logger.warn('WARNING: Your current code is fine, but'
' Theano versions between '
'rev. 3bd9b789f5e8 (2010-06-16) and'
' cfc6322e5ad4 (2010-08-03) would '
'have given an incorrect result. '
'To disable this warning, set the Theano'
' flag warn.sum_div_dimshuffle_bug to'
' False.')
if denominator.owner and isinstance(denominator.owner.op,
T.DimShuffle):
dimshuffle_input = denominator.owner.inputs[0]
......@@ -4237,10 +4241,11 @@ def local_sum_prod_div_dimshuffle(node):
elif isinstance(node.op, T.elemwise.Prod):
op_on_compatible_dims = T.prod(
numerator, axis=compatible_dims)
dtype = numerator.dtype
div_op = T.true_div(
op_on_compatible_dims,
(optimized_dimshuffle **
T.prod([numerator.shape[ax]
T.prod([numerator.shape[ax].astype(dtype)
for ax in compatible_dims])))
op_on_incompatible_dims = T.prod(
div_op,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论