提交 8010cfcb authored 作者: Amjad Almahairi's avatar Amjad Almahairi

making optimization compatible with old version

上级 7a88ea6d
......@@ -377,7 +377,11 @@ class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp):
@local_optimizer([MultinomialFromUniform])
def local_gpu_multinomial(node):
if type(node.op) is MultinomialFromUniform:
p, u, n_samples = node.inputs
if len(node.inputs) == 2:
p, u = node.inputs
n_samples = 1
else:
p, u, n_samples = node.inputs
try:
if get_scalar_constant_value(n_samples) != 1:
return None
......@@ -395,7 +399,11 @@ def local_gpu_multinomial(node):
node.inputs[0].owner and
type(node.inputs[0].owner.op) is MultinomialFromUniform):
multi = node.inputs[0].owner
p, u, n_samples = multi.inputs
if len(node.inputs) == 2:
p, u = node.inputs
n_samples = 1
else:
p, u, n_samples = node.inputs
try:
if get_scalar_constant_value(n_samples) != 1:
return None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论