提交 7f593947 authored 作者: Frederic's avatar Frederic

Make SliceConstant to have its signature hashable.

上级 0e9fe3a9
......@@ -414,12 +414,15 @@ class TestEquilibrium(object):
def test_pre_constant_merge_slice():
ms = theano.tensor.type_other.MakeSlice()(1)
pre_constant_merge([ms])
const_slice = theano.gof.graph.Constant(
const_slice = theano.tensor.type_other.SliceConstant(
type=theano.tensor.type_other.slicetype,
data=slice(1, None, 2))
adv = theano.tensor.subtensor.AdvancedSubtensor()(theano.tensor.matrix(),
[2, 3], const_slice)
pre_constant_merge(adv)
cst = pre_greedy_local_optimizer([theano.tensor.opt.constant_folding], ms)
assert isinstance(cst, theano.tensor.type_other.SliceConstant)
# Make sure constant of slice signature is hashable.
hash(const_slice.signature())
hash(cst.signature())
#
# Slice type and Op. None Type and NoneConst.
#
import numpy
import theano
from theano.gof import Apply, Constant, Generic, Op, Type, hashtype
from theano.gradient import DisconnectedType
......@@ -76,6 +79,35 @@ class SliceType(Type):
slicetype = SliceType()
class SliceConstant(Constant):
def __init__(self, type, data, name=None):
assert isinstance(data, slice)
# Numpy ndarray aren't hashable, so get rid of them.
if isinstance(data.start, numpy.ndarray):
assert data.start.ndim == 0
assert "int" in str(data.start.dtype)
data = slice(int(data.start), data.stop, data.step)
elif isinstance(data.stop, numpy.ndarray):
assert data.stop.ndim == 0
assert "int" in str(data.stop.dtype)
data = slice(data.start, int(data.stop), data.step)
elif isinstance(data.step, numpy.ndarray):
assert data.step.ndim == 0
assert "int" in str(data.step.dtype)
data = slice(data.start, int(data.stop), data.step)
Constant.__init__(self, type, data, name)
def signature(self):
return (SliceConstant, self.data.start, self.data.stop, self.data.step)
def __str__(self):
return "%s{%s, %s, %s}" % (self.__class__.__name__,
self.data.start,
self.data.stop,
self.data.step)
SliceType.Constant = SliceConstant
class NoneTypeT(Generic):
"""
Inherit from Generic to have c code working.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论