提交 efc5e1c0 authored 作者: James Bergstra's avatar James Bergstra

Modified local_pow_canonicalize and local_pow_specialize to not use fill()

上级 35635bf3
...@@ -103,6 +103,19 @@ def scalarconsts_rest(inputs): ...@@ -103,6 +103,19 @@ def scalarconsts_rest(inputs):
nonconsts.append(i) nonconsts.append(i)
return consts, origconsts, nonconsts return consts, origconsts, nonconsts
def broadcast_like(value, template, env):
"""Return a Variable with the same shape and dtype as the template,
filled by broadcasting value through it. `value` will be casted as necessary.
"""
shape_of = env.shape_feature.shape_of
if template not in shape_of:
raise NotImplementedError('broadcast_like currently requires the template Variable to be in the env already')
rval = T.Alloc(template.dtype)(value, *shape_of[template])
assert rval.type == template.type
return rval
@gof.optimizer @gof.optimizer
def insert_inplace_optimizer(env): def insert_inplace_optimizer(env):
""" """
...@@ -1265,11 +1278,10 @@ register_canonicalize(local_inv_canon) ...@@ -1265,11 +1278,10 @@ register_canonicalize(local_inv_canon)
@gof.local_optimizer([T.pow]) @gof.local_optimizer([T.pow])
def local_pow_canonicalize(node): def local_pow_canonicalize(node):
if node.op == T.pow: if node.op == T.pow:
if N.all(local_mul_canonizer.get_constant(node.inputs[1]) == 1.0): if N.all(local_mul_canonizer.get_constant(node.inputs[1]) == 0):
return [T.fill(node.inputs[1], node.inputs[0])] return [broadcast_like(1, node.outputs[0], node.env)]
if N.all(local_mul_canonizer.get_constant(node.inputs[1]) == 0.0): if N.all(local_mul_canonizer.get_constant(node.inputs[1]) == 1):
#extra fills here are to make sure the size of the output stays constant. return [broadcast_like(node.inputs[0], node.outputs[0], node.env)]
return [T.fill(node.inputs[0], T.fill(node.inputs[1], 1.0))]
else: else:
return False return False
register_canonicalize(local_pow_canonicalize) register_canonicalize(local_pow_canonicalize)
...@@ -1279,25 +1291,33 @@ def local_pow_specialize(node): ...@@ -1279,25 +1291,33 @@ def local_pow_specialize(node):
#here, we are past the point of canonicalization, so we don't want to put in un-necessary fills. #here, we are past the point of canonicalization, so we don't want to put in un-necessary fills.
if node.op == T.pow: if node.op == T.pow:
#the idea here is that we have pow(x, y) #the idea here is that we have pow(x, y)
odtype = node.outputs[0].dtype
xsym = node.inputs[0] xsym = node.inputs[0]
ysym = node.inputs[1] ysym = node.inputs[1]
y = local_mul_canonizer.get_constant(ysym) y = local_mul_canonizer.get_constant(ysym)
if (y is not None) \ if (y is not None) \
and encompasses_broadcastable(xsym.type.broadcastable, ysym.type.broadcastable): and encompasses_broadcastable(xsym.type.broadcastable, ysym.type.broadcastable):
if N.all(y == 2.0): rval = None
return [T.sqr(xsym)]
if N.all(y == 1.0): if N.all(y == 2):
return [xsym] rval = [T.sqr(xsym)]
if N.all(y == 0.0): if N.all(y == 1):
return [T.fill(xsym, 1.0)] rval = [xsym]
if N.all(y == 0):
rval = [T.fill(xsym, numpy.asarray(1, dtype=odtype))]
if N.all(y == 0.5): if N.all(y == 0.5):
return [T.sqrt(xsym)] rval = [T.sqrt(xsym)]
if N.all(y == -0.5): if N.all(y == -0.5):
return [T.inv(T.sqrt(xsym))] rval = [T.inv(T.sqrt(xsym))]
if N.all(y == -1.0): if N.all(y == -1):
return [T.inv(xsym)] rval = [T.inv(xsym)]
if N.all(y == -2.0): if N.all(y == -2):
return [T.inv(T.sqr(xsym))] rval = [T.inv(T.sqr(xsym))]
if rval:
rval[0] = T.cast(rval[0], odtype)
assert rval[0].type == node.outputs[0].type, (rval, node.outputs)
return rval
else: else:
return False return False
register_specialize(local_pow_specialize) register_specialize(local_pow_specialize)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论