提交 23145a75 authored 作者: Hengjean's avatar Hengjean

Added remove opt

上级 228915ea
...@@ -2,12 +2,12 @@ from theano import gof ...@@ -2,12 +2,12 @@ from theano import gof
from theano import compile from theano import compile
from theano.gof import TopoOptimizer from theano.gof import TopoOptimizer
from theano.typed_list.basic import (Reverse, from theano.typed_list.basic import (Reverse,
Append, Extend, Insert) Append, Extend, Insert, Remove)
@gof.local_optimizer([Append, Extend, Insert, Reverse], inplace=True) @gof.local_optimizer([Append, Extend, Insert, Reverse, Remove], inplace=True)
def typed_list_inplace_opt(node): def typed_list_inplace_opt(node):
if isinstance(node.op, (Append, Extend, Insert, Reverse)) \ if isinstance(node.op, (Append, Extend, Insert, Reverse, Remove)) \
and not node.op.inplace: and not node.op.inplace:
new_op = node.op.__class__( new_op = node.op.__class__(
......
...@@ -93,3 +93,20 @@ class test_inplace(unittest.TestCase): ...@@ -93,3 +93,20 @@ class test_inplace(unittest.TestCase):
self.assertTrue(numpy.array_equal(f([x], numpy.asarray(1, self.assertTrue(numpy.array_equal(f([x], numpy.asarray(1,
dtype=theano.config.floatX), y), [x, y])) dtype=theano.config.floatX), y), [x, y]))
def test_remove_inplace(self):
mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
mySymbolicMatrix = T.matrix()
z = Remove()(mySymbolicMatricesList, mySymbolicMatrix)
f = theano.function([In(mySymbolicMatricesList, borrow=True,
mutable=True), In(mySymbolicMatrix, borrow=True,
mutable=True)], z, accept_inplace=True)
self.assertTrue(f.maker.fgraph.toposort()[0].op.inplace)
x = rand_ranged_matrix(-1000, 1000, [100, 101])
y = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(numpy.array_equal(f([x, y], y), [x,]))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论