提交 ac0b0a32 authored 作者: khaotik's avatar khaotik

no more runtime check/assign for vector_outs

上级 4a9d1cf8
......@@ -434,8 +434,8 @@ class Scan(PureOp):
argoffset += len(self.outer_seqs(inputs))
# Check that this 3 things have the same dtype for mit_mot:
# - initial state of the output
# - variable representing an input slice of the otuput
# - variable representing an output slice of the otuput
# - variable representing an input slice of the output
# - variable representing an output slice of the output
ipos = 0
opos = 0
inner_mitmot = self.inner_mitmot(self.inputs)
......@@ -610,16 +610,16 @@ class Scan(PureOp):
# The vector_seqs and vector_outs are just a workaround
# strange NumPy behavior: vector_ndarray[int] return a NumPy
# scalar and not a NumPy ndarray of 0 dimensions.
self.vector_seqs = [isinstance(seq, (tensor.TensorVariable,
tensor.TensorConstant)) and
seq.ndim == 1 for seq in
new_inputs[1:1 + self.n_seqs]]
self.vector_outs = [isinstance(arg, (tensor.TensorVariable,
tensor.TensorConstant)) and
arg.ndim == 1 for arg in
new_inputs[1 + self.n_seqs: (1 + self.n_seqs +
self.n_outs)]]
self.vector_outs += [False] * self.n_nit_sot
is_cpu_vector = lambda s: isinstance(s.type, tensor.TensorType) \
and s.ndim == 1
self.vector_seqs = [
is_cpu_vector(seq) for seq in new_inputs[1:1 + self.n_seqs]]
self.vector_outs = [
is_cpu_vector(arg) for arg in new_inputs[
1 + self.n_seqs: (1 + self.n_seqs + self.n_outs)]]
self.vector_outs += [
isinstance(t.type, tensor.TensorType) and t.ndim == 0
for t in self.outer_nitsot_outs(self.outputs)]
apply_node = Apply(self,
new_inputs,
......
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -514,9 +514,6 @@ def perform(
if i == 0:
jout = j+offset_out
shape = (store_steps[j],) + output_storage[jout].storage[0].shape
if output_storage[jout].storage[0].ndim == 0 and \
isinstance(output_storage[jout].storage[0], numpy.ndarray):
vector_outs[j] = 1
dtype = output_storage[jout].storage[0].dtype
if (outs[j][0] is None or
outs[j][0].shape[0] < store_steps[j] or
......
......@@ -24,7 +24,7 @@ from theano.gof import cmodule
_logger = logging.getLogger('theano.scan_module.scan_perform')
version = 0.295 # must match constant returned in function get_version()
version = 0.296 # must match constant returned in function get_version()
need_reload = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论