提交 5127f6a9 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Remove dtype as an argument of T.Alloc

上级 6080bef5
...@@ -1759,17 +1759,17 @@ class Alloc(gof.Op): ...@@ -1759,17 +1759,17 @@ class Alloc(gof.Op):
This Op is used to replace fill() during optimizations because after shapes are lifted, This Op is used to replace fill() during optimizations because after shapes are lifted,
the first argument to fill can often be pruned from the graph. the first argument to fill can often be pruned from the graph.
""" """
def __init__(self, dtype): def __init__(self):
self.dtype = dtype pass
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.dtype == other.dtype return type(self) == type(other)
def __hash__(self): def __hash__(self):
return hash(type(self)) ^ hash(self.dtype) return hash(type(self))
def __str__(self): def __str__(self):
return '%s{%s}' % (self.__class__.__name__, self.dtype) return self.__class__.__name__
def make_node(self, value, *shape): def make_node(self, value, *shape):
v = as_tensor_variable(value) v = as_tensor_variable(value)
...@@ -1780,19 +1780,21 @@ class Alloc(gof.Op): ...@@ -1780,19 +1780,21 @@ class Alloc(gof.Op):
raise TypeError('Shape arguments must be integers', s) raise TypeError('Shape arguments must be integers', s)
# if s is constant 1, then we're broadcastable in that dim # if s is constant 1, then we're broadcastable in that dim
bcast.append(isinstance(s, TensorConstant) and (s.data == 1)) bcast.append(isinstance(s, TensorConstant) and (s.data == 1))
otype = TensorType(dtype=self.dtype, broadcastable=bcast) otype = TensorType(dtype=v.dtype, broadcastable=bcast)
return gof.Apply(self, [v]+sh, [otype()]) return gof.Apply(self, [v]+sh, [otype()])
def perform(self, node, inputs, (out,)): def perform(self, node, inputs, (out,)):
v = inputs[0] v = inputs[0]
sh = tuple([int(i) for i in inputs[1:]]) sh = tuple([int(i) for i in inputs[1:]])
if out[0] is None or out[0].shape != sh: if out[0] is None or out[0].shape != sh:
out[0] = numpy.zeros(sh, dtype=self.dtype) out[0] = numpy.zeros(sh, dtype=v.dtype)
out[0][...] += v # broadcast v to fill us up out[0][...] += v # broadcast v to fill us up
def grad(self, inputs, (gout,)): def grad(self, inputs, (gout,)):
return [None for i in inputs] return [None for i in inputs]
alloc = Alloc()
pprint.assign(alloc, printing.FunctionPrinter('alloc'))
@_redefine(elemwise.Elemwise(scal.identity)) @_redefine(elemwise.Elemwise(scal.identity))
def tensor_copy(a): def tensor_copy(a):
......
...@@ -114,7 +114,7 @@ def broadcast_like(value, template, env): ...@@ -114,7 +114,7 @@ def broadcast_like(value, template, env):
shape_of = env.shape_feature.shape_of shape_of = env.shape_feature.shape_of
if template not in shape_of: if template not in shape_of:
raise NotImplementedError('broadcast_like currently requires the template Variable to be in the env already') raise NotImplementedError('broadcast_like currently requires the template Variable to be in the env already')
rval = T.Alloc(template.dtype)(value, *shape_of[template]) rval = T.alloc(T.cast(value, template.dtype), *shape_of[template])
assert rval.type == template.type assert rval.type == template.type
return rval return rval
...@@ -486,7 +486,7 @@ def local_fill_to_alloc(node): ...@@ -486,7 +486,7 @@ def local_fill_to_alloc(node):
# we are broadcasting v somehow # we are broadcasting v somehow
shape_of = node.env.shape_feature.shape_of shape_of = node.env.shape_feature.shape_of
# TODO: cut out un-necessary dimshuffles of v # TODO: cut out un-necessary dimshuffles of v
rval = [T.Alloc(node.outputs[0].dtype)(v, *shape_of[node.outputs[0]])] rval = [T.alloc(T.cast(v, node.outputs[0].dtype), *shape_of[node.outputs[0]])]
assert rval[0].type == node.outputs[0].type assert rval[0].type == node.outputs[0].type
return rval return rval
...@@ -542,12 +542,12 @@ def local_alloc_unary(node): ...@@ -542,12 +542,12 @@ def local_alloc_unary(node):
"""unary(alloc(x, shp)) -> alloc(unary(x), shp) """unary(alloc(x, shp)) -> alloc(unary(x), shp)
""" """
if isinstance(node.op, T.Elemwise) and len(node.inputs)==1: if isinstance(node.op, T.Elemwise) and len(node.inputs)==1:
x = node.inputs[0] a = node.inputs[0]
if x.owner and isinstance(x.owner.op, T.Alloc): if a.owner and isinstance(a.owner.op, T.Alloc):
return [T.Alloc(node.outputs[0].dtype)( x = a.owner.inputs[0]
node.op(T.cast(x.owner.inputs[0], x.dtype)), shp = a.owner.inputs[1:]
*x.owner.inputs[1:] v = node.op(x)
)] return [T.alloc(T.cast(v, node.outputs[0].dtype), *shp)]
################## ##################
......
...@@ -389,9 +389,9 @@ class test_canonize(unittest.TestCase): ...@@ -389,9 +389,9 @@ class test_canonize(unittest.TestCase):
#must broadcast as their is a dimshuffle in the computation #must broadcast as their is a dimshuffle in the computation
((dx/dv)/dx,[dx,dv],[dxv,dvv],1,'float64'), ((dx/dv)/dx,[dx,dv],[dxv,dvv],1,'float64'),
#topo: [Shape_i, Shape_i, Elemwise{inv,no_inplace}(<TensorType(float64, row)>), Alloc(...)] #topo: [Shape_i, Shape_i, Elemwise{inv,no_inplace}(<TensorType(float64, row)>), Alloc]
((fx/fv)/fx,[fx,fv],[fxv,fvv],1,'float32'), ((fx/fv)/fx,[fx,fv],[fxv,fvv],1,'float32'),
#topo:[Shape_i, Shape_i, Elemwise{inv,no_inplace}(<TensorType(float32, row)>), Alloc(...)] #topo:[Shape_i, Shape_i, Elemwise{inv,no_inplace}(<TensorType(float32, row)>), Alloc]
]): ]):
f = compile.function(list(sym_inputs), g, f = compile.function(list(sym_inputs), g,
mode=mode) mode=mode)
...@@ -906,13 +906,13 @@ def test_log1p(): ...@@ -906,13 +906,13 @@ def test_log1p():
print f.maker.env.toposort() print f.maker.env.toposort()
# the first three ops are Shape_i, Shape_i, and Dimshuffle # the first three ops are Shape_i, Shape_i, and Dimshuffle
assert [node.op for node in f.maker.env.toposort()][3:] \ assert [node.op for node in f.maker.env.toposort()][3:] \
== [T.log1p, Alloc('float64')] == [T.log1p, alloc]
f = function([x,y], T.log(0+(x) + fill(y,1.0)), mode=m) f = function([x,y], T.log(0+(x) + fill(y,1.0)), mode=m)
assert [node.op for node in f.maker.env.toposort()][3:] \ assert [node.op for node in f.maker.env.toposort()][3:] \
== [T.log1p, Alloc('float64')] == [T.log1p, alloc]
f = function([x,y], T.log(2+(x) - fill(y,1.0)), mode=m) f = function([x,y], T.log(2+(x) - fill(y,1.0)), mode=m)
assert [node.op for node in f.maker.env.toposort()][3:] \ assert [node.op for node in f.maker.env.toposort()][3:] \
== [T.log1p, Alloc('float64')] == [T.log1p, alloc]
f([1e-7, 10], [[0, 0], [0, 0]]) #debugmode will verify values f([1e-7, 10], [[0, 0], [0, 0]]) #debugmode will verify values
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论