Fixed incsubtensor (now uses all required inputs). IncSubtensor's make_node now…

Fixed incsubtensor (now uses all required inputs). IncSubtensor's make_node now deals with slices that contain tensors. Added a basic test.
上级 bb09fd5c
......@@ -1830,9 +1830,9 @@ class Subtensor(Op):
return entry.type
elif isinstance(entry, gof.Type) and entry in scal_types:
return entry
if isinstance(entry, gof.Variable) and entry.type in tensor_types:
if isinstance(entry, gof.Variable) and entry.type in tensor_types and numpy.all(entry.type.broadcastable):
return scal.Scalar(entry.type.dtype)
elif isinstance(entry, gof.Type) and entry in tensor_types:
elif isinstance(entry, gof.Type) and entry in tensor_types and numpy.all(entry.broadcastable):
return scal.Scalar(entry.dtype)
elif slice_ok and isinstance(entry, slice):
a = entry.start
......@@ -1868,14 +1868,20 @@ class Subtensor(Op):
def __init__(self, idx_list):
self.idx_list = map(self.convert, idx_list)
@staticmethod
def my_as_scalar(a):
# Since scal.as_scalar does not know about tensor types (it would
# create a circular import) , this method converts either a
# TensorVariable or a ScalarVariable to a scalar.
if isinstance(a, gof.Variable) and isinstance(a.type, TensorType):
return scalar_from_tensor(a)
else:
return scal.as_scalar(a)
def make_node(self, x, *inputs):
x = as_tensor_variable(x)
def my_as_scalar(a):
if isinstance(a, gof.Variable) and isinstance(a.type, TensorType):
return scalar_from_tensor(a)
else:
return scal.as_scalar(a)
inputs = tuple(my_as_scalar(a) for a in inputs)
inputs = tuple(self.my_as_scalar(a) for a in inputs)
idx_list = list(self.idx_list)
if len(idx_list) > x.type.ndim:
......@@ -2008,7 +2014,8 @@ pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Subtensor), S
def incsubtensor(x, y, idx_list, inplace=False, set_instead_of_inc=False):
the_op = IncSubtensor(idx_list, inplace, set_instead_of_inc)
return the_op(x, y)
return the_op(x, y, *Subtensor.collapse(idx_list, lambda entry: isinstance(entry, Variable)))
class IncSubtensor(Op):
"""Increment a subtensor.
......@@ -2073,7 +2080,7 @@ class IncSubtensor(Op):
def make_node(self, x, y, *inputs):
x, y = map(as_tensor_variable, [x, y])
inputs = tuple(map(scal.as_scalar, inputs))
inputs = tuple(map(Subtensor.my_as_scalar, inputs))
idx_list = list(self.idx_list)
if len(idx_list) > x.type.ndim:
......@@ -2084,8 +2091,8 @@ class IncSubtensor(Op):
padded = idx_list + [slice(0,sys.maxint,1)] * (x.type.ndim - len(idx_list))
broadcastable = [bc for p, bc in zip(padded, x.type.broadcastable) if isinstance(p, slice)]
if y.type.broadcastable != tuple(broadcastable):
raise TypeError("Invalid broadcastable pattern for y in IncSubtensor.make_node")
#if y.type.broadcastable != tuple(broadcastable):
# raise TypeError("Invalid broadcastable pattern for y in IncSubtensor.make_node")
input_types = Subtensor.collapse(idx_list, lambda entry: isinstance(entry, gof.Type))
if len(inputs) != len(input_types):
......
import numpy as N
import unittest
from theano.tests import unittest_tools as utt
import theano
import theano.tensor as T
class Test_incsubtensor(unittest.TestCase):
"""Partial testing.
What could be tested:
- increment vs set
- thing incremented: scalar, vector, matrix,
- increment/set: constant, scalar, vector, matrix
- indices: scalar vs slice, constant vs variable, out of bound, ...
- inplace
"""
def setUp(self):
utt.seed_rng()
def test_simple_ok(self):
"""Increments or sets part of a tensor by a scalar using full slice and
a partial slice depending on a scalar.
"""
a = T.dmatrix()
increment = T.dscalar()
sl1 = slice(None)
sl2_end = T.lscalar()
sl2 = slice(sl2_end)
for do_set in [False,True]:
a_incremented = T.incsubtensor(a, increment, [sl1, sl2], set_instead_of_inc=do_set)
f = theano.function([a, increment, sl2_end], a_incremented)
val_a = N.ones((5,5))
val_inc = 2.3
val_sl2_end = 2
result = f(val_a, val_inc, val_sl2_end)
expected_result = N.copy(val_a)
if do_set:
expected_result[:,:val_sl2_end] = val_inc
else:
expected_result[:,:val_sl2_end] += val_inc
self.failUnless(N.array_equal(result, expected_result))
return
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论