提交 303cc80d authored 作者: Mohammad Pezeshki's avatar Mohammad Pezeshki

compatible dimes are now fixed

上级 ae53db81
...@@ -4162,18 +4162,22 @@ def local_sum_prod_div_dimshuffle(node): ...@@ -4162,18 +4162,22 @@ def local_sum_prod_div_dimshuffle(node):
isinstance(numerator.owner.op, T.DimShuffle)): isinstance(numerator.owner.op, T.DimShuffle)):
# Check compatibility # Check compatibility
new_order = numerator.owner.op.new_order new_order = numerator.owner.op.new_order
compatible_dims = True
for ax in axis: for ax in axis:
if ax < len(new_order) and new_order[ax] == 'x': if len(new_order) <= ax or new_order[ax] != 'x':
_logger.warn('WARNING: Your current code is fine, but' compatible_dims = False
' Theano versions between '
'rev. 3bd9b789f5e8 (2010-06-16) and'
' cfc6322e5ad4 (2010-08-03) would '
'have given an incorrect result. '
'To disable this warning, set the Theano'
' flag warn.sum_div_dimshuffle_bug to'
' False.')
break break
if compatible_dims:
_logger.warn('WARNING: Your current code is fine, but'
' Theano versions between '
'rev. 3bd9b789f5e8 (2010-06-16) and'
' cfc6322e5ad4 (2010-08-03) would '
'have given an incorrect result. '
'To disable this warning, set the Theano'
' flag warn.sum_div_dimshuffle_bug to'
' False.')
if denominator.owner and isinstance(denominator.owner.op, if denominator.owner and isinstance(denominator.owner.op,
T.DimShuffle): T.DimShuffle):
dimshuffle_input = denominator.owner.inputs[0] dimshuffle_input = denominator.owner.inputs[0]
...@@ -4237,10 +4241,11 @@ def local_sum_prod_div_dimshuffle(node): ...@@ -4237,10 +4241,11 @@ def local_sum_prod_div_dimshuffle(node):
elif isinstance(node.op, T.elemwise.Prod): elif isinstance(node.op, T.elemwise.Prod):
op_on_compatible_dims = T.prod( op_on_compatible_dims = T.prod(
numerator, axis=compatible_dims) numerator, axis=compatible_dims)
dtype = numerator.dtype
div_op = T.true_div( div_op = T.true_div(
op_on_compatible_dims, op_on_compatible_dims,
(optimized_dimshuffle ** (optimized_dimshuffle **
T.prod([numerator.shape[ax] T.prod([numerator.shape[ax].astype(dtype)
for ax in compatible_dims]))) for ax in compatible_dims])))
op_on_incompatible_dims = T.prod( op_on_incompatible_dims = T.prod(
div_op, div_op,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论