提交 0e9fe3a9 authored 作者: Frederic's avatar Frederic

In a work around, warn about the problem so that the real cause get fixed.

The added test currently fail.
上级 77c4f4d1
......@@ -8,6 +8,7 @@ import logging
import pdb
import sys
import time
import warnings
import numpy
......@@ -731,7 +732,8 @@ def pre_constant_merge(vars):
seen_var = set()
# signature -> variable (for constants)
const_sig_inv = {}
if isinstance(vars, graph.Variable):
vars = [vars]
def recursive_merge(var):
if var in seen_var:
return var
......@@ -747,6 +749,10 @@ def pre_constant_merge(vars):
return const_sig_inv[sig]
const_sig_inv[sig] = var
except TypeError: # unhashable type
warnings.warn(
"We work around a problem, the following variable"
" signature isn't hashable. Please, report this to"
" theano-dev so that the better fix is done. %s" % var)
# Some python object like slice aren't hashable. So
# don't merge them here.
pass
......
......@@ -409,3 +409,17 @@ class TestEquilibrium(object):
_logger.setLevel(oldlevel)
#print 'after', g
assert str(g) == '[Op1(x, y)]'
def test_pre_constant_merge_slice():
ms = theano.tensor.type_other.MakeSlice()(1)
pre_constant_merge([ms])
const_slice = theano.gof.graph.Constant(
type=theano.tensor.type_other.slicetype,
data=slice(1, None, 2))
adv = theano.tensor.subtensor.AdvancedSubtensor()(theano.tensor.matrix(),
[2, 3], const_slice)
pre_constant_merge(adv)
# Make sure constant of slice signature is hashable.
hash(const_slice.signature())
......@@ -1807,6 +1807,8 @@ def as_index_variable(idx):
return NoneConst.clone()
if isinstance(idx, slice):
return make_slice(idx)
if isinstance(idx, gof.Variable) and isinstance(idx.type, SliceType):
return idx
idx = theano.tensor.as_tensor_variable(idx)
if idx.type.dtype[:3] not in ('int', 'uin'):
raise TypeError('index must be integers')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论