提交 d03c8431 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Stop some unnecessary Constant cloning

上级 dd6da53d
......@@ -603,15 +603,8 @@ class Constant(Variable):
return f"{type(self).__name__}{{{name}}}"
def clone(self):
"""Create a shallow clone.
We clone this object, but we don't clone the data to lower memory
requirement. We suppose that the data will never change.
"""
cp = self.__class__(self.type, self.data, self.name)
cp.tag = copy(self.tag)
return cp
"""Return `self`, because there's no reason to clone a constant."""
return self
def __set_owner(self, value):
"""Prevent the :prop:`owner` property from being set.
......
......@@ -61,7 +61,7 @@ def safe_new(
nwx.tag = copy.copy(x.tag)
return nwx
else:
return x.clone()
return x
# Note, `as_tensor_variable` will convert the `Scalar` into a
# `TensorScalar` that will require a `ScalarFromTensor` `Op`, making the
# push-out optimization fail
......@@ -697,14 +697,8 @@ def reconstruct_graph(inputs, outputs, tag=None):
if tag is None:
tag = ""
nw_inputs = [safe_new(x, tag) for x in inputs]
givens = OrderedDict()
for nw_x, x in zip(nw_inputs, inputs):
givens[x] = nw_x
allinputs = list(graph_inputs(outputs))
for inp in allinputs:
if isinstance(inp, Constant):
givens[inp] = inp.clone()
givens = {x: nw_x for nw_x, x in zip(nw_inputs, inputs)}
nw_outputs = clone_replace(outputs, replace=givens)
return (nw_inputs, nw_outputs)
......
......@@ -200,13 +200,13 @@ class TestClone(X):
c1 = at.constant(1.5)
i, o = clone([c1], [c1])
assert i[0] is not c1 and o[0] is not c1
assert i[0] is c1 and o[0] is c1
i, o = clone([c1], [c1], False)
assert i[0] is c1 and o[0] is c1
i, o = clone([c1], [c1], True, False)
assert i[0] is not c1 and o[0] is not c1
assert i[0] is c1 and o[0] is c1
i, o = clone([c1], [c1], False, True)
assert i[0] is c1 and o[0] is c1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论