提交 7f593947 authored 作者: Frederic's avatar Frederic

Make SliceConstant to have its signature hashable.

上级 0e9fe3a9
...@@ -414,12 +414,15 @@ class TestEquilibrium(object): ...@@ -414,12 +414,15 @@ class TestEquilibrium(object):
def test_pre_constant_merge_slice(): def test_pre_constant_merge_slice():
ms = theano.tensor.type_other.MakeSlice()(1) ms = theano.tensor.type_other.MakeSlice()(1)
pre_constant_merge([ms]) pre_constant_merge([ms])
const_slice = theano.gof.graph.Constant( const_slice = theano.tensor.type_other.SliceConstant(
type=theano.tensor.type_other.slicetype, type=theano.tensor.type_other.slicetype,
data=slice(1, None, 2)) data=slice(1, None, 2))
adv = theano.tensor.subtensor.AdvancedSubtensor()(theano.tensor.matrix(), adv = theano.tensor.subtensor.AdvancedSubtensor()(theano.tensor.matrix(),
[2, 3], const_slice) [2, 3], const_slice)
pre_constant_merge(adv) pre_constant_merge(adv)
cst = pre_greedy_local_optimizer([theano.tensor.opt.constant_folding], ms)
assert isinstance(cst, theano.tensor.type_other.SliceConstant)
# Make sure constant of slice signature is hashable. # Make sure constant of slice signature is hashable.
hash(const_slice.signature()) hash(cst.signature())
# #
# Slice type and Op. None Type and NoneConst. # Slice type and Op. None Type and NoneConst.
# #
import numpy
import theano import theano
from theano.gof import Apply, Constant, Generic, Op, Type, hashtype from theano.gof import Apply, Constant, Generic, Op, Type, hashtype
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
...@@ -76,6 +79,35 @@ class SliceType(Type): ...@@ -76,6 +79,35 @@ class SliceType(Type):
slicetype = SliceType() slicetype = SliceType()
class SliceConstant(Constant):
def __init__(self, type, data, name=None):
assert isinstance(data, slice)
# Numpy ndarray aren't hashable, so get rid of them.
if isinstance(data.start, numpy.ndarray):
assert data.start.ndim == 0
assert "int" in str(data.start.dtype)
data = slice(int(data.start), data.stop, data.step)
elif isinstance(data.stop, numpy.ndarray):
assert data.stop.ndim == 0
assert "int" in str(data.stop.dtype)
data = slice(data.start, int(data.stop), data.step)
elif isinstance(data.step, numpy.ndarray):
assert data.step.ndim == 0
assert "int" in str(data.step.dtype)
data = slice(data.start, int(data.stop), data.step)
Constant.__init__(self, type, data, name)
def signature(self):
return (SliceConstant, self.data.start, self.data.stop, self.data.step)
def __str__(self):
return "%s{%s, %s, %s}" % (self.__class__.__name__,
self.data.start,
self.data.stop,
self.data.step)
SliceType.Constant = SliceConstant
class NoneTypeT(Generic): class NoneTypeT(Generic):
""" """
Inherit from Generic to have c code working. Inherit from Generic to have c code working.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论