local_subtensor_make_vector: deal with case when idx is a type.

上级 7999e2d3
......@@ -7,7 +7,8 @@ import logging
_logger = logging.getLogger('theano.tensor.opt')
from theano import gof
from theano.gof import opt, InconsistencyError, TopoOptimizer, graph, Variable
from theano.gof import opt, InconsistencyError, TopoOptimizer, graph
from theano.gof import Variable, Constant
from theano.gof.utils import MethodNotDefined
from theano.configparser import config
from elemwise import Elemwise, DimShuffle
......@@ -74,7 +75,7 @@ def get_constant_value(v):
is.
"""
if isinstance(v, gof.Constant):
if isinstance(v, Constant):
#TODO: consider checking for arrays of the form e.g. [1,1,1,1] where
# it is not a constant, but in some cases it *could* be replaced with one.
# Note that this would have an effect on the broadcasting of inputs and so on
......@@ -567,9 +568,16 @@ def local_subtensor_make_vector(node):
except:
#'how can you have multiple indexes into a shape?'
raise
if isinstance(idx, (scalar.Scalar, T.TensorType)):
# The idx is a Scalar, ie a Type. This means the actual index
# is contained in node.inputs[1]
old_idx, idx = idx, node.inputs[1]
assert isinstance(idx, old_idx)
if isinstance(idx, (int, numpy.integer)):
return [x.owner.inputs[idx]]
elif isinstance(idx, (T.TensorVariable, T.TensorConstant)):
elif isinstance(idx, Variable):
# if it is a constant we can do something with it
try:
v = get_constant_value(idx)
......@@ -1044,7 +1052,7 @@ class Canonizer(gof.LocalOptimizer):
@staticmethod
def get_constant(v):
"""
Returns a numeric constant if v is a gof.Constant or, well, a
Returns a numeric constant if v is a Constant or, well, a
numeric constant. If v is a plain Variable, returns None.
"""
if isinstance(v, Variable):
......@@ -1126,7 +1134,7 @@ class Canonizer(gof.LocalOptimizer):
# we can't allow ct == []
# TODO: why is this branch needed when merge_num_denum does it for us?
ct = [self.calculate(numct, denumct, aslist = False, out_type=out_type)]
# TODO: why are we not wrapping ct in a gof.Constant right now?
# TODO: why are we not wrapping ct in a Constant right now?
if orig_num and len(numct) == 1 and len(denumct) == 0 and ct and N.all(ct == self.get_constant(orig_num[0])):
# this is an important trick :( if it so happens that:
......@@ -1352,7 +1360,7 @@ def local_neg_div_neg(node):
# No other clients of the original division
new_num = num.owner.inputs[0]
return [T.true_div(new_num, denom)]
elif numpy.all(num.broadcastable) and isinstance(num, gof.Constant):
elif numpy.all(num.broadcastable) and isinstance(num, Constant):
if len(frac.clients) == 1:
new_num = -num.data
return [T.true_div(new_num, denom)]
......@@ -1715,7 +1723,7 @@ register_canonicalize(local_greedy_distributor)
@gof.local_optimizer([None])
def constant_folding(node):
for input in node.inputs:
if not isinstance(input, gof.Constant):
if not isinstance(input, Constant):
return False
try:
storage = [[None] for output in node.outputs]
......@@ -1735,7 +1743,7 @@ def constant_folding(node):
try:
constant = output.type.Constant
except:
constant = gof.Constant
constant = Constant
msg += [constant(output.type, s[0])]
return msg
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论