提交 79ccac56 authored 作者: Frederic Bastien's avatar Frederic Bastien

Small code refactor to remove duplicate code. This make sure the new cast is…

Small code refactor to remove duplicate code. This make sure the new cast is also done in both branch.
上级 7173f901
...@@ -5710,23 +5710,24 @@ def local_opt_alloc(node): ...@@ -5710,23 +5710,24 @@ def local_opt_alloc(node):
if node_inps.owner and isinstance(node_inps.owner.op, T.Alloc): if node_inps.owner and isinstance(node_inps.owner.op, T.Alloc):
input = node_inps.owner.inputs[0] input = node_inps.owner.inputs[0]
shapes = node_inps.owner.inputs[1:] shapes = node_inps.owner.inputs[1:]
if (node.op.axis is None or try:
node.op.axis == tuple(range(input.ndim))): val = get_scalar_constant_value(input,
try: only_process_constants=True)
val = get_scalar_constant_value(input, assert val.size == 1
only_process_constants=True) val = val.reshape(1)[0]
assert val.size == 1 # check which type of op
# check which type of op size = T.mul(*shapes)
size = T.mul(*shapes) if input.dtype == "float32":
if input.dtype == "float32": # shapes are ints and normally int64.
# shapes are ints and normally int64. # We don't want to have a float64 upcast here
# We don't want to have a float64 upcast here # if input is a float32.
# if input is a float32. size = size.astype(input.dtype)
size = size.astype(input.dtype) if (node.op.axis is None or
node.op.axis == tuple(range(input.ndim))):
if isinstance(node.op, T.Sum): if isinstance(node.op, T.Sum):
val = val.reshape(1)[0] * size val = val * size
else: else:
val = val.reshape(1)[0] ** size val = val ** size
# Sum can change the input dtype (upcast or bool # Sum can change the input dtype (upcast or bool
# -> float32) by default or by user request. # -> float32) by default or by user request.
# We can ignore the acc_dtype, as there is only 1 # We can ignore the acc_dtype, as there is only 1
...@@ -5736,29 +5737,21 @@ def local_opt_alloc(node): ...@@ -5736,29 +5737,21 @@ def local_opt_alloc(node):
# dtype. # dtype.
val = val.astype(node.outputs[0].dtype) val = val.astype(node.outputs[0].dtype)
return [val] return [val]
to_prod = [shapes[i] for i in xrange(len(shapes))
except NotScalarConstantError: if i in node.op.axis]
pass if to_prod:
else: size = T.mul(*to_prod)
try: if isinstance(node.op, T.Sum):
val = get_scalar_constant_value(input, val *= size
only_process_constants=True) else:
assert val.size == 1 val = val ** size
val = val.reshape(1)[0] # See comments above.
to_prod = [shapes[i] for i in xrange(len(shapes)) val = val.astype(node.outputs[0].dtype)
if i in node.op.axis] return [T.alloc(val,
if to_prod: *[shapes[i] for i in xrange(len(shapes))
size = T.mul(*to_prod) if i not in node.op.axis])]
if isinstance(node.op, T.Sum): except NotScalarConstantError:
val *= size pass
else:
val = val ** size
val = val.astype(node.outputs[0].dtype)
return [T.alloc(val,
*[shapes[i] for i in xrange(len(shapes))
if i not in node.op.axis])]
except NotScalarConstantError:
pass
@register_specialize @register_specialize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论