提交 68fb4f02 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

GpuAdvancedIncSubtensor1 supports mixed dtypes. The dev20 version doesn't but that's accessory.

上级 2f0ab791
......@@ -994,19 +994,12 @@ def local_gpua_advanced_incsubtensor(op, context_name, inputs, outputs):
x, y, ilist = inputs
# Gpu Ops needs both inputs to have the same dtype
if (x.type.dtype != y.type.dtype):
dtype = scalar.upcast(x.type.dtype, y.type.dtype)
if x.type.dtype != dtype:
x = tensor.cast(x, dtype)
if y.type.dtype != dtype:
y = tensor.cast(y, dtype)
set_instead_of_inc = op.set_instead_of_inc
compute_capability = int(context.bin_id[-2])
if (compute_capability < 2 or x.ndim != 2 or y.ndim != 2):
if (compute_capability < 2 or x.ndim != 2 or y.ndim != 2 or
x.type.dtype != y.type.dtype):
return GpuAdvancedIncSubtensor1(
set_instead_of_inc=set_instead_of_inc)
else:
......
......@@ -599,7 +599,6 @@ class GpuAdvancedIncSubtensor1(Op):
y_ = as_gpuarray_variable(y, ctx_name)
ilist_ = tensor.as_tensor_variable(ilist)
assert x_.type.dtype == y_.type.dtype
assert x_.type.ndim >= y_.type.ndim
if ilist_.type.dtype[:3] not in ('int', 'uin'):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论