local_subtensor_make_vector: deal with case when idx is a type.

上级 7999e2d3
...@@ -7,7 +7,8 @@ import logging ...@@ -7,7 +7,8 @@ import logging
_logger = logging.getLogger('theano.tensor.opt') _logger = logging.getLogger('theano.tensor.opt')
from theano import gof from theano import gof
from theano.gof import opt, InconsistencyError, TopoOptimizer, graph, Variable from theano.gof import opt, InconsistencyError, TopoOptimizer, graph
from theano.gof import Variable, Constant
from theano.gof.utils import MethodNotDefined from theano.gof.utils import MethodNotDefined
from theano.configparser import config from theano.configparser import config
from elemwise import Elemwise, DimShuffle from elemwise import Elemwise, DimShuffle
...@@ -74,7 +75,7 @@ def get_constant_value(v): ...@@ -74,7 +75,7 @@ def get_constant_value(v):
is. is.
""" """
if isinstance(v, gof.Constant): if isinstance(v, Constant):
#TODO: consider checking for arrays of the form e.g. [1,1,1,1] where #TODO: consider checking for arrays of the form e.g. [1,1,1,1] where
# it is not a constant, but in some cases it *could* be replaced with one. # it is not a constant, but in some cases it *could* be replaced with one.
# Note that this would have an effect on the broadcasting of inputs and so on # Note that this would have an effect on the broadcasting of inputs and so on
...@@ -567,9 +568,16 @@ def local_subtensor_make_vector(node): ...@@ -567,9 +568,16 @@ def local_subtensor_make_vector(node):
except: except:
#'how can you have multiple indexes into a shape?' #'how can you have multiple indexes into a shape?'
raise raise
if isinstance(idx, (scalar.Scalar, T.TensorType)):
# The idx is a Scalar, ie a Type. This means the actual index
# is contained in node.inputs[1]
old_idx, idx = idx, node.inputs[1]
assert isinstance(idx, old_idx)
if isinstance(idx, (int, numpy.integer)): if isinstance(idx, (int, numpy.integer)):
return [x.owner.inputs[idx]] return [x.owner.inputs[idx]]
elif isinstance(idx, (T.TensorVariable, T.TensorConstant)): elif isinstance(idx, Variable):
# if it is a constant we can do something with it # if it is a constant we can do something with it
try: try:
v = get_constant_value(idx) v = get_constant_value(idx)
...@@ -1044,7 +1052,7 @@ class Canonizer(gof.LocalOptimizer): ...@@ -1044,7 +1052,7 @@ class Canonizer(gof.LocalOptimizer):
@staticmethod @staticmethod
def get_constant(v): def get_constant(v):
""" """
Returns a numeric constant if v is a gof.Constant or, well, a Returns a numeric constant if v is a Constant or, well, a
numeric constant. If v is a plain Variable, returns None. numeric constant. If v is a plain Variable, returns None.
""" """
if isinstance(v, Variable): if isinstance(v, Variable):
...@@ -1126,7 +1134,7 @@ class Canonizer(gof.LocalOptimizer): ...@@ -1126,7 +1134,7 @@ class Canonizer(gof.LocalOptimizer):
# we can't allow ct == [] # we can't allow ct == []
# TODO: why is this branch needed when merge_num_denum does it for us? # TODO: why is this branch needed when merge_num_denum does it for us?
ct = [self.calculate(numct, denumct, aslist = False, out_type=out_type)] ct = [self.calculate(numct, denumct, aslist = False, out_type=out_type)]
# TODO: why are we not wrapping ct in a gof.Constant right now? # TODO: why are we not wrapping ct in a Constant right now?
if orig_num and len(numct) == 1 and len(denumct) == 0 and ct and N.all(ct == self.get_constant(orig_num[0])): if orig_num and len(numct) == 1 and len(denumct) == 0 and ct and N.all(ct == self.get_constant(orig_num[0])):
# this is an important trick :( if it so happens that: # this is an important trick :( if it so happens that:
...@@ -1352,7 +1360,7 @@ def local_neg_div_neg(node): ...@@ -1352,7 +1360,7 @@ def local_neg_div_neg(node):
# No other clients of the original division # No other clients of the original division
new_num = num.owner.inputs[0] new_num = num.owner.inputs[0]
return [T.true_div(new_num, denom)] return [T.true_div(new_num, denom)]
elif numpy.all(num.broadcastable) and isinstance(num, gof.Constant): elif numpy.all(num.broadcastable) and isinstance(num, Constant):
if len(frac.clients) == 1: if len(frac.clients) == 1:
new_num = -num.data new_num = -num.data
return [T.true_div(new_num, denom)] return [T.true_div(new_num, denom)]
...@@ -1715,7 +1723,7 @@ register_canonicalize(local_greedy_distributor) ...@@ -1715,7 +1723,7 @@ register_canonicalize(local_greedy_distributor)
@gof.local_optimizer([None]) @gof.local_optimizer([None])
def constant_folding(node): def constant_folding(node):
for input in node.inputs: for input in node.inputs:
if not isinstance(input, gof.Constant): if not isinstance(input, Constant):
return False return False
try: try:
storage = [[None] for output in node.outputs] storage = [[None] for output in node.outputs]
...@@ -1735,7 +1743,7 @@ def constant_folding(node): ...@@ -1735,7 +1743,7 @@ def constant_folding(node):
try: try:
constant = output.type.Constant constant = output.type.Constant
except: except:
constant = gof.Constant constant = Constant
msg += [constant(output.type, s[0])] msg += [constant(output.type, s[0])]
return msg return msg
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论