提交 9d9d9020 authored 作者: nouiz's avatar nouiz

Merge pull request #298 from pascanur/scan_check

Scan check
...@@ -119,6 +119,7 @@ class Scalar(Type): ...@@ -119,6 +119,7 @@ class Scalar(Type):
TODO: refactor to be named ScalarType for consistency with TensorType TODO: refactor to be named ScalarType for consistency with TensorType
""" """
ndim = 0
def __init__(self, dtype): def __init__(self, dtype):
if dtype == 'floatX': if dtype == 'floatX':
...@@ -441,6 +442,9 @@ all_types = discrete_types + continuous_types ...@@ -441,6 +442,9 @@ all_types = discrete_types + continuous_types
class _scalar_py_operators: class _scalar_py_operators:
# So that we can simplify checking code when we have a mixture of Scalar
# variables and Tensor variables
ndim = 0
#UNARY #UNARY
def __abs__(self): return abs_(self) def __abs__(self): return abs_(self)
......
...@@ -200,8 +200,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -200,8 +200,7 @@ class PushOutNonSeqScan(gof.Optimizer):
not isinstance(nd.op, theano.compile.ViewOp) and not isinstance(nd.op, theano.compile.ViewOp) and
not isinstance(nd.op, theano.compile.DeepCopyOp) and not isinstance(nd.op, theano.compile.DeepCopyOp) and
# and we didn't already looked at this node # and we didn't already looked at this node
not nd in to_remove not nd in to_remove):
):
# We have a candidate node to removable # We have a candidate node to removable
# Step 1. Reconstruct it on outside # Step 1. Reconstruct it on outside
...@@ -317,12 +316,12 @@ def scan_make_inplace(node): ...@@ -317,12 +316,12 @@ def scan_make_inplace(node):
info['inplace'] = True info['inplace'] = True
# inputs corresponding to sequences and n_steps # inputs corresponding to sequences and n_steps
ls_begin = node.inputs[:1 + op.n_seqs] ls_begin = node.inputs[:1 + op.n_seqs]
ls = op.outer_mitmot(node) ls = op.outer_mitmot(node.inputs)
ls += op.outer_mitsot(node) ls += op.outer_mitsot(node.inputs)
ls += op.outer_sitsot(node) ls += op.outer_sitsot(node.inputs)
ls_end = op.outer_shared(node) ls_end = op.outer_shared(node.inputs)
ls_end += op.outer_nitsot(node) ls_end += op.outer_nitsot(node.inputs)
ls_end += op.outer_non_seqs(node) ls_end += op.outer_non_seqs(node.inputs)
n_outs = len(ls) n_outs = len(ls)
for idx in xrange(n_outs): for idx in xrange(n_outs):
if ls[idx] in ls[:idx]: if ls[idx] in ls[:idx]:
...@@ -717,8 +716,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -717,8 +716,7 @@ class ScanSaveMem(gof.Optimizer):
fslice = slice( fslice = slice(
sanitize(cnf_slice[0].start), sanitize(cnf_slice[0].start),
sanitize(cnf_slice[0].stop), sanitize(cnf_slice[0].stop),
sanitize(cnf_slice[0].step) sanitize(cnf_slice[0].step))
)
else: else:
fslice = sanitize(cnf_slice[0]) fslice = sanitize(cnf_slice[0])
...@@ -850,54 +848,54 @@ class ScanMerge(gof.Optimizer): ...@@ -850,54 +848,54 @@ class ScanMerge(gof.Optimizer):
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# Seq # Seq
inner_ins += rename(nd.op.inner_seqs(), idx) inner_ins += rename(nd.op.inner_seqs(nd.op.inputs), idx)
outer_ins += rename(nd.op.outer_seqs(nd), idx) outer_ins += rename(nd.op.outer_seqs(nd.inputs), idx)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# MitMot # MitMot
inner_ins += rename(nd.op.inner_mitmot(), idx) inner_ins += rename(nd.op.inner_mitmot(nd.op.inputs), idx)
inner_outs += nd.op.inner_mitmot_outs() inner_outs += nd.op.inner_mitmot_outs(nd.op.outputs)
info['tap_array'] += nd.op.mitmot_taps() info['tap_array'] += nd.op.mitmot_taps()
info['mit_mot_out_slices'] += nd.op.mitmot_out_taps() info['mit_mot_out_slices'] += nd.op.mitmot_out_taps()
outer_ins += rename(nd.op.outer_mitmot(nd), idx) outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx)
outer_outs += nd.op.outer_mitmot_outs(nd) outer_outs += nd.op.outer_mitmot_outs(nd.outputs)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# MitSot # MitSot
inner_ins += rename(nd.op.inner_mitsot(), idx) inner_ins += rename(nd.op.inner_mitsot(nd.op.inputs), idx)
inner_outs += nd.op.inner_mitsot_outs() inner_outs += nd.op.inner_mitsot_outs(nd.op.outputs)
info['tap_array'] += nd.op.mitsot_taps() info['tap_array'] += nd.op.mitsot_taps()
outer_ins += rename(nd.op.outer_mitsot(nd), idx) outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx)
outer_outs += nd.op.outer_mitsot_outs(nd) outer_outs += nd.op.outer_mitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# SitSot # SitSot
inner_ins += rename(nd.op.inner_sitsot(), idx) inner_ins += rename(nd.op.inner_sitsot(nd.op.inputs), idx)
info['tap_array'] += [[-1] for x in xrange(nd.op.n_sit_sot)] info['tap_array'] += [[-1] for x in xrange(nd.op.n_sit_sot)]
inner_outs += nd.op.inner_sitsot_outs() inner_outs += nd.op.inner_sitsot_outs(nd.op.outputs)
outer_ins += rename(nd.op.outer_sitsot(nd), idx) outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx)
outer_outs += nd.op.outer_sitsot_outs(nd) outer_outs += nd.op.outer_sitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# Shared # Shared
inner_ins += rename(nd.op.inner_shared(), idx) inner_ins += rename(nd.op.inner_shared(nd.op.inputs), idx)
outer_ins += rename(nd.op.outer_shared(nd), idx) outer_ins += rename(nd.op.outer_shared(nd.inputs), idx)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# NitSot # NitSot
inner_outs += nd.op.inner_nitsot_outs() inner_outs += nd.op.inner_nitsot_outs(nd.op.outputs)
outer_ins += rename(nd.op.outer_nitsot(nd), idx) outer_ins += rename(nd.op.outer_nitsot(nd.inputs), idx)
outer_outs += nd.op.outer_nitsot_outs(nd) outer_outs += nd.op.outer_nitsot_outs(nd.outputs)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# Shared # Shared
outer_outs += nd.op.outer_shared_outs(nd) outer_outs += nd.op.outer_shared_outs(nd.outputs)
inner_outs += nd.op.inner_shared_outs() inner_outs += nd.op.inner_shared_outs(nd.op.outputs)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# Non Seqs # Non Seqs
inner_ins += rename(nd.op.inner_non_seqs(), idx) inner_ins += rename(nd.op.inner_non_seqs(nd.op.inputs), idx)
outer_ins += rename(nd.op.outer_non_seqs(nd), idx) outer_ins += rename(nd.op.outer_non_seqs(nd.inputs), idx)
# Add back the number of steps # Add back the number of steps
outer_ins = [nodes[0].inputs[0]] + outer_ins outer_ins = [nodes[0].inputs[0]] + outer_ins
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论