提交 6b9cca50 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

merge

......@@ -2172,6 +2172,31 @@ class Subtensor(Op):
cdata = cdata[0]
out[0] = numpy.asarray(x.__getitem__(cdata))
def infer_shape(self, node, shapes):
xshp = shapes[0]
assert len(xshp) == node.inputs[0].ndim
outshp = []
padded = self.idx_list + [slice(None, None, None)] * (len(xshp) - len(self.idx_list))
i = 0
for idx, xl in zip(padded, xshp):
if isinstance(idx, slice):
# If it is the default (None, None, None) slice, or a variant,
# the shape will be xl
if (idx.start is None or idx.start == 0)\
and (idx.stop is None or idx.stop == sys.maxint)\
and (idx.step is None or idx.step == 1):
outshp.append(xl)
else:
#No easy way to compute the shape
outshp.append(Shape_i(i)(node.outputs[0]))
i += 1
else:
# That dimension is dropped
pass
assert i == node.outputs[0].ndim
assert len(outshp) == node.outputs[0].ndim
return [outshp]
def grad(self, inputs, (gz,)):
x = inputs[0]
rest = inputs[1:]
......
......@@ -905,28 +905,19 @@ def _check_rows_is_arange_len_labels(rows, labels):
return False
if getattr(step, 'data', None) != 1: # constant step will have data
return False
if stop.owner and isinstance(stop.owner.op, tensor.Subtensor):
if not stop.owner:
return False
# Not sure if that case happens any more after the introduction
# of ShapeOptimizer
if isinstance(stop.owner.op, tensor.Subtensor):
shape_subtensor = stop.owner
if shape_subtensor.op.idx_list == [0]:
shape_var, = shape_subtensor.inputs
if shape_var.owner and shape_var.owner.op == tensor._shape:
return shape_var.owner.inputs[0] is labels
@gof.local_optimizer([tensor._shape])
def local_shape_lift_advanced_indexing_arange(node):
'''shape(a[arange(len(y)), y]) -> shape(y) (conditions apply)'''
if node.op == tensor._shape:
if node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, tensor.AdvancedSubtensor):
try:
a, rows, labels = node.inputs[0].owner.inputs
except:
return
if _check_rows_is_arange_len_labels(rows, labels):
if labels.ndim == 1 and a.ndim == 2:
return tensor._shape(labels),
opt.register_specialize(local_shape_lift_advanced_indexing_arange, 'shape_lift')
else:
shape_of = stop.owner.env.shape_feature.shape_of
return shape_of[labels][0] is stop
@opt.register_specialize
@gof.local_optimizer([])
......
......@@ -269,6 +269,8 @@ class T_CrossentropyCategorical1Hot(unittest.TestCase):
def test_get_rid_of_advanced_indexing_version_of_xent(self):
verbose = 0
if verbose:
from theano.printing import pprint
# TODO: add the optimization in FAST_COMPILE?
# In the mean time, run it as 'FAST_RUN' instead
mode = theano.compile.mode.get_default_mode()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论