Fixed incsubtensor (now uses all required inputs). IncSubtensor's make_node now…

Fixed incsubtensor (now uses all required inputs). IncSubtensor's make_node now deals with slices that contain tensors. Added a basic test.
上级 bb09fd5c
...@@ -1830,9 +1830,9 @@ class Subtensor(Op): ...@@ -1830,9 +1830,9 @@ class Subtensor(Op):
return entry.type return entry.type
elif isinstance(entry, gof.Type) and entry in scal_types: elif isinstance(entry, gof.Type) and entry in scal_types:
return entry return entry
if isinstance(entry, gof.Variable) and entry.type in tensor_types: if isinstance(entry, gof.Variable) and entry.type in tensor_types and numpy.all(entry.type.broadcastable):
return scal.Scalar(entry.type.dtype) return scal.Scalar(entry.type.dtype)
elif isinstance(entry, gof.Type) and entry in tensor_types: elif isinstance(entry, gof.Type) and entry in tensor_types and numpy.all(entry.broadcastable):
return scal.Scalar(entry.dtype) return scal.Scalar(entry.dtype)
elif slice_ok and isinstance(entry, slice): elif slice_ok and isinstance(entry, slice):
a = entry.start a = entry.start
...@@ -1868,14 +1868,20 @@ class Subtensor(Op): ...@@ -1868,14 +1868,20 @@ class Subtensor(Op):
def __init__(self, idx_list): def __init__(self, idx_list):
self.idx_list = map(self.convert, idx_list) self.idx_list = map(self.convert, idx_list)
def make_node(self, x, *inputs): @staticmethod
x = as_tensor_variable(x)
def my_as_scalar(a): def my_as_scalar(a):
# Since scal.as_scalar does not know about tensor types (it would
# create a circular import) , this method converts either a
# TensorVariable or a ScalarVariable to a scalar.
if isinstance(a, gof.Variable) and isinstance(a.type, TensorType): if isinstance(a, gof.Variable) and isinstance(a.type, TensorType):
return scalar_from_tensor(a) return scalar_from_tensor(a)
else: else:
return scal.as_scalar(a) return scal.as_scalar(a)
inputs = tuple(my_as_scalar(a) for a in inputs)
def make_node(self, x, *inputs):
x = as_tensor_variable(x)
inputs = tuple(self.my_as_scalar(a) for a in inputs)
idx_list = list(self.idx_list) idx_list = list(self.idx_list)
if len(idx_list) > x.type.ndim: if len(idx_list) > x.type.ndim:
...@@ -2008,7 +2014,8 @@ pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Subtensor), S ...@@ -2008,7 +2014,8 @@ pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Subtensor), S
def incsubtensor(x, y, idx_list, inplace=False, set_instead_of_inc=False): def incsubtensor(x, y, idx_list, inplace=False, set_instead_of_inc=False):
the_op = IncSubtensor(idx_list, inplace, set_instead_of_inc) the_op = IncSubtensor(idx_list, inplace, set_instead_of_inc)
return the_op(x, y)
return the_op(x, y, *Subtensor.collapse(idx_list, lambda entry: isinstance(entry, Variable)))
class IncSubtensor(Op): class IncSubtensor(Op):
"""Increment a subtensor. """Increment a subtensor.
...@@ -2073,7 +2080,7 @@ class IncSubtensor(Op): ...@@ -2073,7 +2080,7 @@ class IncSubtensor(Op):
def make_node(self, x, y, *inputs): def make_node(self, x, y, *inputs):
x, y = map(as_tensor_variable, [x, y]) x, y = map(as_tensor_variable, [x, y])
inputs = tuple(map(scal.as_scalar, inputs)) inputs = tuple(map(Subtensor.my_as_scalar, inputs))
idx_list = list(self.idx_list) idx_list = list(self.idx_list)
if len(idx_list) > x.type.ndim: if len(idx_list) > x.type.ndim:
...@@ -2084,8 +2091,8 @@ class IncSubtensor(Op): ...@@ -2084,8 +2091,8 @@ class IncSubtensor(Op):
padded = idx_list + [slice(0,sys.maxint,1)] * (x.type.ndim - len(idx_list)) padded = idx_list + [slice(0,sys.maxint,1)] * (x.type.ndim - len(idx_list))
broadcastable = [bc for p, bc in zip(padded, x.type.broadcastable) if isinstance(p, slice)] broadcastable = [bc for p, bc in zip(padded, x.type.broadcastable) if isinstance(p, slice)]
if y.type.broadcastable != tuple(broadcastable): #if y.type.broadcastable != tuple(broadcastable):
raise TypeError("Invalid broadcastable pattern for y in IncSubtensor.make_node") # raise TypeError("Invalid broadcastable pattern for y in IncSubtensor.make_node")
input_types = Subtensor.collapse(idx_list, lambda entry: isinstance(entry, gof.Type)) input_types = Subtensor.collapse(idx_list, lambda entry: isinstance(entry, gof.Type))
if len(inputs) != len(input_types): if len(inputs) != len(input_types):
......
import numpy as N
import unittest
from theano.tests import unittest_tools as utt
import theano
import theano.tensor as T
class Test_incsubtensor(unittest.TestCase):
"""Partial testing.
What could be tested:
- increment vs set
- thing incremented: scalar, vector, matrix,
- increment/set: constant, scalar, vector, matrix
- indices: scalar vs slice, constant vs variable, out of bound, ...
- inplace
"""
def setUp(self):
utt.seed_rng()
def test_simple_ok(self):
"""Increments or sets part of a tensor by a scalar using full slice and
a partial slice depending on a scalar.
"""
a = T.dmatrix()
increment = T.dscalar()
sl1 = slice(None)
sl2_end = T.lscalar()
sl2 = slice(sl2_end)
for do_set in [False,True]:
a_incremented = T.incsubtensor(a, increment, [sl1, sl2], set_instead_of_inc=do_set)
f = theano.function([a, increment, sl2_end], a_incremented)
val_a = N.ones((5,5))
val_inc = 2.3
val_sl2_end = 2
result = f(val_a, val_inc, val_sl2_end)
expected_result = N.copy(val_a)
if do_set:
expected_result[:,:val_sl2_end] = val_inc
else:
expected_result[:,:val_sl2_end] += val_inc
self.failUnless(N.array_equal(result, expected_result))
return
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论