提交 92a4c457 authored 作者: David Warde-Farley's avatar David Warde-Farley

Merge pull request #7 from jaberg/small_stuff

Small fixes to AdvancedSubtensor1 and CVM.
......@@ -476,6 +476,16 @@ class Function(object):
self._value = ValueAttribute()
self._container = ContainerAttribute()
# Compute self.n_returned_outputs.
# This is used only when fn.need_update_inputs is False
# because we're using one of the VM objects and it is
# putting updates back into the input containers all by itself.
assert len(self.maker.expanded_inputs) == len(self.input_storage)
self.n_returned_outputs = len(self.output_storage)
for input in self.maker.expanded_inputs:
if input.update is not None:
self.n_returned_outputs -= 1
def __contains__(self, item):
return self.value.__contains__(item)
......@@ -636,6 +646,8 @@ class Function(object):
for input, storage in reversed(zip(self.maker.expanded_inputs, self.input_storage)):
if input.update is not None:
storage.data = outputs.pop()
else:
outputs = outputs[:self.n_returned_outputs]
# Put default values back in the storage
for i, (required, refeed, value) in enumerate(self.defaults):
......
......@@ -1270,9 +1270,12 @@ class _tensor_py_operators:
break
if advanced:
if len(args) == 1 and isinstance(args[0],
(list, TensorVariable,
theano.tensor.sharedvar.TensorSharedVariable)):
if (len(args) == 1
and isinstance(args[0], (
list,
TensorVariable,
TensorConstant,
theano.tensor.sharedvar.TensorSharedVariable))):
return advanced_subtensor1(self, *args)
else:
return AdvancedSubtensor(args)(self, *args)
......@@ -4863,10 +4866,6 @@ class AdvancedSubtensor1(Op):
raise TypeError('index must be vector')
if x_.type.ndim == 0:
raise TypeError('cannot index into a scalar')
if x_.type.broadcastable[0]:
# the caller should have made a copy of x len(ilist) times
raise TypeError('cannot index into a broadcastable dimension')
return Apply(self, [x_, ilist_], [x_.type()])
def perform(self, node, inp, out_):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论