提交 5127f6a9 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Remove dtype as an argument of T.Alloc

上级 6080bef5
......@@ -1759,17 +1759,17 @@ class Alloc(gof.Op):
This Op is used to replace fill() during optimizations because after shapes are lifted,
the first argument to fill can often be pruned from the graph.
"""
def __init__(self, dtype):
self.dtype = dtype
def __init__(self):
pass
def __eq__(self, other):
return type(self) == type(other) and self.dtype == other.dtype
return type(self) == type(other)
def __hash__(self):
return hash(type(self)) ^ hash(self.dtype)
return hash(type(self))
def __str__(self):
return '%s{%s}' % (self.__class__.__name__, self.dtype)
return self.__class__.__name__
def make_node(self, value, *shape):
v = as_tensor_variable(value)
......@@ -1780,19 +1780,21 @@ class Alloc(gof.Op):
raise TypeError('Shape arguments must be integers', s)
# if s is constant 1, then we're broadcastable in that dim
bcast.append(isinstance(s, TensorConstant) and (s.data == 1))
otype = TensorType(dtype=self.dtype, broadcastable=bcast)
otype = TensorType(dtype=v.dtype, broadcastable=bcast)
return gof.Apply(self, [v]+sh, [otype()])
def perform(self, node, inputs, (out,)):
v = inputs[0]
sh = tuple([int(i) for i in inputs[1:]])
if out[0] is None or out[0].shape != sh:
out[0] = numpy.zeros(sh, dtype=self.dtype)
out[0] = numpy.zeros(sh, dtype=v.dtype)
out[0][...] += v # broadcast v to fill us up
def grad(self, inputs, (gout,)):
return [None for i in inputs]
alloc = Alloc()
pprint.assign(alloc, printing.FunctionPrinter('alloc'))
@_redefine(elemwise.Elemwise(scal.identity))
def tensor_copy(a):
......
......@@ -114,7 +114,7 @@ def broadcast_like(value, template, env):
shape_of = env.shape_feature.shape_of
if template not in shape_of:
raise NotImplementedError('broadcast_like currently requires the template Variable to be in the env already')
rval = T.Alloc(template.dtype)(value, *shape_of[template])
rval = T.alloc(T.cast(value, template.dtype), *shape_of[template])
assert rval.type == template.type
return rval
......@@ -486,7 +486,7 @@ def local_fill_to_alloc(node):
# we are broadcasting v somehow
shape_of = node.env.shape_feature.shape_of
# TODO: cut out un-necessary dimshuffles of v
rval = [T.Alloc(node.outputs[0].dtype)(v, *shape_of[node.outputs[0]])]
rval = [T.alloc(T.cast(v, node.outputs[0].dtype), *shape_of[node.outputs[0]])]
assert rval[0].type == node.outputs[0].type
return rval
......@@ -542,12 +542,12 @@ def local_alloc_unary(node):
"""unary(alloc(x, shp)) -> alloc(unary(x), shp)
"""
if isinstance(node.op, T.Elemwise) and len(node.inputs)==1:
x = node.inputs[0]
if x.owner and isinstance(x.owner.op, T.Alloc):
return [T.Alloc(node.outputs[0].dtype)(
node.op(T.cast(x.owner.inputs[0], x.dtype)),
*x.owner.inputs[1:]
)]
a = node.inputs[0]
if a.owner and isinstance(a.owner.op, T.Alloc):
x = a.owner.inputs[0]
shp = a.owner.inputs[1:]
v = node.op(x)
return [T.alloc(T.cast(v, node.outputs[0].dtype), *shp)]
##################
......
......@@ -389,9 +389,9 @@ class test_canonize(unittest.TestCase):
#must broadcast as their is a dimshuffle in the computation
((dx/dv)/dx,[dx,dv],[dxv,dvv],1,'float64'),
#topo: [Shape_i, Shape_i, Elemwise{inv,no_inplace}(<TensorType(float64, row)>), Alloc(...)]
#topo: [Shape_i, Shape_i, Elemwise{inv,no_inplace}(<TensorType(float64, row)>), Alloc]
((fx/fv)/fx,[fx,fv],[fxv,fvv],1,'float32'),
#topo:[Shape_i, Shape_i, Elemwise{inv,no_inplace}(<TensorType(float32, row)>), Alloc(...)]
#topo:[Shape_i, Shape_i, Elemwise{inv,no_inplace}(<TensorType(float32, row)>), Alloc]
]):
f = compile.function(list(sym_inputs), g,
mode=mode)
......@@ -906,13 +906,13 @@ def test_log1p():
print f.maker.env.toposort()
# the first three ops are Shape_i, Shape_i, and Dimshuffle
assert [node.op for node in f.maker.env.toposort()][3:] \
== [T.log1p, Alloc('float64')]
== [T.log1p, alloc]
f = function([x,y], T.log(0+(x) + fill(y,1.0)), mode=m)
assert [node.op for node in f.maker.env.toposort()][3:] \
== [T.log1p, Alloc('float64')]
== [T.log1p, alloc]
f = function([x,y], T.log(2+(x) - fill(y,1.0)), mode=m)
assert [node.op for node in f.maker.env.toposort()][3:] \
== [T.log1p, Alloc('float64')]
== [T.log1p, alloc]
f([1e-7, 10], [[0, 0], [0, 0]]) #debugmode will verify values
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论