提交 79ccac56 authored 作者: Frederic Bastien's avatar Frederic Bastien

Small code refactor to remove duplicate code. This make sure the new cast is…

Small code refactor to remove duplicate code. This make sure the new cast is also done in both branch.
上级 7173f901
......@@ -5710,23 +5710,24 @@ def local_opt_alloc(node):
if node_inps.owner and isinstance(node_inps.owner.op, T.Alloc):
input = node_inps.owner.inputs[0]
shapes = node_inps.owner.inputs[1:]
if (node.op.axis is None or
node.op.axis == tuple(range(input.ndim))):
try:
val = get_scalar_constant_value(input,
only_process_constants=True)
assert val.size == 1
# check which type of op
size = T.mul(*shapes)
if input.dtype == "float32":
# shapes are ints and normally int64.
# We don't want to have a float64 upcast here
# if input is a float32.
size = size.astype(input.dtype)
try:
val = get_scalar_constant_value(input,
only_process_constants=True)
assert val.size == 1
val = val.reshape(1)[0]
# check which type of op
size = T.mul(*shapes)
if input.dtype == "float32":
# shapes are ints and normally int64.
# We don't want to have a float64 upcast here
# if input is a float32.
size = size.astype(input.dtype)
if (node.op.axis is None or
node.op.axis == tuple(range(input.ndim))):
if isinstance(node.op, T.Sum):
val = val.reshape(1)[0] * size
val = val * size
else:
val = val.reshape(1)[0] ** size
val = val ** size
# Sum can change the input dtype (upcast or bool
# -> float32) by default or by user request.
# We can ignore the acc_dtype, as there is only 1
......@@ -5736,29 +5737,21 @@ def local_opt_alloc(node):
# dtype.
val = val.astype(node.outputs[0].dtype)
return [val]
except NotScalarConstantError:
pass
else:
try:
val = get_scalar_constant_value(input,
only_process_constants=True)
assert val.size == 1
val = val.reshape(1)[0]
to_prod = [shapes[i] for i in xrange(len(shapes))
if i in node.op.axis]
if to_prod:
size = T.mul(*to_prod)
if isinstance(node.op, T.Sum):
val *= size
else:
val = val ** size
val = val.astype(node.outputs[0].dtype)
return [T.alloc(val,
*[shapes[i] for i in xrange(len(shapes))
if i not in node.op.axis])]
except NotScalarConstantError:
pass
to_prod = [shapes[i] for i in xrange(len(shapes))
if i in node.op.axis]
if to_prod:
size = T.mul(*to_prod)
if isinstance(node.op, T.Sum):
val *= size
else:
val = val ** size
# See comments above.
val = val.astype(node.outputs[0].dtype)
return [T.alloc(val,
*[shapes[i] for i in xrange(len(shapes))
if i not in node.op.axis])]
except NotScalarConstantError:
pass
@register_specialize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论