提交 efc5e1c0 authored 作者: James Bergstra's avatar James Bergstra

Modified local_pow_canonicalize and local_pow_specialize to not use fill()

上级 35635bf3
......@@ -103,6 +103,19 @@ def scalarconsts_rest(inputs):
nonconsts.append(i)
return consts, origconsts, nonconsts
def broadcast_like(value, template, env):
"""Return a Variable with the same shape and dtype as the template,
filled by broadcasting value through it. `value` will be casted as necessary.
"""
shape_of = env.shape_feature.shape_of
if template not in shape_of:
raise NotImplementedError('broadcast_like currently requires the template Variable to be in the env already')
rval = T.Alloc(template.dtype)(value, *shape_of[template])
assert rval.type == template.type
return rval
@gof.optimizer
def insert_inplace_optimizer(env):
"""
......@@ -1265,11 +1278,10 @@ register_canonicalize(local_inv_canon)
@gof.local_optimizer([T.pow])
def local_pow_canonicalize(node):
if node.op == T.pow:
if N.all(local_mul_canonizer.get_constant(node.inputs[1]) == 1.0):
return [T.fill(node.inputs[1], node.inputs[0])]
if N.all(local_mul_canonizer.get_constant(node.inputs[1]) == 0.0):
#extra fills here are to make sure the size of the output stays constant.
return [T.fill(node.inputs[0], T.fill(node.inputs[1], 1.0))]
if N.all(local_mul_canonizer.get_constant(node.inputs[1]) == 0):
return [broadcast_like(1, node.outputs[0], node.env)]
if N.all(local_mul_canonizer.get_constant(node.inputs[1]) == 1):
return [broadcast_like(node.inputs[0], node.outputs[0], node.env)]
else:
return False
register_canonicalize(local_pow_canonicalize)
......@@ -1279,25 +1291,33 @@ def local_pow_specialize(node):
#here, we are past the point of canonicalization, so we don't want to put in un-necessary fills.
if node.op == T.pow:
#the idea here is that we have pow(x, y)
odtype = node.outputs[0].dtype
xsym = node.inputs[0]
ysym = node.inputs[1]
y = local_mul_canonizer.get_constant(ysym)
if (y is not None) \
and encompasses_broadcastable(xsym.type.broadcastable, ysym.type.broadcastable):
if N.all(y == 2.0):
return [T.sqr(xsym)]
if N.all(y == 1.0):
return [xsym]
if N.all(y == 0.0):
return [T.fill(xsym, 1.0)]
rval = None
if N.all(y == 2):
rval = [T.sqr(xsym)]
if N.all(y == 1):
rval = [xsym]
if N.all(y == 0):
rval = [T.fill(xsym, numpy.asarray(1, dtype=odtype))]
if N.all(y == 0.5):
return [T.sqrt(xsym)]
rval = [T.sqrt(xsym)]
if N.all(y == -0.5):
return [T.inv(T.sqrt(xsym))]
if N.all(y == -1.0):
return [T.inv(xsym)]
if N.all(y == -2.0):
return [T.inv(T.sqr(xsym))]
rval = [T.inv(T.sqrt(xsym))]
if N.all(y == -1):
rval = [T.inv(xsym)]
if N.all(y == -2):
rval = [T.inv(T.sqr(xsym))]
if rval:
rval[0] = T.cast(rval[0], odtype)
assert rval[0].type == node.outputs[0].type, (rval, node.outputs)
return rval
else:
return False
register_specialize(local_pow_specialize)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论