提交 1a0122fe authored 作者: Pascal Lamblin's avatar Pascal Lamblin

New ARange Op, arange (symbolic) function, and tests for them.

上级 2feb66f4
...@@ -1590,7 +1590,6 @@ def one(): ...@@ -1590,7 +1590,6 @@ def one():
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Filler) and r.owner.op.value == 0, printing.FunctionPrinter('zeros')) pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Filler) and r.owner.op.value == 0, printing.FunctionPrinter('zeros'))
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Filler) and r.owner.op.value == 1, printing.FunctionPrinter('ones')) pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Filler) and r.owner.op.value == 1, printing.FunctionPrinter('ones'))
@_redefine(elemwise.Elemwise(scal.identity)) @_redefine(elemwise.Elemwise(scal.identity))
def tensor_copy(a): def tensor_copy(a):
"""Create a duplicate of `a` (with duplicated storage)""" """Create a duplicate of `a` (with duplicated storage)"""
...@@ -2714,6 +2713,59 @@ def tile(x, reps, ndim=None): ...@@ -2714,6 +2713,59 @@ def tile(x, reps, ndim=None):
tile.op[ndim] = Tile(ndim) tile.op[ndim] = Tile(ndim)
return tile.op[ndim](x, reps) return tile.op[ndim](x, reps)
class ARange(Op):
"""Create an array containing evenly spaced values within a given interval.
Parameters and behaviour are the same as numpy.arange().
"""
def __init__(self, dtype):
self.dtype = dtype
def __eq__(self, other):
return type(self) == type(other) and self.dtype == other.dtype
def __hash__(self):
return hash(self.dtype)
def make_node(self, start, stop, step):
start, stop, step = map(as_tensor_variable, (start, stop, step))
assert start.ndim == 0
assert stop.ndim == 0
assert step.ndim == 0
inputs = [start, stop, step]
outputs = [tensor(self.dtype, (False,))]
return Apply(self, inputs, outputs)
def perform(self, node, (start, stop, step), (out,)):
print repr(start), repr(stop), repr(step)
start = start.item()
stop = stop.item()
step = step.item()
out[0] = numpy.arange(start, stop, step, dtype=self.dtype)
def grad(self, inputs, (gz,)):
return [None] * len(inputs)
_arange = {}
def arange(start, stop=None, step=1, dtype=None):
# If only one argument is provided, it is in fact the "stop" argument,
# and start is 0.
if stop is None:
start, stop = 0, start
start, stop, step = map(as_tensor_variable, (start, stop, step))
# If dtype is not provided, infer it from the other arguments
if dtype is None:
dtype = scal.upcast(start.type.dtype, stop.type.dtype, step.type.dtype)
if dtype not in _arange:
_arange[dtype] = ARange(dtype)
return _arange[dtype](start, stop, step)
class InversePermutation(Op): class InversePermutation(Op):
"""Computes the inverse of permutations. """Computes the inverse of permutations.
......
...@@ -1790,6 +1790,133 @@ def test_tile(): ...@@ -1790,6 +1790,133 @@ def test_tile():
print >> sys.stderr, "WARNING: No testcase for Tile" print >> sys.stderr, "WARNING: No testcase for Tile"
pass pass
class TestARange(unittest.TestCase):
def setUp(self):
utt.seed_rng()
def test_Op_integers(self):
"""Test behaviour of ARange Op on integer inputs"""
start, stop, step = iscalars('start', 'stop', 'step')
out = ARange(start.type.dtype)(start, stop, step)
f = function([start, stop, step], out)
assert numpy.all(f(0,5,1) == numpy.arange(0,5,1))
assert numpy.all(f(2,11,4) == numpy.arange(2,11,4))
assert numpy.all(f(-5,1,1) == numpy.arange(-5,1,1))
assert numpy.all(f(10,2,-2) == numpy.arange(10,2,-2))
assert numpy.all(f(10,2,2) == numpy.arange(10,2,2))
assert numpy.all(f(0,0,1) == numpy.arange(0,0,1))
def test_integers(self):
"""Test arange constructor, on integer outputs"""
start, stop, step = iscalars('start', 'stop', 'step')
out = arange(start, stop, step)
f = function([start, stop, step], out)
assert out.dtype == start.type.dtype
assert numpy.all(f(0,5,1) == numpy.arange(0,5,1))
assert numpy.all(f(2,11,4) == numpy.arange(2,11,4))
assert numpy.all(f(-5,1,1) == numpy.arange(-5,1,1))
assert numpy.all(f(10,2,-2) == numpy.arange(10,2,-2))
assert numpy.all(f(10,2,2) == numpy.arange(10,2,2))
assert numpy.all(f(0,0,1) == numpy.arange(0,0,1))
def test_float32(self):
"""Test arange constructor, on integer outputs"""
start, stop, step = fscalars('start', 'stop', 'step')
out = arange(start, stop, step)
f = function([start, stop, step], out)
assert out.dtype == start.type.dtype
assert numpy.all(f(0,5,1) == numpy.arange(0,5,1, dtype=start.type.dtype))
assert numpy.all(f(2,11,4) == numpy.arange(2,11,4, dtype=start.type.dtype))
assert numpy.all(f(-5,1.1,1.2) == numpy.arange(-5,1.1,1.2, dtype=start.type.dtype))
assert numpy.all(f(1.3,2,-2.1) == numpy.arange(1.3,2,-2.1, dtype=start.type.dtype))
assert numpy.all(f(10,2,2) == numpy.arange(10,2,2, dtype=start.type.dtype))
def test_float64(self):
"""Test arange constructor, on integer outputs"""
start, stop, step = dscalars('start', 'stop', 'step')
out = arange(start, stop, step)
f = function([start, stop, step], out)
assert out.dtype == start.type.dtype
assert numpy.all(f(0,5,1) == numpy.arange(0,5,1, dtype=start.type.dtype))
assert numpy.all(f(2,11,4) == numpy.arange(2,11,4, dtype=start.type.dtype))
assert numpy.all(f(-5,1.1,1.2) == numpy.arange(-5,1.1,1.2, dtype=start.type.dtype))
assert numpy.all(f(1.3,2,-2.1) == numpy.arange(1.3,2,-2.1, dtype=start.type.dtype))
assert numpy.all(f(10,2,2) == numpy.arange(10,2,2, dtype=start.type.dtype))
def test_default_step(self):
"""Test that arange constructor uses the correct default step"""
start, stop = iscalars('start', 'stop')
out = arange(start, stop)
f = function([start, stop], out)
assert out.dtype == start.type.dtype
assert numpy.all(f(0,5) == numpy.arange(0,5))
assert numpy.all(f(-5,1) == numpy.arange(-5,1))
assert numpy.all(f(0,0) == numpy.arange(0,0))
dstart, dstop = dscalars('start', 'stop')
dout = arange(dstart, dstop)
df = function([dstart, dstop], dout)
assert dout.dtype == dstart.type.dtype
print df(0.2, 5.3)
print numpy.arange(0.2, 5.3)
assert numpy.all(df(0.2, 5.3) == numpy.arange(0.2, 5.3))
assert numpy.all(df(0.8, 5.3) == numpy.arange(0.8, 5.3))
assert numpy.all(df(-0.7, 5.3) == numpy.arange(-0.7, 5.3))
def test_default_start(self):
"""Test that arange constructor uses the correct default start"""
stop = iscalar('stop')
out = arange(stop)
f = function([stop], out)
assert out.dtype == stop.type.dtype
assert numpy.all(f(8) == numpy.arange(8))
assert numpy.all(f(-2) == numpy.arange(-2))
fstop = fscalar('stop')
fout = arange(fstop)
ff = function([fstop], fout)
assert fout.dtype == fstop.type.dtype
assert numpy.all(ff(0.2) == numpy.arange(0.2))
assert numpy.all(ff(-0.7) == numpy.arange(-0.7))
assert numpy.all(ff(8.5) == numpy.arange(8.5))
def test_upcast(self):
"""Test that arange compute output type adequately"""
assert arange(iscalar()).dtype == iscalar().dtype
assert arange(fscalar()).dtype == fscalar().dtype
assert arange(dscalar()).dtype == dscalar().dtype
# int32 + float32 -> float64
assert arange(iscalar(), fscalar()).dtype == dscalar().dtype
assert arange(iscalar(), dscalar()).dtype == dscalar().dtype
assert arange(fscalar(), dscalar()).dtype == dscalar().dtype
assert arange(iscalar(), fscalar(), dscalar()).dtype == dscalar().dtype
def test_dtype_cache(self):
"""Checks that the same Op is returned on repeated calls to arange
using the same dtype, but not for different dtypes."""
start, stop, step = iscalars('start', 'stop', 'step')
out1 = arange(start, stop, step)
out2 = arange(start, stop, step, dtype=start.type.dtype)
out3 = arange(start, stop, 2., dtype=start.type.dtype)
out4 = arange(start, stop, 2.)
assert out1.owner.op is out2.owner.op
assert out2.owner.op is out3.owner.op
assert out3.owner.op is not out4.owner.op
class TestInversePermutation(unittest.TestCase): class TestInversePermutation(unittest.TestCase):
def setUp(self): def setUp(self):
utt.seed_rng() utt.seed_rng()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论