提交 79ccac56 authored 作者: Frederic Bastien's avatar Frederic Bastien

Small code refactor to remove duplicate code. This make sure the new cast is…

Small code refactor to remove duplicate code. This make sure the new cast is also done in both branch.
上级 7173f901
...@@ -5710,12 +5710,11 @@ def local_opt_alloc(node): ...@@ -5710,12 +5710,11 @@ def local_opt_alloc(node):
if node_inps.owner and isinstance(node_inps.owner.op, T.Alloc): if node_inps.owner and isinstance(node_inps.owner.op, T.Alloc):
input = node_inps.owner.inputs[0] input = node_inps.owner.inputs[0]
shapes = node_inps.owner.inputs[1:] shapes = node_inps.owner.inputs[1:]
if (node.op.axis is None or
node.op.axis == tuple(range(input.ndim))):
try: try:
val = get_scalar_constant_value(input, val = get_scalar_constant_value(input,
only_process_constants=True) only_process_constants=True)
assert val.size == 1 assert val.size == 1
val = val.reshape(1)[0]
# check which type of op # check which type of op
size = T.mul(*shapes) size = T.mul(*shapes)
if input.dtype == "float32": if input.dtype == "float32":
...@@ -5723,10 +5722,12 @@ def local_opt_alloc(node): ...@@ -5723,10 +5722,12 @@ def local_opt_alloc(node):
# We don't want to have a float64 upcast here # We don't want to have a float64 upcast here
# if input is a float32. # if input is a float32.
size = size.astype(input.dtype) size = size.astype(input.dtype)
if (node.op.axis is None or
node.op.axis == tuple(range(input.ndim))):
if isinstance(node.op, T.Sum): if isinstance(node.op, T.Sum):
val = val.reshape(1)[0] * size val = val * size
else: else:
val = val.reshape(1)[0] ** size val = val ** size
# Sum can change the input dtype (upcast or bool # Sum can change the input dtype (upcast or bool
# -> float32) by default or by user request. # -> float32) by default or by user request.
# We can ignore the acc_dtype, as there is only 1 # We can ignore the acc_dtype, as there is only 1
...@@ -5736,15 +5737,6 @@ def local_opt_alloc(node): ...@@ -5736,15 +5737,6 @@ def local_opt_alloc(node):
# dtype. # dtype.
val = val.astype(node.outputs[0].dtype) val = val.astype(node.outputs[0].dtype)
return [val] return [val]
except NotScalarConstantError:
pass
else:
try:
val = get_scalar_constant_value(input,
only_process_constants=True)
assert val.size == 1
val = val.reshape(1)[0]
to_prod = [shapes[i] for i in xrange(len(shapes)) to_prod = [shapes[i] for i in xrange(len(shapes))
if i in node.op.axis] if i in node.op.axis]
if to_prod: if to_prod:
...@@ -5753,6 +5745,7 @@ def local_opt_alloc(node): ...@@ -5753,6 +5745,7 @@ def local_opt_alloc(node):
val *= size val *= size
else: else:
val = val ** size val = val ** size
# See comments above.
val = val.astype(node.outputs[0].dtype) val = val.astype(node.outputs[0].dtype)
return [T.alloc(val, return [T.alloc(val,
*[shapes[i] for i in xrange(len(shapes)) *[shapes[i] for i in xrange(len(shapes))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论