提交 27795a97 authored 作者: Ying Zhang's avatar Ying Zhang

Add a 'name' argument to '.copy' method of Theano variables

上级 c042a9c4
...@@ -2815,8 +2815,11 @@ alloc = Alloc() ...@@ -2815,8 +2815,11 @@ alloc = Alloc()
pprint.assign(alloc, printing.FunctionPrinter('alloc')) pprint.assign(alloc, printing.FunctionPrinter('alloc'))
"""Create a duplicate of `a` (with duplicated storage)""" @constructor
tensor_copy = elemwise.Elemwise(scal.identity) def tensor_copy(input):
"""Create a duplicate of a tensor `input` with duplicated storage."""
return elemwise.Elemwise(scal.identity)(input)
pprint.assign(tensor_copy, printing.IgnorePrinter()) pprint.assign(tensor_copy, printing.IgnorePrinter())
......
import numpy as np import numpy as np
from numpy.testing import assert_equal, assert_string_equal
import theano import theano
import theano.tensor as tt import theano.tensor as tt
...@@ -19,3 +20,14 @@ def test_numpy_method(): ...@@ -19,3 +20,14 @@ def test_numpy_method():
f = theano.function([x], y) f = theano.function([x], y)
utt.assert_allclose(np.nan_to_num(f(data)), utt.assert_allclose(np.nan_to_num(f(data)),
np.nan_to_num(fct(data))) np.nan_to_num(fct(data)))
def test_copy():
x = tt.matrix('x')
data = np.random.rand(5, 5)
y = x.copy(name='y')
f = theano.function([x], y)
assert_equal(f(data), data)
assert_string_equal(y.name, 'y')
...@@ -520,8 +520,11 @@ class _tensor_py_operators: ...@@ -520,8 +520,11 @@ class _tensor_py_operators:
return theano.tensor.subtensor.take(self, indices, axis, mode) return theano.tensor.subtensor.take(self, indices, axis, mode)
# COPYING # COPYING
def copy(self): def copy(self, name=None):
return theano.tensor.basic.tensor_copy(self) """Copy a variable and set a name to the copy."""
copied_variable = theano.tensor.basic.tensor_copy(self)
copied_variable.name = name
return copied_variable
def __iter__(self): def __iter__(self):
try: try:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论