diff --git a/src/StaticLint.jl b/src/StaticLint.jl index f19d39ea..69fa0594 100644 --- a/src/StaticLint.jl +++ b/src/StaticLint.jl @@ -70,6 +70,13 @@ function (state::Toplevel)(x::EXPR) end state.scope != s0 && (state.scope = s0) + + if state.file == state.targetfile && hasscope(x) && scopeof(x) !== state.scope && typof(x) !== CSTParser.ModuleH && typof(x) !== CSTParser.BareModule && typof(x) !== CSTParser.FileH && !CSTParser.defines_datatype(x) + for (n,b) in scopeof(x).names + infer_type_by_use(b, state.server) + end + end + return state.scope end @@ -88,6 +95,7 @@ function (state::Delayed)(x::EXPR) traverse(x, state) + # needs to call to add infer_type_by_use state.scope != s0 && (state.scope = s0) return state.scope end diff --git a/src/imports.jl b/src/imports.jl index 67dc134b..36f60f8b 100644 --- a/src/imports.jl +++ b/src/imports.jl @@ -50,7 +50,15 @@ function resolve_import(x, state::State) end end -function _mark_import_arg(arg, par, state, u) +function add_to_imported_modules(scope::Scope, name::Symbol, val) + if scope.modules isa Dict + scope.modules[name] = val + else + modules = Dict(name => val) + end +end + +function _mark_import_arg(arg, par, state, usinged) if par !== nothing && (typof(arg) === IDENTIFIER || typof(arg) === MacroName) if par isa Binding # mark reference to binding push!(par.refs, arg) @@ -65,29 +73,15 @@ function _mark_import_arg(arg, par, state, u) end arg.meta.binding = Binding(arg, par, _typeof(par, state), [], nothing, nothing) end - if u && par isa SymbolServer.ModuleStore - if state.scope.modules isa Dict - state.scope.modules[Symbol(valof(arg))] = par - else - state.scope.modules = Dict(Symbol(valof(arg)) => par) - end - elseif u && par isa Binding && par.val isa SymbolServer.ModuleStore - if state.scope.modules isa Dict - state.scope.modules[Symbol(valof(arg))] = par.val - else - state.scope.modules = Dict(Symbol(valof(arg)) => par.val) - end - elseif u && par isa Binding && par.val isa EXPR && (typof(par.val) === CSTParser.ModuleH || typof(par.val) === CSTParser.BareModule) - if state.scope.modules isa Dict - state.scope.modules[Symbol(valof(arg))] = scopeof(par.val) - else - state.scope.modules = Dict(Symbol(valof(arg)) => scopeof(par.val)) - end - elseif u && par isa Binding && par.val isa Binding && par.val.val isa EXPR && (typof(par.val.val) === CSTParser.ModuleH || typof(par.val.val) === CSTParser.BareModule) - if state.scope.modules isa Dict - state.scope.modules[Symbol(valof(arg))] = scopeof(par.val.val) - else - state.scope.modules = Dict(Symbol(valof(arg)) => scopeof(par.val.val)) + if usinged + if par isa SymbolServer.ModuleStore + add_to_imported_modules(state.scope, Symbol(valof(arg)), par) + elseif par isa Binding && par.val isa SymbolServer.ModuleStore + add_to_imported_modules(state.scope, Symbol(valof(arg)), par.val) + elseif par isa Binding && par.val isa EXPR && (typof(par.val) === CSTParser.ModuleH || typof(par.val) === CSTParser.BareModule) + add_to_imported_modules(state.scope, Symbol(valof(arg)), scopeof(par.val)) + elseif par isa Binding && par.val isa Binding && par.val.val isa EXPR && (typof(par.val.val) === CSTParser.ModuleH || typof(par.val.val) === CSTParser.BareModule) + add_to_imported_modules(state.scope, Symbol(valof(arg)), scopeof(par.val.val)) end end end diff --git a/src/server.jl b/src/server.jl index 46d38207..3cba9c36 100644 --- a/src/server.jl +++ b/src/server.jl @@ -14,8 +14,9 @@ mutable struct FileServer <: AbstractServer roots::Set{File} symbolserver::SymbolServer.EnvStore symbol_extends::Dict{SymbolServer.VarRef, Vector{SymbolServer.VarRef}} + symbol_fieldtypemap::Dict{Symbol, Vector{SymbolServer.VarRef}} end -FileServer() = FileServer(Dict{String,File}(), Set{File}(), deepcopy(SymbolServer.stdlibs), SymbolServer.collect_extended_methods(SymbolServer.stdlibs)) +FileServer() = FileServer(Dict{String,File}(), Set{File}(), deepcopy(SymbolServer.stdlibs), SymbolServer.collect_extended_methods(SymbolServer.stdlibs), fieldname_type_map(SymbolServer.stdlibs)) # Interface spec. # AbstractServer :-> (has/canload/load/set/get)file, getsymbolserver, getsymbolextends @@ -37,6 +38,7 @@ function loadfile(server::FileServer, path::String) end getsymbolserver(server::FileServer) = server.symbolserver getsymbolextendeds(server::FileServer) = server.symbol_extends +getsymbolfieldtypemap(server::FileServer) = server.symbol_fieldtypemap function scopepass(file, target = nothing) server = file.server diff --git a/src/type_inf.jl b/src/type_inf.jl index df089a06..81e78545 100644 --- a/src/type_inf.jl +++ b/src/type_inf.jl @@ -56,6 +56,343 @@ function infer_type(binding::Binding, scope, state) binding.type = refof(t) end end + elseif binding.val isa EXPR && parentof(binding.val) isa EXPR && typof(parentof(binding.val)) === CSTParser.WhereOpCall + binding.type = CoreTypes.DataType end end end + +""" + is_getfield_lhs(x::EXPR) +x the `a` in `a.b` +""" +is_getfield_lhs(x::EXPR) = is_getfield(parentof(x)) && x === parentof(x)[1] + +""" + is_getfield_lhs_as_chain(x::EXPR) +x the `b` in `a.b.c` +""" +is_getfield_lhs_as_chain(x::EXPR) = parentof(x) isa EXPR && typof(parentof(x)) === CSTParser.Quotenode && StaticLint.is_getfield(parentof(parentof(x))) && StaticLint.is_getfield(parentof(parentof(parentof(x)))) && x === parentof(parentof(x))[3][1] + +isemptyvect(x::EXPR) = typof(x) === CSTParser.Vect && length(x) == 2 + +function get_struct_fieldname(x::EXPR) + if _binary_assert(x, CSTParser.Tokens.DECLARATION) + return get_struct_fieldname(x[1]) + elseif typof(x) === CSTParser.InvisBrackets && length(x) == 3 + return get_struct_fieldname(x[2]) + elseif isidentifier(x) + return x + else + end + return nothing +end + +function cst_struct_fieldnames(x::EXPR) + fns = Symbol[] + if CSTParser.defines_mutable(x) + body = x[4] + elseif CSTParser.defines_struct(x) + body = x[3] + else + return fns + end + for arg in body + field_name = get_struct_fieldname(arg) + if field_name isa EXPR && isidentifier(field_name) + push!(fns, Symbol(CSTParser.str_value(field_name))) + end + end + return fns +end + + +""" + fieldname_type_map(s::Union{Scope,ModuleStore,EnvStore}, server, l = Dict()) + +Returns a Dict where a fieldname (key) points to a collection of types that +have that field. +""" +fieldname_type_map(s, server, l = Dict{Symbol,Any}()) = l # fallback +function fieldname_type_map(s::Scope, server, l = Dict()) + for (n,b) in s.names + b = get_root_method(b, server) + # Todo: Allow for const rebindings of datatypes (i.e. `const dt = DataType`) + if b isa Binding && b.val isa EXPR + if CSTParser.defines_datatype(b.val) + for f in cst_struct_fieldnames(b.val) + f = Symbol(f) + if haskey(l, f) + push!(l[f], b) + else + l[f] = [b] + end + end + elseif CSTParser.defines_function(b.val) && n == "get_property" + # need to check this overwrites Base.get_property + # need to iterate over all methods + sig = CSTParser.get_sig(b.val) + if length(sig) > 5 && _binary_assert(sig[3], CSTParser.Tokens.DECLARATION) && hasref(sig[3][3]) + t_binding = refof(sig[3][3]) + if t_binding isa Binding + if t_binding.type !== CoreTypes.DataType + t_binding = get_root_method(t_binding, server) + t_binding.type !== CoreTypes.DataType && continue + end + for f in get_property_shadow_fields(b.val) + f = Symbol(f) + if haskey(l, f) + push!(l[f], t_binding) + else + l[f] = [t_binding] + end + end + end + end + end + end + end + return l +end + +""" + get_property_shadow_fields(func) + +Assumes `func` is the definition of a function for `get_property`. Searches for +comparisons within the body between the second argument of the function and +symbols, returning a list of these symbols. + +e.g. +``` +function get_property(x::SomeType, f::Symbol) + if f === :asdf + elseif f == :sdgs + end +end +``` + +-> [:asdf, :sdgs] +""" +function get_property_shadow_fields(func) + # Get the argname for 2nd argument of get_property + str_sname = CSTParser.str_value(CSTParser.rem_decl(CSTParser.rem_where_decl(CSTParser.get_sig(func))[5])) + str_sname isa String || return [] + function trav(x, out = []) + if (_binary_assert(x, CSTParser.Tokens.EQEQEQ) || _binary_assert(x, CSTParser.Tokens.EQEQ)) && CSTParser.valof(x[1]) == str_sname && + CSTParser.typof(x[3]) === CSTParser.Quotenode && length(x[3]) ==2 && CSTParser.is_colon(x[3][1]) && CSTParser.isidentifier(x[3][2]) + push!(out, Expr(x[3][2])) + end + for a in x + trav(a, out) + end + out + end + trav(func) +end + +function fieldname_type_map(cache::SymbolServer.ModuleStore, l = Dict{Symbol,Any}()) + for (n,v) in cache.vals + if v isa SymbolServer.DataTypeStore + for f in v.fieldnames + if haskey(l, f) + push!(l[f], v.name.name) + else + l[f] = [v.name.name] + end + end + elseif v isa SymbolServer.ModuleStore + fieldname_type_map(v, l) + end + end + return l +end + +function fieldname_type_map(cache::SymbolServer.EnvStore, l = Dict{Symbol,Any}()) + for (_,m) in cache + fieldname_type_map(m, l) + end + return l +end + +""" + check_ref_against_fieldnames(ref, user_datatypes, new_possibles, server) + +Tries to infer the type of `ref` by looking at how getfield is used against it +and comparing these instances against the fields of all known datatypes. These +are pre-cached for packages in the server's EnvStore (`getsymbolfieldtypemap(server)`). +""" +function check_ref_against_fieldnames(ref, user_datatypes, new_possibles, server) + if is_getfield_lhs(ref) && typof(parentof(ref)[3]) === CSTParser.Quotenode + rhs = parentof(ref)[3][1] + elseif is_getfield_lhs_as_chain(ref) + rhs = parentof(parentof(parentof(ref)))[3][1] + else + return + end + if isidentifier(rhs) + rhs_sym = Symbol(CSTParser.str_value(rhs)) + for t in get(getsymbolfieldtypemap(server), rhs_sym, []) + push!(new_possibles, t) + end + for t in get(user_datatypes, rhs_sym, []) + push!(new_possibles, t) + end + end +end + +""" + is_arg_of_resolved_call(x) + +Checks whether x is the argument of a function call. +""" +is_arg_of_resolved_call(x::EXPR) = parentof(x) isa EXPR && typof(parentof(x)) === Call && parentof(x)[1] !== x && +(hasref(parentof(x)[1]) || (is_getfield(parentof(x)[1]) && typof(parentof(x)[1][3]) === CSTParser.Quotenode && hasref(parentof(x)[1][3][1]))) + + +""" + get_arg_position_in_call(call, arg) + get_arg_position_in_call(arg) + +Returns the position of `arg` in `call` ignoring the function name and punctuation. +The single argument method assumes `parentof(arg) == call` +""" +function get_arg_position_in_call(call::EXPR, arg) + for (i,a) in enumerate(call) + a == arg && return div(i-1, 2) + end +end + +function get_arg_position_in_call(arg) + get_arg_position_in_call(parentof(arg), arg) +end + + +""" + get_arg_type_at_position(f, argi, types) + +Pushes to `types` the argument type (if not `Core.Any`) of a function +at position `argi`. +""" +function get_arg_type_at_position(f, argi, types) end + +function get_arg_type_at_position(b::Binding, argi, types) + argi1 = argi*2 + 1 + if b.val isa EXPR + sig = CSTParser.get_sig(b.val) + if sig !== nothing && + argi1 < length(sig) && + hasbinding(sig[argi1]) && + (argb = bindingof(sig[argi1]); argb isa Binding && argb.type !== nothing) && + !(argb.type in types) + push!(types, argb.type) + return + end + elseif b.val isa SymbolServer.SymStore + return get_arg_type_at_position(b.val, argi, types) + end + return +end + +function get_arg_type_at_position(f::T, argi, types) where T <: Union{SymbolServer.DataTypeStore,SymbolServer.FunctionStore} + for m in f.methods + get_arg_type_at_position(m, argi, types) + end +end + +function get_arg_type_at_position(m::SymbolServer.MethodStore, argi, types) + if length(m.sig) >= argi && m.sig[argi][2] != SymbolServer.VarRef(SymbolServer.VarRef(nothing, :Core), :Any) && !(m.sig[argi][2] in types) + push!(types, m.sig[argi][2]) + end +end + +""" + check_ref_against_calls(x, visitedmethods, new_possibles, server) + +Pushes to `new_possibles` +""" +function check_ref_against_calls(x, visitedmethods, new_possibles, server) + if is_arg_of_resolved_call(x) + # x is argument of function call (func) and we know what that function is + if CSTParser.isidentifier(parentof(x)[1]) + func = refof(parentof(x)[1]) + else + func = refof(parentof(x)[1][3][1]) + end + # make sure we've got the last binding for func + if func isa Binding + func = get_last_method(func, server) + end + # what slot does ref sit in? + argi = get_arg_position_in_call(x) + tls = retrieve_toplevel_scope(x) + while (func isa Binding && func.type == CoreTypes.Function) || func isa SymbolServer.SymStore + !(func in visitedmethods) ? push!(visitedmethods, func) : return # check whether we've been here before + if func isa Binding + get_arg_type_at_position(func, argi, new_possibles) + func = prev_method(func) + else + tls === nothing && return + iterate_over_ss_methods(func, tls, server, m->(get_arg_type_at_position(m, argi, new_possibles);false)) + return + end + end + end +end + +""" + infer_type_by_use(b::Binding, server) + +Tries to infer the type of Binding `b` by looking at how it is used. +""" +function infer_type_by_use(b::Binding, server) + b.type !== nothing && return # b already has a type + user_datatypes = fieldname_type_map(retrieve_toplevel_scope(b.val), server) + possibletypes = [] + visitedmethods = [] + for ref in b.refs + new_possibles = [] + ref isa EXPR || continue # skip non-EXPR (i.e. used for handling of globals) + check_ref_against_fieldnames(ref, user_datatypes, new_possibles, server) + check_ref_against_calls(ref, visitedmethods, new_possibles, server) + + if isempty(possibletypes) + possibletypes = new_possibles + elseif !isempty(new_possibles) + possibletypes = intersect(possibletypes, new_possibles) + if isempty(possibletypes) + return + end + end + end + # Only do something if we're left with a set of 1 at the end. + if length(possibletypes) == 1 + type = first(possibletypes) + if type isa Binding + b.type = type + elseif type isa SymbolServer.DataTypeStore + b.type = type + elseif type isa SymbolServer.VarRef + b.type = SymbolServer._lookup(type, getsymbolserver(server)) # could be nothing + elseif type isa SymbolServer.FakeTypeName && isempty(type.parameters) + b.type = SymbolServer._lookup(type.name, getsymbolserver(server)) # could be nothing + end + end +end + +""" + isrebinding(b::Binding) + +Does `b` simply rebind another binding? +""" +function isrebinding(b::Binding) + b.val isa EXPR && CSTParser.is_assignment(b.val) && + b.val[1] == b.name && CSTParser.isidentifier(b.val[3]) && + hasbinding(b.val[3]) +end + +""" + getrebound(b::Binding) + +Assumes `isrebinding(b) == true` and gets the source binding (recursively). +""" +getrebound(b::Binding) = isrebinding(bindingof(b.val[3])) ? getrebound(bindingof(b.val[3])) : bindingof(b.val[3]) diff --git a/src/utils.jl b/src/utils.jl index 5f509ec9..f68dd93c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -119,6 +119,18 @@ function get_root_method(b::Binding, server, b1 = nothing, visited_bindings = Bi end end +function get_last_method(b::Binding, server, visited_bindings = Binding[]) + if b.next === nothing || b == b.next || !(b.next isa Binding) || b in visited_bindings + return b + end + push!(visited_bindings, b) + if b.type == b.next.type == CoreTypes.Function + return get_last_method(b.next, server, visited_bindings) + else + return b + end +end + function retrieve_delayed_scope(x) if (CSTParser.defines_function(x) || CSTParser.defines_macro(x)) && scopeof(x) !== nothing if parentof(scopeof(x)) !== nothing @@ -274,7 +286,8 @@ isexportedby(k::String, m::SymbolServer.ModuleStore) = isexportedby(Symbol(k), m isexportedby(x::EXPR, m::SymbolServer.ModuleStore) = isexportedby(valof(x), m) isexportedby(k, m::SymbolServer.ModuleStore) = false -function retrieve_toplevel_scope(x) +function retrieve_toplevel_scope(x) end +function retrieve_toplevel_scope(x::EXPR) if scopeof(x) !== nothing && (typof(x) === CSTParser.ModuleH || typof(x) === CSTParser.BareModule || typof(x) === CSTParser.FileH) return scopeof(x) elseif parentof(x) isa EXPR diff --git a/test/runtests.jl b/test/runtests.jl index a367b7f9..0a3833f0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,7 +20,7 @@ function parse_and_pass(s) f = StaticLint.File("", s, CSTParser.parse(s, true), nothing, server) StaticLint.setroot(f, f) StaticLint.setfile(server, "", f) - StaticLint.scopepass(f) + StaticLint.scopepass(f, f) return f.cst end @@ -821,6 +821,37 @@ end StaticLint.check_const_redef(cst[2]) @test cst[2].meta.error == nothing end + +@testset "expr fieldnames" begin + let cst = parse_and_pass(""" + struct T + end + + struct T + a + end + + struct T + a + b + end + + struct T + a::S + b::S + end + + mutable struct T + a::S + b::S + end + """) + @test StaticLint.cst_struct_fieldnames(cst[1]) == [] + @test StaticLint.cst_struct_fieldnames(cst[2]) == [:a] + @test StaticLint.cst_struct_fieldnames(cst[3]) == [:a, :b] + @test StaticLint.cst_struct_fieldnames(cst[4]) == [:a, :b] + @test StaticLint.cst_struct_fieldnames(cst[5]) == [:a, :b] + end end @testset "hoisting of inner constructors" begin @@ -835,6 +866,7 @@ end @test bindingof(cst[1]) === bindingof(cst[1][3][3]).prev @test bindingof(cst[1][3][3]) === bindingof(cst[2]).prev end +include("type_inf.jl") end @testset "using of self" begin # e.g. `using StaticLint: StaticLint` diff --git a/test/type_inf.jl b/test/type_inf.jl new file mode 100644 index 00000000..2df345ee --- /dev/null +++ b/test/type_inf.jl @@ -0,0 +1,138 @@ +@testset "fieldname inference" begin +# arg1 is inferred as T -> only a single (user defined) +# datatype has the field `fieldname1` +let cst = parse_and_pass(""" + struct T + fieldname1 + end + function f(arg1) + arg1.fieldname1 + end + """) + @test cst[2].meta.scope.names["arg1"].type === cst.meta.scope.names["T"] +end + +# arg1 inferred as above +# arg2 as above but for `S` +# arg3 field use is conflicting -> no type assigned +let cst = parse_and_pass(""" + struct T + fieldname1 + end + struct S + fieldname2 + end + function f(arg1, arg2, arg3) + arg1.fieldname1 + arg2.fieldname2 + arg3.fieldname1 + arg3.fieldname2 + end + """) + @test cst[3].meta.scope.names["arg1"].type === cst.meta.scope.names["T"] + @test cst[3].meta.scope.names["arg2"].type === cst.meta.scope.names["S"] + @test cst[3].meta.scope.names["arg3"].type === nothing +end + +# arg1 type inferred as above +# arg2 type not inferred as `sig` is also the fieldname of +# `Method` exported by Core. +let cst = parse_and_pass(""" + struct T + fieldname1 + sig + end + function f(arg1, arg2) + arg1.fieldname1 + arg2.sig + end + """) + @test cst[2].meta.scope.names["arg1"].type === cst.meta.scope.names["T"] + @test cst[2].meta.scope.names["arg2"].type === nothing +end + +let cst = parse_and_pass(""" + struct T + fieldname1 + end + struct S + fieldname2 + end + function f(arg1) + if arg1 isa T + arg1.fieldname1 + elseif arg1 isa S + arg1.fieldname2 + end + end + """) + @test cst[3].meta.scope.names["arg1"].type === nothing +end +end + +@testset "inference by use as function argument" begin +# single method function with user defined datatype +let cst = parse_and_pass(""" + struct T end + function f(arg::T) end + function g(arg) end + let arg1 = unknownvalue, arg2 = unknownvalue + f(arg1) + g(arg1) + end + """) + @test cst[4].meta.scope.names["arg1"].type === cst.meta.scope.names["T"] + @test cst[4].meta.scope.names["arg2"].type === nothing +end + +# as above against imported (symbolserver) types +let cst = parse_and_pass(""" + function f(arg::Int) end + let arg = unknownvalue + f(arg) + end + """) + @test cst[2].meta.scope.names["arg"].type.name == SymbolServer.FakeTypeName(Int) +end + +# 2 methods, conflicting types so no inference +let cst = parse_and_pass(""" + function f(arg::Int) end + function f(arg::Float64) end + let arg = unknownvalue + f(arg) + end + """) + @test cst[3].meta.scope.names["arg"].type === nothing +end + +# 2 functions, 1 with two methods. +let cst = parse_and_pass(""" + function f(arg::Int) end + function f(arg::Float64) end + function g(arg::Int) end + let arg = unknownvalue + f(arg) + g(arg) + end + """) + @test cst[4].meta.scope.names["arg"].type.name == SymbolServer.FakeTypeName(Int) +end + +# SymServer function w/ single method +let cst = parse_and_pass(""" + let arg = unknownvalue + dirname(arg) + end + """) + @test cst[1].meta.scope.names["arg"].type.name == SymbolServer.FakeTypeName(AbstractString) +end +# As above but qualified name for function. +let cst = parse_and_pass(""" + let arg = unknownvalue + Base.dirname(arg) + end + """) + @test cst[1].meta.scope.names["arg"].type.name == SymbolServer.FakeTypeName(AbstractString) +end +end \ No newline at end of file