提交 dc32d57a authored 作者: Frederic's avatar Frederic

Make MakeSlice.make_node accept the node inputs as parameter.

This is needed for scan optimization to rebuild the node outside the scan.
上级 6e758c8f
...@@ -9,6 +9,8 @@ from theano.gradient import DisconnectedType ...@@ -9,6 +9,8 @@ from theano.gradient import DisconnectedType
def as_int_none_variable(x): def as_int_none_variable(x):
if x is None: if x is None:
return NoneConst return NoneConst
elif NoneConst.equals(x):
return x
x = theano.tensor.as_tensor_variable(x, ndim=0) x = theano.tensor.as_tensor_variable(x, ndim=0)
if x.type.dtype[:3] not in ('int', 'uin'): if x.type.dtype[:3] not in ('int', 'uin'):
raise TypeError('index must be integers') raise TypeError('index must be integers')
...@@ -16,10 +18,18 @@ def as_int_none_variable(x): ...@@ -16,10 +18,18 @@ def as_int_none_variable(x):
class MakeSlice(Op): class MakeSlice(Op):
def make_node(self, slc): def make_node(self, slc, stop=None, step=None):
# We need to accept and handle in make_node inputs the node
# inputs to allow redoing a new op elsewhere in the graph by
# optimization.
if isinstance(slc, slice):
assert stop is None
assert step is None
inp = [slc.start, slc.stop, slc.step]
else:
inp = [slc, stop, step]
return Apply(self, return Apply(self,
map(as_int_none_variable, map(as_int_none_variable, inp),
[slc.start, slc.stop, slc.step]),
[slicetype()]) [slicetype()])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论