提交 4357224c authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add optimization removing transfers between scalars and scalar tensors

上级 8cf01ede
......@@ -303,6 +303,34 @@ def local_dimshuffle_no_inplace_at_canonicalize(node):
return [T.DimShuffle(node.op.input_broadcastable, node.op.new_order, inplace=False)(node.inputs[0])]
######################
# Casting operations #
######################
@register_canonicalize
#@register_specialize
@gof.local_optimizer([T.TensorFromScalar])
def local_tensor_scalar_tensor(node):
'''tensor_from_scalar(scalar_from_tensor(x)) -> x'''
if isinstance(node.op, T.TensorFromScalar):
s = node.inputs[0]
if s.owner and isinstance(s.owner.op, T.ScalarFromTensor):
t = s.owner.inputs[0]
return [t]
@register_canonicalize
#@register_specialize
@gof.local_optimizer([T.ScalarFromTensor])
def local_scalar_tensor_scalar(node):
'''scalar_from_tensor(tensor_from_scalar(x)) -> x'''
if isinstance(node.op, T.ScalarFromTensor):
t = node.inputs[0]
if t.owner and isinstance(t.owner.op, T.TensorFromScalar):
s = t.owner.inputs[0]
return [s]
#####################################
# ShapeFeature, Shape optimizations
#####################################
......
......@@ -6,6 +6,7 @@ import unittest
import numpy
from nose.plugins.skip import SkipTest
from numpy.testing import dec
from numpy.testing.noseclasses import KnownFailureTest
import theano
......@@ -2324,6 +2325,52 @@ def test_local_add_specialize():
s = TT.add(TT.zeros_like(a))
assert local_add_specialize.transform(s.owner)
def test_local_tensor_scalar_tensor():
dtypes = ['int8', 'int16', 'int32', 'int64',
'uint8', 'uint16', 'uint32', 'uint64',
'float32', 'float64',
'complex64', 'complex128'
]
for dtype in dtypes:
t_type = TensorType(dtype=dtype, broadcastable=())
t = t_type()
s = TT.scalar_from_tensor(t)
t2 = TT.tensor_from_scalar(s)
f = function([t], t2, mode=mode_opt)
e = f.maker.env.toposort()
cast_nodes = [n for n in e
if isinstance(n.op, (TT.TensorFromScalar,
TT.ScalarFromTensor))]
assert len(cast_nodes) == 0
f(0)
@dec.knownfailureif(
isinstance(theano.compile.mode.get_default_mode(),
theano.compile.debugmode.DebugMode),
("This test fails in DEBUG_MODE, but the generated code is OK. "
"It is actually a problem of DEBUG_MODE, see #624."))
def test_local_scalar_tensor_scalar():
dtypes = ['int8', 'int16', 'int32', 'int64',
'uint8', 'uint16', 'uint32', 'uint64',
'float32', 'float64',
'complex64', 'complex128'
]
for dtype in dtypes:
s_type = theano.scalar.Scalar(dtype=dtype)
s = s_type()
t = TT.tensor_from_scalar(s)
s2 = TT.scalar_from_tensor(t)
f = function([s], s2, mode=mode_opt)
e = f.maker.env.toposort()
cast_nodes = [n for n in e
if isinstance(n.op, (TT.TensorFromScalar,
TT.ScalarFromTensor))]
assert len(cast_nodes) == 0
f(0)
if __name__ == '__main__':
# unittest.main()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论