diff --git a/python/akg/ms/gpu/cast.py b/python/akg/ms/gpu/cast.py index 0478a20f94a1d16709617f70aea313cab0daf735..71bd2a0b3f47e42dd7bd40b8af07f589240bb91b 100644 --- a/python/akg/ms/gpu/cast.py +++ b/python/akg/ms/gpu/cast.py @@ -22,4 +22,6 @@ def Cast(x, dst_type): """cast.""" if x.dtype == "int64" and dst_type == "float16": x = cast.cast(x, "float32") + if x.dtype == "float16" and dst_type == "int32": + x = topi.trunc(x) return cast.cast(x, dst_type)