提交 ea3e7426 authored 作者: Melanie Ducoffe's avatar Melanie Ducoffe

en cours

Conflicts: theano/tensor/tests/test_basic.py
上级 a6a7afb4
...@@ -17,7 +17,7 @@ from theano.tensor import elemwise ...@@ -17,7 +17,7 @@ from theano.tensor import elemwise
from theano.tensor.var import (AsTensorError, TensorVariable, from theano.tensor.var import (AsTensorError, TensorVariable,
TensorConstant, TensorConstant,
_tensor_py_operators) _tensor_py_operators)
from theano.tensor.type import TensorType from theano.tensor.type import TensorType, values_eq_approx_always_true
from theano.tensor.type_other import NoneConst from theano.tensor.type_other import NoneConst
from theano import scalar as scal from theano import scalar as scal
from theano.compat import partial from theano.compat import partial
...@@ -5478,28 +5478,27 @@ class AllocEmpty(gof.Op): ...@@ -5478,28 +5478,27 @@ class AllocEmpty(gof.Op):
# specify the type of the data # specify the type of the data
def __init__(self, dtype): def __init__(self, dtype):
assert isinstance(dtype, str) assert isinstance(dtype, str)
self.dtype = 'NPY_' + dtype.upper() self.dtype = dtype.lower()
@staticmethod def validate_shape(self, shape):
def validate_shape(shape): sh = [as_tensor_variable(s) for s in shape]
sh = [tensor.as_tensor_variable(s) for s in shape]
bcast = [] bcast = []
for s in sh: for s in sh:
if s.type.dtype[:3] not in ('int', 'uin'): if s.type.dtype[:3] not in ('int', 'uin'):
raise TypeError('Shape arguments must be integers', s) raise TypeError('Shape arguments must be integers', s)
# if s is constant 1, then we're broadcastable in that dim # if s is constant 1, then we're broadcastable in that dim
try: try:
const_shp = tensor.get_scalar_constant_value(s) const_shp = get_scalar_constant_value(s)
except tensor.NotScalarConstantError: except NotScalarConstantError:
const_shp = None const_shp = None
bcast.append(numpy.all(1 == const_shp)) bcast.append(numpy.all(1 == const_shp))
otype = tensor.TensorType(dtype=self.dtype, broadcastable=bcast) otype = TensorType(dtype=self.dtype, broadcastable=bcast)
output = otype() output = otype()
return sh, output return sh, output
def make_node(self, *shape): def make_node(self, *shape):
shape, output = self.validate_shape(shape) shape, output = self.validate_shape(shape)
output.tag.values_eq_approx = tensor.type.values_eq_approx_always_true output.tag.values_eq_approx = values_eq_approx_always_true
return Apply(self, shape, [output]) return Apply(self, shape, [output])
def perform(self, node, inputs, out_): def perform(self, node, inputs, out_):
...@@ -5513,15 +5512,17 @@ class AllocEmpty(gof.Op): ...@@ -5513,15 +5512,17 @@ class AllocEmpty(gof.Op):
return False return False
def c_code(self, node, name, inputs, out_, sub): def c_code(self, node, name, inputs, out_, sub):
dtype = self.dtype dtype = "NPY_"+self.dtype.upper()
out, = out_ out, = out_
fail = sub['fail'] fail = sub['fail']
shps = inputs shps = inputs
nd = len(shps) nd = len(shps)
str = "int dims[%(nd)s];\n" % locals() str = "int dims[%(nd)s];\n" % locals()
for idx, sh in enumerate(shps): for idx, sh in enumerate(shps):
str += "dims[%(idx)s] =" \ str += "dims[%(idx)s] =" \
"PyInt_AsLong((PyObject*)%(sh)s);\n" % locals() "PyInt_AsLong((PyObject*)%(sh)s);\n" % locals()
# Validate that the output storage exists # Validate that the output storage exists
str += "if(%(out)s==NULL\n" % locals() str += "if(%(out)s==NULL\n" % locals()
for idx, sh in enumerate(shps): for idx, sh in enumerate(shps):
...@@ -5533,7 +5534,7 @@ class AllocEmpty(gof.Op): ...@@ -5533,7 +5534,7 @@ class AllocEmpty(gof.Op):
output variable */ output variable */
Py_XDECREF(%(out)s); Py_XDECREF(%(out)s);
%(out)s = (PyArrayObject*)PyArray_EMPTY(%(nd)s, %(out)s = (PyArrayObject*)PyArray_EMPTY(%(nd)s,
PyArray_DIMS(dims), dims,
%(dtype)s, %(dtype)s,
0); 0);
if (!%(out)s) if (!%(out)s)
......
...@@ -46,7 +46,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -46,7 +46,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
itensor3, Tile, switch, Diagonal, Diag, itensor3, Tile, switch, Diagonal, Diag,
nonzero, flatnonzero, nonzero_values, nonzero, flatnonzero, nonzero_values,
stacklists, DimShuffle, hessian, ptp, power, stacklists, DimShuffle, hessian, ptp, power,
swapaxes, choose, Choose, NoneConst, swapaxes, choose, Choose, NoneConst, AllocEmpty
) )
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -7536,6 +7536,25 @@ class T_Choose(utt.InferShapeTester): ...@@ -7536,6 +7536,25 @@ class T_Choose(utt.InferShapeTester):
# Op that should be removed from the graph. # Op that should be removed from the graph.
self.op_class) self.op_class)
def test_allocempty():
# Test that we allocated correctly
f = theano.function([], AllocEmpty("float32")(2, 3)) # change
assert len(f.maker.fgraph.apply_nodes) == 1
out = f()
assert out.shape == (2, 3)
assert out.dtype == 'float32'
# Test that we do not merge them.
f = theano.function([], [AllocEmpty("float32")(2, 3),
AllocEmpty("float32")(2, 3)])
out = f()
assert out[0].shape == (2, 3)
assert out[0].dtype == 'float32'
assert out[1].shape == (2, 3)
assert out[1].dtype == 'float32'
assert len([node for node in f.maker.fgraph.apply_nodes
if isinstance(node.op, AllocEmpty)]) == 2
""" """
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论