提交 8d5c698d authored 作者: Razvan Pascanu's avatar Razvan Pascanu

make make_node to cast and filter_variable inputs to scan

上级 8d52b596
...@@ -214,7 +214,7 @@ class Scan(PureOp): ...@@ -214,7 +214,7 @@ class Scan(PureOp):
assert n_outer_ins == n_inner_ins, \ assert n_outer_ins == n_inner_ins, \
("The number of inputs given to the inner function of scan" ("The number of inputs given to the inner function of scan"
" does not match the number of inputs given to scan.") " does not match the number of inputs given to scan.")
new_inputs = [inputs[0]]
# assert dtype is consistent # assert dtype is consistent
err_msg1 = ('When compiling the inner function of scan the ' err_msg1 = ('When compiling the inner function of scan the '
'following error has been encountered: The ' 'following error has been encountered: The '
...@@ -252,24 +252,31 @@ class Scan(PureOp): ...@@ -252,24 +252,31 @@ class Scan(PureOp):
inputs[1 + self.n_seqs: (1 + self.n_seqs + inputs[1 + self.n_seqs: (1 + self.n_seqs +
self.n_outs)]] self.n_outs)]]
self.vector_outs += [False] * self.n_nit_sot self.vector_outs += [False] * self.n_nit_sot
def format(var, as_var):
""" This functions ensures that ``out`` has the same dtype as
``inp`` as well as calling filter_variable to make sure they are
both TensorType or CudaNdarrayType. It internally deals with the
corner case where inp.ndim + 1 = out.ndim
"""
if not hasattr(var, 'dtype'):
return var
rval = var
if rval.type.dtype != as_var.type.dtype:
rval = rval.astype(as_var.type.dtype)
if rval.ndim == as_var.ndim:
rval = as_var.type.filter_variable(rval)
else:
tmp = tensor.unbroadcast(tensor.shape_padleft(as_var), 0)
rval = tmp.type.filter_variable(rval)
return rval
# Check if input sequences and variables representing a slice of # Check if input sequences and variables representing a slice of
# them have the same dtype # them have the same dtype
argoffset = 0 argoffset = 0
for idx, (inner_seq, outer_seq) in enumerate( for inner_seq, outer_seq in zip(self.inner_seqs(self.inputs),
zip(self.inner_seqs(self.inputs), self.outer_seqs(inputs)):
self.outer_seqs(inputs))): new_inputs.append(format(outer_seq, as_var=inner_seq))
if inner_seq.type.dtype != outer_seq[0].type.dtype:
assert isinstance(idx, int)
raise ValueError(err_msg1 % ('sequence',
str(outer_seq),
idx,
outer_seq.type.dtype,
outer_seq.ndim,
str(inner_seq),
inner_seq.type.dtype,
inner_seq.ndim))
argoffset += len(self.outer_seqs(inputs)) argoffset += len(self.outer_seqs(inputs))
# Check that this 3 things have the same dtype for mit_mot: # Check that this 3 things have the same dtype for mit_mot:
# - initial state of the output # - initial state of the output
...@@ -279,10 +286,12 @@ class Scan(PureOp): ...@@ -279,10 +286,12 @@ class Scan(PureOp):
opos = 0 opos = 0
inner_mitmot = self.inner_mitmot(self.inputs) inner_mitmot = self.inner_mitmot(self.inputs)
inner_mitmot_outs = self.inner_mitmot_outs(self.outputs) inner_mitmot_outs = self.inner_mitmot_outs(self.outputs)
for idx, (itaps, otaps, outer_mitmot) in enumerate( for idx, (itaps, otaps, _outer_mitmot) in enumerate(
zip(self.mitmot_taps(), zip(self.mitmot_taps(),
self.mitmot_out_taps(), self.mitmot_out_taps(),
self.outer_mitmot(inputs))): self.outer_mitmot(inputs))):
outer_mitmot = format(_outer_mitmot, as_var=inner_mitmot[ipos])
new_inputs.append(outer_mitmot)
for k in xrange(len(itaps)): for k in xrange(len(itaps)):
if (inner_mitmot[ipos + k].type.dtype != if (inner_mitmot[ipos + k].type.dtype !=
outer_mitmot.type.dtype or outer_mitmot.type.dtype or
...@@ -315,13 +324,16 @@ class Scan(PureOp): ...@@ -315,13 +324,16 @@ class Scan(PureOp):
# Same checks as above but for outputs of type mit_sot # Same checks as above but for outputs of type mit_sot
ipos = 0 ipos = 0
inner_mitsots = self.inner_mitsot(self.inputs) inner_mitsots = self.inner_mitsot(self.inputs)
for idx, (itaps, outer_mitsot, inner_mitsot_out) in enumerate( for idx, (itaps, _outer_mitsot, inner_mitsot_out) in enumerate(
zip(self.mitsot_taps(), zip(self.mitsot_taps(),
self.outer_mitsot(inputs), self.outer_mitsot(inputs),
self.inner_mitsot_outs(self.outputs))): self.inner_mitsot_outs(self.outputs))):
outer_mitsot = format(_outer_mitsot, as_var=inner_mitsots[ipos])
new_inputs.append(outer_mitsot)
for k in xrange(len(itaps)): for k in xrange(len(itaps)):
if (inner_mitsots[ipos + k].type.dtype != \ if (inner_mitsots[ipos + k].type.dtype != \
outer_mitsot.type.dtype or outer_mitsot.type.dtype or
inner_mitsots[ipos + k].ndim != outer_mitsot.ndim - 1): inner_mitsots[ipos + k].ndim != outer_mitsot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info' raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ', ' in scan nomenclature) ',
...@@ -345,12 +357,13 @@ class Scan(PureOp): ...@@ -345,12 +357,13 @@ class Scan(PureOp):
argoffset += len(self.outer_mitsot(inputs)) argoffset += len(self.outer_mitsot(inputs))
# Same checks as above but for outputs of type sit_sot # Same checks as above but for outputs of type sit_sot
for idx, (inner_sitsot, outer_sitsot, inner_sitsot_out) in enumerate( for idx, (inner_sitsot, _outer_sitsot, inner_sitsot_out) in enumerate(
zip(self.inner_sitsot(self.inputs), zip(self.inner_sitsot(self.inputs),
self.outer_sitsot(inputs), self.outer_sitsot(inputs),
self.inner_sitsot_outs(self.outputs))): self.inner_sitsot_outs(self.outputs))):
if (inner_sitsot.type.dtype != outer_sitsot.type.dtype or outer_sitsot = format(_outer_sitsot, as_var=inner_sitsot)
inner_sitsot.ndim != outer_sitsot.ndim - 1): new_inputs.append(outer_sitsot)
if (inner_sitsot.ndim != outer_sitsot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info' raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ', ' in scan nomenclature) ',
str(outer_sitsot), str(outer_sitsot),
...@@ -373,10 +386,12 @@ class Scan(PureOp): ...@@ -373,10 +386,12 @@ class Scan(PureOp):
argoffset += len(self.outer_sitsot(inputs)) argoffset += len(self.outer_sitsot(inputs))
# Check that the shared variable and their update rule have the same # Check that the shared variable and their update rule have the same
# dtype. Maybe even same type ?! # dtype. Maybe even same type ?!
for idx, (inner_shared, inner_shared_out, outer_shared) in enumerate( for idx, (inner_shared, inner_shared_out, _outer_shared) in enumerate(
zip(self.inner_shared(self.inputs), zip(self.inner_shared(self.inputs),
self.inner_shared_outs(self.outputs), self.inner_shared_outs(self.outputs),
self.outer_shared(inputs))): self.outer_shared(inputs))):
outer_shared = format(_outer_shared, as_var=inner_shared)
new_inputs.append(outer_shared)
if (hasattr(outer_shared, 'dtype') and if (hasattr(outer_shared, 'dtype') and
(outer_shared.dtype != inner_shared_out.dtype or (outer_shared.dtype != inner_shared_out.dtype or
outer_shared.ndim != inner_shared_out.ndim)): outer_shared.ndim != inner_shared_out.ndim)):
...@@ -399,13 +414,17 @@ class Scan(PureOp): ...@@ -399,13 +414,17 @@ class Scan(PureOp):
str(inner_shared), str(inner_shared),
inner_shared.dtype, inner_shared.dtype,
inner_shared.ndim)) inner_shared.ndim))
for inner_nonseq, outer_nonseq in zip( new_inputs += self.outer_nitsot(inputs)
for inner_nonseq, _outer_nonseq in zip(
self.inner_non_seqs(self.inputs), self.inner_non_seqs(self.inputs),
self.outer_non_seqs(inputs)): self.outer_non_seqs(inputs)):
outer_nonseq = format(_outer_nonseq, as_var=inner_nonseq)
new_inputs.append(outer_nonseq)
if inner_nonseq.type != outer_nonseq.type: if inner_nonseq.type != outer_nonseq.type:
raise ValueError(('Argument %s given to scan node does not' raise ValueError(('Argument %s given to scan node does not'
' match its correspondance %s') % ' match its correspondance %s') %
(str(outer_nonseq), str(inner_nonseq))) (str(outer_nonseq), str(inner_nonseq)))
for outer_nitsot in self.outer_nitsot(inputs): for outer_nitsot in self.outer_nitsot(inputs):
# For every nit_sot input we get as input a int/uint that # For every nit_sot input we get as input a int/uint that
# depicts the size in memory for that sequence. This feature is # depicts the size in memory for that sequence. This feature is
...@@ -416,7 +435,7 @@ class Scan(PureOp): ...@@ -416,7 +435,7 @@ class Scan(PureOp):
'scalar int !', str(outer_nitsot)) 'scalar int !', str(outer_nitsot))
apply_node = Apply(self, apply_node = Apply(self,
inputs, new_inputs,
[t() for t in self.output_types]) [t() for t in self.output_types])
return apply_node return apply_node
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论