提交 0e9fe3a9 authored 作者: Frederic's avatar Frederic

In a work around, warn about the problem so that the real cause get fixed.

The added test currently fail.
上级 77c4f4d1
...@@ -8,6 +8,7 @@ import logging ...@@ -8,6 +8,7 @@ import logging
import pdb import pdb
import sys import sys
import time import time
import warnings
import numpy import numpy
...@@ -731,7 +732,8 @@ def pre_constant_merge(vars): ...@@ -731,7 +732,8 @@ def pre_constant_merge(vars):
seen_var = set() seen_var = set()
# signature -> variable (for constants) # signature -> variable (for constants)
const_sig_inv = {} const_sig_inv = {}
if isinstance(vars, graph.Variable):
vars = [vars]
def recursive_merge(var): def recursive_merge(var):
if var in seen_var: if var in seen_var:
return var return var
...@@ -747,6 +749,10 @@ def pre_constant_merge(vars): ...@@ -747,6 +749,10 @@ def pre_constant_merge(vars):
return const_sig_inv[sig] return const_sig_inv[sig]
const_sig_inv[sig] = var const_sig_inv[sig] = var
except TypeError: # unhashable type except TypeError: # unhashable type
warnings.warn(
"We work around a problem, the following variable"
" signature isn't hashable. Please, report this to"
" theano-dev so that the better fix is done. %s" % var)
# Some python object like slice aren't hashable. So # Some python object like slice aren't hashable. So
# don't merge them here. # don't merge them here.
pass pass
......
...@@ -409,3 +409,17 @@ class TestEquilibrium(object): ...@@ -409,3 +409,17 @@ class TestEquilibrium(object):
_logger.setLevel(oldlevel) _logger.setLevel(oldlevel)
#print 'after', g #print 'after', g
assert str(g) == '[Op1(x, y)]' assert str(g) == '[Op1(x, y)]'
def test_pre_constant_merge_slice():
ms = theano.tensor.type_other.MakeSlice()(1)
pre_constant_merge([ms])
const_slice = theano.gof.graph.Constant(
type=theano.tensor.type_other.slicetype,
data=slice(1, None, 2))
adv = theano.tensor.subtensor.AdvancedSubtensor()(theano.tensor.matrix(),
[2, 3], const_slice)
pre_constant_merge(adv)
# Make sure constant of slice signature is hashable.
hash(const_slice.signature())
...@@ -1807,6 +1807,8 @@ def as_index_variable(idx): ...@@ -1807,6 +1807,8 @@ def as_index_variable(idx):
return NoneConst.clone() return NoneConst.clone()
if isinstance(idx, slice): if isinstance(idx, slice):
return make_slice(idx) return make_slice(idx)
if isinstance(idx, gof.Variable) and isinstance(idx.type, SliceType):
return idx
idx = theano.tensor.as_tensor_variable(idx) idx = theano.tensor.as_tensor_variable(idx)
if idx.type.dtype[:3] not in ('int', 'uin'): if idx.type.dtype[:3] not in ('int', 'uin'):
raise TypeError('index must be integers') raise TypeError('index must be integers')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论