提交 ea3e7426 authored 作者: Melanie Ducoffe's avatar Melanie Ducoffe

en cours

Conflicts: theano/tensor/tests/test_basic.py
上级 a6a7afb4
......@@ -17,7 +17,7 @@ from theano.tensor import elemwise
from theano.tensor.var import (AsTensorError, TensorVariable,
TensorConstant,
_tensor_py_operators)
from theano.tensor.type import TensorType
from theano.tensor.type import TensorType, values_eq_approx_always_true
from theano.tensor.type_other import NoneConst
from theano import scalar as scal
from theano.compat import partial
......@@ -5478,28 +5478,27 @@ class AllocEmpty(gof.Op):
# specify the type of the data
def __init__(self, dtype):
assert isinstance(dtype, str)
self.dtype = 'NPY_' + dtype.upper()
self.dtype = dtype.lower()
@staticmethod
def validate_shape(shape):
sh = [tensor.as_tensor_variable(s) for s in shape]
def validate_shape(self, shape):
sh = [as_tensor_variable(s) for s in shape]
bcast = []
for s in sh:
if s.type.dtype[:3] not in ('int', 'uin'):
raise TypeError('Shape arguments must be integers', s)
# if s is constant 1, then we're broadcastable in that dim
try:
const_shp = tensor.get_scalar_constant_value(s)
except tensor.NotScalarConstantError:
const_shp = get_scalar_constant_value(s)
except NotScalarConstantError:
const_shp = None
bcast.append(numpy.all(1 == const_shp))
otype = tensor.TensorType(dtype=self.dtype, broadcastable=bcast)
otype = TensorType(dtype=self.dtype, broadcastable=bcast)
output = otype()
return sh, output
def make_node(self, *shape):
shape, output = self.validate_shape(shape)
output.tag.values_eq_approx = tensor.type.values_eq_approx_always_true
output.tag.values_eq_approx = values_eq_approx_always_true
return Apply(self, shape, [output])
def perform(self, node, inputs, out_):
......@@ -5513,15 +5512,17 @@ class AllocEmpty(gof.Op):
return False
def c_code(self, node, name, inputs, out_, sub):
dtype = self.dtype
dtype = "NPY_"+self.dtype.upper()
out, = out_
fail = sub['fail']
shps = inputs
nd = len(shps)
str = "int dims[%(nd)s];\n" % locals()
for idx, sh in enumerate(shps):
str += "dims[%(idx)s] =" \
"PyInt_AsLong((PyObject*)%(sh)s);\n" % locals()
# Validate that the output storage exists
str += "if(%(out)s==NULL\n" % locals()
for idx, sh in enumerate(shps):
......@@ -5533,7 +5534,7 @@ class AllocEmpty(gof.Op):
output variable */
Py_XDECREF(%(out)s);
%(out)s = (PyArrayObject*)PyArray_EMPTY(%(nd)s,
PyArray_DIMS(dims),
dims,
%(dtype)s,
0);
if (!%(out)s)
......
......@@ -46,7 +46,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
itensor3, Tile, switch, Diagonal, Diag,
nonzero, flatnonzero, nonzero_values,
stacklists, DimShuffle, hessian, ptp, power,
swapaxes, choose, Choose, NoneConst,
swapaxes, choose, Choose, NoneConst, AllocEmpty
)
from theano.tests import unittest_tools as utt
......@@ -7536,6 +7536,25 @@ class T_Choose(utt.InferShapeTester):
# Op that should be removed from the graph.
self.op_class)
def test_allocempty():
# Test that we allocated correctly
f = theano.function([], AllocEmpty("float32")(2, 3)) # change
assert len(f.maker.fgraph.apply_nodes) == 1
out = f()
assert out.shape == (2, 3)
assert out.dtype == 'float32'
# Test that we do not merge them.
f = theano.function([], [AllocEmpty("float32")(2, 3),
AllocEmpty("float32")(2, 3)])
out = f()
assert out[0].shape == (2, 3)
assert out[0].dtype == 'float32'
assert out[1].shape == (2, 3)
assert out[1].dtype == 'float32'
assert len([node for node in f.maker.fgraph.apply_nodes
if isinstance(node.op, AllocEmpty)]) == 2
"""
if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论