提交 19d28797 authored 作者: James Bergstra's avatar James Bergstra

Added canonicalization optimization to tensor.opt. It replaces multiplication

by zero with zero.
上级 6905a821
......@@ -10,6 +10,7 @@ from elemwise import Elemwise, DimShuffle
from theano import scalar
import basic as T
import inplace as I
import numpy
import numpy as N
import operator
import itertools
......@@ -38,6 +39,24 @@ def _fill_chain(new_out, orig_inputs):
new_out = T.fill(i, new_out)
return [new_out]
def get_constant_value(v):
"""return the constant value underlying variable `v`
If v is the output of dimshuffles, this function digs through them.
If `v` is not some view of constant data, then raise a TypeError.
:note: There may be another function similar to this one in the code, but I'm not sure where it
is.
"""
if not isinstance(v, gof.Variable):
return v # why would this happen?
if isinstance(v, gof.Constant):
return v.data
if v.owner and isinstance(v.owner.op, T.DimShuffle):
return get_constant_value(v.owner.inputs[0])
raise TypeError(v)
@gof.optimizer
......@@ -841,6 +860,24 @@ def local_mul_to_neg(node):
return False
register_specialize(local_mul_to_neg)
@gof.local_optimizer([T.mul])
def local_mul_zero(node):
"""As part of canonicalization, we replace multiplication by zero with zero.
"""
if node.op == T.mul:
otype = node.outputs[0].type
for i in node.inputs:
try:
value = get_constant_value(i)
except TypeError:
continue
#print 'MUL by value', value, node.inputs
if numpy.all(value == 0):
#print '... returning zeros'
return _fill_chain(numpy.asarray(0, dtype=otype.dtype), node.inputs)
register_canonicalize(local_mul_zero)
@gof.local_optimizer([T.true_div])
def local_div_to_inv(node):
if node.op == T.true_div and N.all(local_mul_canonizer.get_constant(node.inputs[0]) == 1.0):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论