提交 deee4794 authored 作者: Frederic Bastien's avatar Frederic Bastien

Use a cast instead of an identify that modify the dtype. This allow more optimization to work.

上级 d308c69a
...@@ -4998,10 +4998,7 @@ class Canonizer(gof.LocalOptimizer): ...@@ -4998,10 +4998,7 @@ class Canonizer(gof.LocalOptimizer):
new = self.merge_num_denum(num, denum) new = self.merge_num_denum(num, denum)
if new.type.dtype != out.type.dtype: if new.type.dtype != out.type.dtype:
# new = T.fill(out, new) new = T.cast(new, out.type.dtype)
elem_op = T.Elemwise(scalar.Identity(scalar.specific_out(
getattr(scalar, out.type.dtype))))
new = elem_op(new)
assert (new.type == out.type) == (not (new.type != out.type)) assert (new.type == out.type) == (not (new.type != out.type))
......
...@@ -898,6 +898,25 @@ def test_const_type_in_mul_canonizer(): ...@@ -898,6 +898,25 @@ def test_const_type_in_mul_canonizer():
f1(ival, wval, visbval, hidbval, betaval, aval)) f1(ival, wval, visbval, hidbval, betaval, aval))
def test_cast_in_mul_canonizer():
x, y = tensor.vectors('xy')
m = tensor.minimum(x, y)
o = m.sum()
go = tensor.fill(o, 1)
e = tensor.eq(go, x)
o1 = (1 - e) * go
o2 = e * go
mode = theano.compile.get_default_mode().excluding('fusion').including('fast_run')
f = theano.function([x, y], [o1, o2], mode=mode)
theano.printing.debugprint(f, print_type=True)
nodes = f.maker.fgraph.apply_nodes
assert len([n for n in nodes if isinstance(getattr(n.op, 'scalar_op', None),
theano.scalar.Identity)]) == 0
assert len([n for n in nodes if isinstance(getattr(n.op, 'scalar_op'),
theano.scalar.Cast)]) == 1
f([1], [1])
class test_fusion(unittest.TestCase): class test_fusion(unittest.TestCase):
mode = copy.copy(compile.mode.get_default_mode()) mode = copy.copy(compile.mode.get_default_mode())
_shared = staticmethod(shared) _shared = staticmethod(shared)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论