KernelAbstractions.jl icon indicating copy to clipboard operation
KernelAbstractions.jl copied to clipboard

@index cannot be used in the CPU if is not a direct rhs of an assignment

Open 101001000 opened this issue 1 year ago • 1 comments

# This works flawlessly
@kernel function f()
           a = @index(Global, Cartesian)
           @print(a[1])

# This doesn't compile in CPU
@kernel function f()
           a = let 
               @index(Global, Cartesian)
               end
           @print(a[1])

My guess is that the @kernel macro, will only insert the idx parameter required for CPU indexing when the @index macro is a direct assignment. Maybe a more complex logic to retrieve the indicies should be applied?

# macros.jl line 290:
 elseif @capture(stmt, lhs_ = rhs_ | (vs__, lhs_ = rhs_))
            if @capture(rhs, @index(args__))
                push!(indicies, stmt)
                continue

# macros.jl line 242:
function emit(loop)
    idx = gensym(:I)
    for stmt in loop.indicies
        # splice index into the i = @index(Cartesian, $idx)
        @assert stmt.head === :(=)
        rhs = stmt.args[2]
        push!(rhs.args, idx)
    end

This also means that there's a bit of inconsistence about how indexing is applied in CPU and in GPU:

# This code is valid in CPU
@kernel function f()
           a = @index(Global, Cartesian)
           @print(a[1])

# But this code is invalid in CPU and valid in GPU
@kernel function f()
           a = @index(Global, Cartesian)[1]
           @print(a)

101001000 avatar Apr 08 '24 14:04 101001000

I had a success by modifying the macros.jl file. Instead doing the push!(rhs.args, idx), I'm creating a custom function which will add the index to every single "@index()" expression and calling it. I don't know the implications or how that would break the semantics, but it works in the kernels I tried. My only concern is that it still only working on an assignment statement level, but at least I can index now.

This is my quick sketch:

function expr_to_string(expr)
    io = IOBuffer()
    Base.show_unquoted(io, expr, 0, -1)
    return String(take!(io))
end
function expr_identify_1(expr, str)
    try    
        return expr_to_string(expr.args[1]) == str
    catch
        return false
    end
end
function append_idx!(stmt, idx)
    if stmt isa Expr
        if expr_identify_1(stmt, """var\"@index\"""") 
            push!(stmt.args, idx)
        end
        for arg in stmt.args
            append_idx!(arg, idx)
        end        
    end
end

and instead of push!(rhs.args, idx) I do append_idx!(stmt, idx)

101001000 avatar Apr 08 '24 15:04 101001000