提交 23145a75 authored 作者: Hengjean's avatar Hengjean

Added remove opt

上级 228915ea
......@@ -2,12 +2,12 @@ from theano import gof
from theano import compile
from theano.gof import TopoOptimizer
from theano.typed_list.basic import (Reverse,
Append, Extend, Insert)
Append, Extend, Insert, Remove)
@gof.local_optimizer([Append, Extend, Insert, Reverse], inplace=True)
@gof.local_optimizer([Append, Extend, Insert, Reverse, Remove], inplace=True)
def typed_list_inplace_opt(node):
if isinstance(node.op, (Append, Extend, Insert, Reverse)) \
if isinstance(node.op, (Append, Extend, Insert, Reverse, Remove)) \
and not node.op.inplace:
new_op = node.op.__class__(
......
......@@ -30,7 +30,7 @@ class test_inplace(unittest.TestCase):
f = theano.function([In(mySymbolicMatricesList, borrow=True,
mutable=True)], z, accept_inplace=True)
self.assertTrue(f.maker.fgraph.toposort()[0].op.inplace)
x = rand_ranged_matrix(-1000, 1000, [100, 101])
y = rand_ranged_matrix(-1000, 1000, [100, 101])
......@@ -47,8 +47,8 @@ class test_inplace(unittest.TestCase):
mutable=True), In(mySymbolicMatrix, borrow=True,
mutable=True)], z, accept_inplace=True)
self.assertTrue(f.maker.fgraph.toposort()[0].op.inplace)
x = rand_ranged_matrix(-1000, 1000, [100, 101])
x = rand_ranged_matrix(-1000, 1000, [100, 101])
y = rand_ranged_matrix(-1000, 1000, [100, 101])
......@@ -93,3 +93,20 @@ class test_inplace(unittest.TestCase):
self.assertTrue(numpy.array_equal(f([x], numpy.asarray(1,
dtype=theano.config.floatX), y), [x, y]))
def test_remove_inplace(self):
mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
mySymbolicMatrix = T.matrix()
z = Remove()(mySymbolicMatricesList, mySymbolicMatrix)
f = theano.function([In(mySymbolicMatricesList, borrow=True,
mutable=True), In(mySymbolicMatrix, borrow=True,
mutable=True)], z, accept_inplace=True)
self.assertTrue(f.maker.fgraph.toposort()[0].op.inplace)
x = rand_ranged_matrix(-1000, 1000, [100, 101])
y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(numpy.array_equal(f([x, y], y), [x,]))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论