提交 e937d599 authored 作者: James Bergstra's avatar James Bergstra

Fixed get_constant function in Canonicalizer to handle numeric constants in

addition to Variables.
上级 8cf2e21d
...@@ -7,7 +7,7 @@ import logging ...@@ -7,7 +7,7 @@ import logging
_logger = logging.getLogger('theano.tensor.opt') _logger = logging.getLogger('theano.tensor.opt')
from theano import gof from theano import gof
from theano.gof import opt, InconsistencyError, TopoOptimizer, graph from theano.gof import opt, InconsistencyError, TopoOptimizer, graph, Variable
from theano.gof.utils import MethodNotDefined from theano.gof.utils import MethodNotDefined
from theano.configparser import config from theano.configparser import config
from elemwise import Elemwise, DimShuffle from elemwise import Elemwise, DimShuffle
...@@ -977,10 +977,13 @@ class Canonizer(gof.LocalOptimizer): ...@@ -977,10 +977,13 @@ class Canonizer(gof.LocalOptimizer):
Returns a numeric constant if v is a gof.Constant or, well, a Returns a numeric constant if v is a gof.Constant or, well, a
numeric constant. If v is a plain Variable, returns None. numeric constant. If v is a plain Variable, returns None.
""" """
if isinstance(v, Variable):
try: try:
return get_constant_value(v) return get_constant_value(v)
except TypeError: except TypeError:
return None return None
else:
return v
def simplify(self, num, denum): def simplify(self, num, denum):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论