提交 9d9d9020 authored 作者: nouiz's avatar nouiz

Merge pull request #298 from pascanur/scan_check

Scan check
......@@ -119,6 +119,7 @@ class Scalar(Type):
TODO: refactor to be named ScalarType for consistency with TensorType
"""
ndim = 0
def __init__(self, dtype):
if dtype == 'floatX':
......@@ -441,6 +442,9 @@ all_types = discrete_types + continuous_types
class _scalar_py_operators:
# So that we can simplify checking code when we have a mixture of Scalar
# variables and Tensor variables
ndim = 0
#UNARY
def __abs__(self): return abs_(self)
......
......@@ -200,8 +200,7 @@ class PushOutNonSeqScan(gof.Optimizer):
not isinstance(nd.op, theano.compile.ViewOp) and
not isinstance(nd.op, theano.compile.DeepCopyOp) and
# and we didn't already looked at this node
not nd in to_remove
):
not nd in to_remove):
# We have a candidate node to removable
# Step 1. Reconstruct it on outside
......@@ -317,12 +316,12 @@ def scan_make_inplace(node):
info['inplace'] = True
# inputs corresponding to sequences and n_steps
ls_begin = node.inputs[:1 + op.n_seqs]
ls = op.outer_mitmot(node)
ls += op.outer_mitsot(node)
ls += op.outer_sitsot(node)
ls_end = op.outer_shared(node)
ls_end += op.outer_nitsot(node)
ls_end += op.outer_non_seqs(node)
ls = op.outer_mitmot(node.inputs)
ls += op.outer_mitsot(node.inputs)
ls += op.outer_sitsot(node.inputs)
ls_end = op.outer_shared(node.inputs)
ls_end += op.outer_nitsot(node.inputs)
ls_end += op.outer_non_seqs(node.inputs)
n_outs = len(ls)
for idx in xrange(n_outs):
if ls[idx] in ls[:idx]:
......@@ -717,8 +716,7 @@ class ScanSaveMem(gof.Optimizer):
fslice = slice(
sanitize(cnf_slice[0].start),
sanitize(cnf_slice[0].stop),
sanitize(cnf_slice[0].step)
)
sanitize(cnf_slice[0].step))
else:
fslice = sanitize(cnf_slice[0])
......@@ -850,54 +848,54 @@ class ScanMerge(gof.Optimizer):
for idx, nd in enumerate(nodes):
# Seq
inner_ins += rename(nd.op.inner_seqs(), idx)
outer_ins += rename(nd.op.outer_seqs(nd), idx)
inner_ins += rename(nd.op.inner_seqs(nd.op.inputs), idx)
outer_ins += rename(nd.op.outer_seqs(nd.inputs), idx)
for idx, nd in enumerate(nodes):
# MitMot
inner_ins += rename(nd.op.inner_mitmot(), idx)
inner_outs += nd.op.inner_mitmot_outs()
inner_ins += rename(nd.op.inner_mitmot(nd.op.inputs), idx)
inner_outs += nd.op.inner_mitmot_outs(nd.op.outputs)
info['tap_array'] += nd.op.mitmot_taps()
info['mit_mot_out_slices'] += nd.op.mitmot_out_taps()
outer_ins += rename(nd.op.outer_mitmot(nd), idx)
outer_outs += nd.op.outer_mitmot_outs(nd)
outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx)
outer_outs += nd.op.outer_mitmot_outs(nd.outputs)
for idx, nd in enumerate(nodes):
# MitSot
inner_ins += rename(nd.op.inner_mitsot(), idx)
inner_outs += nd.op.inner_mitsot_outs()
inner_ins += rename(nd.op.inner_mitsot(nd.op.inputs), idx)
inner_outs += nd.op.inner_mitsot_outs(nd.op.outputs)
info['tap_array'] += nd.op.mitsot_taps()
outer_ins += rename(nd.op.outer_mitsot(nd), idx)
outer_outs += nd.op.outer_mitsot_outs(nd)
outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx)
outer_outs += nd.op.outer_mitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes):
# SitSot
inner_ins += rename(nd.op.inner_sitsot(), idx)
inner_ins += rename(nd.op.inner_sitsot(nd.op.inputs), idx)
info['tap_array'] += [[-1] for x in xrange(nd.op.n_sit_sot)]
inner_outs += nd.op.inner_sitsot_outs()
outer_ins += rename(nd.op.outer_sitsot(nd), idx)
outer_outs += nd.op.outer_sitsot_outs(nd)
inner_outs += nd.op.inner_sitsot_outs(nd.op.outputs)
outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx)
outer_outs += nd.op.outer_sitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes):
# Shared
inner_ins += rename(nd.op.inner_shared(), idx)
outer_ins += rename(nd.op.outer_shared(nd), idx)
inner_ins += rename(nd.op.inner_shared(nd.op.inputs), idx)
outer_ins += rename(nd.op.outer_shared(nd.inputs), idx)
for idx, nd in enumerate(nodes):
# NitSot
inner_outs += nd.op.inner_nitsot_outs()
outer_ins += rename(nd.op.outer_nitsot(nd), idx)
outer_outs += nd.op.outer_nitsot_outs(nd)
inner_outs += nd.op.inner_nitsot_outs(nd.op.outputs)
outer_ins += rename(nd.op.outer_nitsot(nd.inputs), idx)
outer_outs += nd.op.outer_nitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes):
# Shared
outer_outs += nd.op.outer_shared_outs(nd)
inner_outs += nd.op.inner_shared_outs()
outer_outs += nd.op.outer_shared_outs(nd.outputs)
inner_outs += nd.op.inner_shared_outs(nd.op.outputs)
for idx, nd in enumerate(nodes):
# Non Seqs
inner_ins += rename(nd.op.inner_non_seqs(), idx)
outer_ins += rename(nd.op.outer_non_seqs(nd), idx)
inner_ins += rename(nd.op.inner_non_seqs(nd.op.inputs), idx)
outer_ins += rename(nd.op.outer_non_seqs(nd.inputs), idx)
# Add back the number of steps
outer_ins = [nodes[0].inputs[0]] + outer_ins
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论