提交 481ad92d authored 作者: Hengjean's avatar Hengjean

Added Opt and tests.

上级 5617dc91
from type import TypedListType
from basic import *
import opt
from theano import gof
from theano import compile
from theano.gof import TopoOptimizer
from theano.typed_list.basic import (Reverse,
Append, Extend, Insert)
@gof.local_optimizer([Reverse], inplace=True)
def local_inplace_reverse(node):
if isinstance(node.op, Reverse) and not node.op.inplace:
new_op = node.op.__class__(
inplace=True)
new_node = new_op(*node.inputs)
return [new_node]
return False
compile.optdb.register('local_inplace_reverse',
TopoOptimizer(local_inplace_reverse,
failure_callback=TopoOptimizer.warn_inplace), 60,
'fast_run', 'inplace') # DEBUG
@gof.local_optimizer([Append], inplace=True)
def local_inplace_append(node):
if isinstance(node.op, Append) and not node.op.inplace:
new_op = node.op.__class__(
inplace=True)
new_node = new_op(*node.inputs)
return [new_node]
return False
compile.optdb.register('local_inplace_append',
TopoOptimizer(local_inplace_append,
failure_callback=TopoOptimizer.warn_inplace), 60,
'fast_run', 'inplace') # DEBUG
@gof.local_optimizer([Extend], inplace=True)
def local_inplace_extend(node):
if isinstance(node.op, Extend) and not node.op.inplace:
new_op = node.op.__class__(
inplace=True)
new_node = new_op(*node.inputs)
return [new_node]
return False
compile.optdb.register('local_inplace_extend',
TopoOptimizer(local_inplace_extend,
failure_callback=TopoOptimizer.warn_inplace), 60,
'fast_run', 'inplace') # DEBUG
@gof.local_optimizer([Insert], inplace=True)
def local_inplace_insert(node):
if isinstance(node.op, Insert) and not node.op.inplace:
new_op = node.op.__class__(
inplace=True)
new_node = new_op(*node.inputs)
return [new_node]
return False
compile.optdb.register('local_inplace_insert',
TopoOptimizer(local_inplace_insert,
failure_callback=TopoOptimizer.warn_inplace), 60,
'fast_run', 'inplace') # DEBUG
import unittest
import numpy
import theano
import theano.typed_list
from theano import tensor as T
from theano.tensor.type_other import SliceType
from theano.typed_list.type import TypedListType
from theano.typed_list.basic import (GetItem, Insert,
Append, Extend, Remove, Reverse,
Index, Count)
from theano import In
#took from tensors/tests/test_basic.py
def rand_ranged_matrix(minimum, maximum, shape):
return numpy.asarray(numpy.random.rand(*shape) * (maximum - minimum)
+ minimum, dtype=theano.config.floatX)
class test_inplace(unittest.TestCase):
def test_reverse_inplace(self):
mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
z = Reverse()(mySymbolicMatricesList)
f = theano.function([In(mySymbolicMatricesList, borrow=True,
mutable=True)], z, accept_inplace=True)
self.assertTrue(f.maker.fgraph.toposort()[0].op.inplace)
def test_append_inplace(self):
mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
mySymbolicMatrix = T.matrix()
z = Append()(mySymbolicMatricesList, mySymbolicMatrix)
f = theano.function([In(mySymbolicMatricesList, borrow=True,
mutable=True), In(mySymbolicMatrix, borrow=True,
mutable=True)], z, accept_inplace=True)
self.assertTrue(f.maker.fgraph.toposort()[0].op.inplace)
def test_extend_inplace(self):
mySymbolicMatricesList1 = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
mySymbolicMatricesList2 = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
z = Extend()(mySymbolicMatricesList1, mySymbolicMatricesList2)
f = theano.function([In(mySymbolicMatricesList1, borrow=True,
mutable=True), mySymbolicMatricesList2],
z)
self.assertTrue(f.maker.fgraph.toposort()[0].op.inplace)
def test_insert_inplace(self):
mySymbolicMatricesList = TypedListType(T.TensorType(
theano.config.floatX, (False, False)))()
mySymbolicIndex = T.scalar()
mySymbolicMatrix = T.matrix()
z = Insert()(mySymbolicMatricesList, mySymbolicIndex, mySymbolicMatrix)
f = theano.function([In(mySymbolicMatricesList, borrow=True,
mutable=True), mySymbolicIndex, mySymbolicMatrix],
z, accept_inplace=True)
self.assertTrue(f.maker.fgraph.toposort()[0].op.inplace)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论