提交 3ef4a040 authored 作者: lamblin's avatar lamblin

Merge pull request #857 from nouiz/shape_tensor_basic2

Shape tensor basic2
差异被折叠。
......@@ -1546,6 +1546,7 @@ class CAReduceDtype(CAReduce):
# We need to redefine make_node so that, if self.dtype is None,
# we can infer what dtype should be, and create a node from an Op
# of the appropriate dtype.
input = as_tensor_variable(input)
dtype = self._output_dtype(input.dtype)
assert dtype is not None
if dtype == self.dtype:
......
......@@ -360,6 +360,10 @@ class RepeatOp(theano.Op):
repeats = node.inputs[1]
out_shape = list(i0_shapes)
#uint64 shape are not supported.
dtype = None
if repeats.dtype in ['uint8', 'uint16', 'uint32']:
dtype = 'int64'
if self.axis is None:
if repeats.ndim == 0:
if len(i0_shapes) == 0:
......@@ -370,12 +374,12 @@ class RepeatOp(theano.Op):
res = res * d
out_shape = (res * repeats, )
else:
out_shape = [theano.tensor.sum(repeats)]
out_shape = [theano.tensor.sum(repeats, dtype=dtype)]
else:
if repeats.ndim == 0:
out_shape[self.axis] = out_shape[self.axis] * repeats
else:
out_shape[self.axis] = theano.tensor.sum(repeats)
out_shape[self.axis] = theano.tensor.sum(repeats, dtype=dtype)
return [out_shape]
def __str__(self):
......
......@@ -869,7 +869,7 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
# typically we should not need the gradient w.r.t. dy).
y_idx_range = tensor.arange(y_idx.shape[0])
g_dy = tensor.sum(
g_dx * tensor.AdvancedIncSubtensor((y_idx_range, y_idx))(
g_dx * tensor.AdvancedIncSubtensor()(
sm, tensor.fill(dy, -1), y_idx_range, y_idx),
axis=1)
g_sm = dy.dimshuffle(0, 'x') * g_dx
......
......@@ -769,7 +769,8 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
theano.printing.debugprint(f)
raise
g = theano.function([x, y], T.grad(expr, x), mode=mode)
print_graph(g)
if verbose:
print_graph(g)
try:
ops = [node.op for node in g.maker.fgraph.toposort()]
assert len(ops) == 4
......@@ -829,7 +830,8 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
finally:
config.warn.sum_div_dimshuffle_bug = backup
print_graph(g)
if verbose:
print_graph(g)
try:
ops = [node.op for node in g.maker.fgraph.toposort()]
assert len(ops) <= 6
......
......@@ -971,6 +971,7 @@ class ShapeFeature(object):
# keep it this way. See #266 for a better long-term fix.
if getattr(d, 'dtype', 'int64') != 'int64':
assert d.dtype in theano.tensor.discrete_dtypes, d.dtype
assert str(d.dtype) != 'uint64'
new_shape += sh[len(new_shape):i + 1]
new_shape[i] = theano.tensor.cast(d, 'int64')
if new_shape:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论