提交 d03c8431 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Stop some unnecessary Constant cloning

上级 dd6da53d
...@@ -603,15 +603,8 @@ class Constant(Variable): ...@@ -603,15 +603,8 @@ class Constant(Variable):
return f"{type(self).__name__}{{{name}}}" return f"{type(self).__name__}{{{name}}}"
def clone(self): def clone(self):
"""Create a shallow clone. """Return `self`, because there's no reason to clone a constant."""
return self
We clone this object, but we don't clone the data to lower memory
requirement. We suppose that the data will never change.
"""
cp = self.__class__(self.type, self.data, self.name)
cp.tag = copy(self.tag)
return cp
def __set_owner(self, value): def __set_owner(self, value):
"""Prevent the :prop:`owner` property from being set. """Prevent the :prop:`owner` property from being set.
......
...@@ -61,7 +61,7 @@ def safe_new( ...@@ -61,7 +61,7 @@ def safe_new(
nwx.tag = copy.copy(x.tag) nwx.tag = copy.copy(x.tag)
return nwx return nwx
else: else:
return x.clone() return x
# Note, `as_tensor_variable` will convert the `Scalar` into a # Note, `as_tensor_variable` will convert the `Scalar` into a
# `TensorScalar` that will require a `ScalarFromTensor` `Op`, making the # `TensorScalar` that will require a `ScalarFromTensor` `Op`, making the
# push-out optimization fail # push-out optimization fail
...@@ -697,14 +697,8 @@ def reconstruct_graph(inputs, outputs, tag=None): ...@@ -697,14 +697,8 @@ def reconstruct_graph(inputs, outputs, tag=None):
if tag is None: if tag is None:
tag = "" tag = ""
nw_inputs = [safe_new(x, tag) for x in inputs] nw_inputs = [safe_new(x, tag) for x in inputs]
givens = OrderedDict()
for nw_x, x in zip(nw_inputs, inputs):
givens[x] = nw_x
allinputs = list(graph_inputs(outputs))
for inp in allinputs:
if isinstance(inp, Constant):
givens[inp] = inp.clone()
givens = {x: nw_x for nw_x, x in zip(nw_inputs, inputs)}
nw_outputs = clone_replace(outputs, replace=givens) nw_outputs = clone_replace(outputs, replace=givens)
return (nw_inputs, nw_outputs) return (nw_inputs, nw_outputs)
......
...@@ -200,13 +200,13 @@ class TestClone(X): ...@@ -200,13 +200,13 @@ class TestClone(X):
c1 = at.constant(1.5) c1 = at.constant(1.5)
i, o = clone([c1], [c1]) i, o = clone([c1], [c1])
assert i[0] is not c1 and o[0] is not c1 assert i[0] is c1 and o[0] is c1
i, o = clone([c1], [c1], False) i, o = clone([c1], [c1], False)
assert i[0] is c1 and o[0] is c1 assert i[0] is c1 and o[0] is c1
i, o = clone([c1], [c1], True, False) i, o = clone([c1], [c1], True, False)
assert i[0] is not c1 and o[0] is not c1 assert i[0] is c1 and o[0] is c1
i, o = clone([c1], [c1], False, True) i, o = clone([c1], [c1], False, True)
assert i[0] is c1 and o[0] is c1 assert i[0] is c1 and o[0] is c1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论