Skip to content

Commit

Permalink
Dyno: Fixes to c_ptr representation (chapel-lang#24749)
Browse files Browse the repository at this point in the history
This PR contains various improvements to the representation of ``c_ptr``
and ``c_ptrConst`` in the frontend.

The most significant change is to stop treating ``c_ptr[Const]`` as a
builtin and instead rely on the ``c_ptr`` and ``c_ptrConst`` classes in
``CTypes``. This allows us to use a real ID as a point of reference to
find methods on ``c_ptr`` when ``CTypes`` is not in scope. Tests have
also been updated to now use ``CTypes``.

Another notable change is to allow ``eltType`` to be called on both
types and values of ``c_ptr``. This is achieved by updating
``getCompilerGeneratedMethodQuery`` to take a ``QualifiedType`` instead
of a ``Type*``, so that we can preserve the relevant ``Kind``.

Other changes:
- implement PRIM_ARRAY_GET for ``c_ptr``
- allow ``c_ptr`` in PRIM_STRING_LENGTH_BYTES

[reviewed-by @mppf]
  • Loading branch information
benharsh authored Apr 3, 2024
2 parents 63a9a86 + 71aa824 commit d0e3add
Show file tree
Hide file tree
Showing 11 changed files with 144 additions and 29 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: make test-dyno-with-asserts
# 'make modules' so the ChapelSysCTypes module is available
run: |
CHPL_HOME=$PWD make modules
CHPL_HOME=$PWD make DYNO_ENABLE_ASSERTIONS=1 test-dyno -j`util/buildRelease/chpl-make-cpu_count`
- name: run dyno linters
run: |
Expand Down
4 changes: 4 additions & 0 deletions frontend/include/chpl/types/CPtrType.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ class CPtrType final : public Type {
static const ID& getId(Context* context);
static const ID& getConstId(Context* context);

const ID& id(Context* context) const {
return isConst() ? getConstId(context) : getId(context);
}

const Type* eltType() const {
return eltType_;
}
Expand Down
35 changes: 25 additions & 10 deletions frontend/lib/resolution/default-functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -685,24 +685,29 @@ generateRecordComparison(Context* context, const CompositeType* lhsType) {
}

static const TypedFnSignature*
generateCPtrMethod(Context* context, const CPtrType * cpt, UniqueString name) {
generateCPtrMethod(Context* context, QualifiedType receiverType,
UniqueString name) {
// Build a basic function signature for methods on a cptr
// TODO: we should really have a way to just set the return type here
const CPtrType* cpt = receiverType.type()->toCPtrType();
const TypedFnSignature* result = nullptr;
std::vector<UntypedFnSignature::FormalDetail> formals;
std::vector<QualifiedType> formalTypes;

formals.push_back(UntypedFnSignature::FormalDetail(USTR("this"), false, nullptr));
formalTypes.push_back(QualifiedType(QualifiedType::CONST_REF, cpt));

// Allow calling 'eltType' on either a type or value
auto qual = receiverType.isType() ? QualifiedType::TYPE : QualifiedType::CONST_REF;
formalTypes.push_back(QualifiedType(qual, cpt));

auto ufs = UntypedFnSignature::get(context,
/*id*/ cpt->getId(context),
/*id*/ cpt->id(context),
/*name*/ name,
/*isMethod*/ true,
/*isTypeConstructor*/ false,
/*isCompilerGenerated*/ true,
/*throws*/ false,
/*idTag*/ parsing::idToTag(context, cpt->getId(context)),
/*idTag*/ asttags::Class,
/*kind*/ uast::Function::Kind::PROC,
/*formals*/ std::move(formals),
/*whereClause*/ nullptr);
Expand All @@ -719,9 +724,11 @@ generateCPtrMethod(Context* context, const CPtrType * cpt, UniqueString name) {
}

static const TypedFnSignature* const&
getCompilerGeneratedMethodQuery(Context* context, const Type* type,
getCompilerGeneratedMethodQuery(Context* context, QualifiedType receiverType,
UniqueString name, bool parenless) {
QUERY_BEGIN(getCompilerGeneratedMethodQuery, context, type, name, parenless);
QUERY_BEGIN(getCompilerGeneratedMethodQuery, context, receiverType, name, parenless);

const Type* type = receiverType.type();

const TypedFnSignature* result = nullptr;

Expand Down Expand Up @@ -749,8 +756,8 @@ getCompilerGeneratedMethodQuery(Context* context, const Type* type,
} else {
CHPL_UNIMPL("record method not implemented yet!");
}
} else if (auto cPtrType = type->toCPtrType()) {
result = generateCPtrMethod(context, cPtrType, name);
} else if (type->isCPtrType()) {
result = generateCPtrMethod(context, receiverType, name);
} else {
CHPL_UNIMPL("should not be reachable");
}
Expand Down Expand Up @@ -855,9 +862,17 @@ generateCastToEnum(Context* context,
If no method was generated, returns nullptr.
*/
const TypedFnSignature*
getCompilerGeneratedMethod(Context* context, const Type* type,
getCompilerGeneratedMethod(Context* context, const QualifiedType receiverType,
UniqueString name, bool parenless) {
return getCompilerGeneratedMethodQuery(context, type, name, parenless);
// Normalize recieverType to allow TYPE methods on c_ptr, and to otherwise
// use the VAR Kind. The Param* value is also stripped away to reduce
// queries.
auto qt = receiverType;
bool isCPtr = qt.hasTypePtr() ? qt.type()->isCPtrType() : false;
if (!(qt.isType() && isCPtr)) {
qt = QualifiedType(QualifiedType::VAR, qt.type());
}
return getCompilerGeneratedMethodQuery(context, qt, name, parenless);
}

static const TypedFnSignature* const&
Expand Down
3 changes: 2 additions & 1 deletion frontend/lib/resolution/default-functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ bool needCompilerGeneratedMethod(Context* context, const types::Type* type,
If no method was generated, returns nullptr.
*/
const TypedFnSignature*
getCompilerGeneratedMethod(Context* context, const types::Type* type,
getCompilerGeneratedMethod(Context* context,
const types::QualifiedType receiverType,
UniqueString name, bool parenless);

/**
Expand Down
20 changes: 18 additions & 2 deletions frontend/lib/resolution/prims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,8 @@ CallResolutionResult resolvePrimCall(Context* context,
break;
} else if (actualType.type()->isStringType() ||
actualType.type()->isBytesType() ||
actualType.type()->isCStringType()) {
actualType.type()->isCStringType() ||
actualType.type()->isCPtrType()) {
// for non-param string/bytes, the return type is just a default int
type = QualifiedType(QualifiedType::CONST_VAR,
IntType::get(context, 0));
Expand Down Expand Up @@ -1640,10 +1641,25 @@ CallResolutionResult resolvePrimCall(Context* context,
type = QualifiedType(QualifiedType::CONST_VAR,
CompositeType::getLocaleIDType(context));
break;

case PRIM_ARRAY_GET: {
if (ci.numActuals() == 2 && ci.actual(0).type().hasTypePtr()) {
auto index = ci.actual(1).type();
if (index.hasTypePtr() && index.type()->isIntegralType()) {
auto act = ci.actual(0).type();
if (auto ptr = act.type()->toCPtrType()) {
type = QualifiedType(QualifiedType::REF, ptr->eltType());
}
} else {
context->error(call, "bad call to primitive \"%s\": second argument must be an integral type", primTagToName(prim));
}
}
}
break;

case PRIM_USED_MODULES_LIST:
case PRIM_REFERENCED_MODULES_LIST:
case PRIM_TUPLE_EXPAND:
case PRIM_ARRAY_GET:
case PRIM_MAYBE_LOCAL_THIS:
case PRIM_MAYBE_LOCAL_ARR_ELEM:
case PRIM_MAYBE_AGGREGATE_ASSIGN:
Expand Down
17 changes: 13 additions & 4 deletions frontend/lib/resolution/resolution-queries.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3036,6 +3036,12 @@ static const Type* getCPtrType(Context* context,
UniqueString name = ci.name();
bool isConst;

// 'typeForId' should have prepared this for us if 'CTypes' was in scope.
auto called = ci.calledType();
if (!(called.hasTypePtr() && called.type()->isCPtrType())) {
return nullptr;
}

if (name == USTR("c_ptr")) {
isConst = false;
} else if (name == USTR("c_ptrConst")) {
Expand Down Expand Up @@ -3403,11 +3409,10 @@ considerCompilerGeneratedMethods(Context* context,
// fetch the receiver type info
CHPL_ASSERT(ci.numActuals() >= 1);
auto& receiver = ci.actual(0);
// TODO: This should be the QualifiedType in case of type methods
auto receiverType = receiver.type().type();
auto receiverType = receiver.type();

// if not compiler-generated, then nothing to do
if (!needCompilerGeneratedMethod(context, receiverType, ci.name(),
if (!needCompilerGeneratedMethod(context, receiverType.type(), ci.name(),
ci.isParenless())) {
return nullptr;
}
Expand Down Expand Up @@ -3502,6 +3507,8 @@ lookupCalledExpr(Context* context,
if (auto compType = t->getCompositeType()) {
receiverScopes =
Resolver::gatherReceiverAndParentScopesForType(context, compType);
} else if (auto cptr = t->toCPtrType()) {
receiverScopes.push_back(scopeForId(context, cptr->id(context)));
}
}
}
Expand Down Expand Up @@ -4066,7 +4073,9 @@ findMostSpecificAndCheck(Context* context,
}

// note any most-specific candidates from POI in poiInfo.
{
// TODO: This can be the case for generated calls, but is skipping the POI
// accumulation safe?
if (call != nullptr) {
size_t n = candidates.size();
for (size_t i = firstPoiCandidate; i < n; i++) {
for (const MostSpecificCandidate& candidate : mostSpecific) {
Expand Down
9 changes: 9 additions & 0 deletions frontend/lib/resolution/scope-queries.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,15 @@ struct GatherDecls {
// since dyno handles tuple types directly rather
// than through a record.
skip = true;
} else if (d->name() == "eltType" &&
atFieldLevel && tagParent == asttags::Class &&
(d->id().symbolPath().startsWith("CTypes.c_ptr") ||
d->id().symbolPath().startsWith("Ctypes.c_ptrConst"))) {
// skip gathering the 'eltType' field of the dummy c_ptr[Const] classes,
// since we're representing those types entirely within the frontend.
//
// TODO: Remove this once we have replaced those classes.
skip = true;
}

if (skip == false) {
Expand Down
3 changes: 2 additions & 1 deletion frontend/lib/types/CompositeType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ bool CompositeType::isMissingBundledClassType(Context* context, ID id) {
auto path = id.symbolPath();
return path == "ChapelReduce.ReduceScanOp" ||
path == "Errors.Error" ||
path == "CTypes.c_ptr";
path == "CTypes.c_ptr" ||
path == "CTypes.c_ptrConst";
}

return false;
Expand Down
4 changes: 0 additions & 4 deletions frontend/lib/types/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,6 @@ void Type::gatherBuiltins(Context* context,
auto genericUnmanaged = ClassType::get(context, AnyClassType::get(context), nullptr, ClassTypeDecorator(ClassTypeDecorator::UNMANAGED));
gatherType(context, map, "unmanaged", genericUnmanaged);

gatherType(context, map, "c_ptr", CPtrType::get(context));

gatherType(context, map, "c_ptrConst", CPtrType::getConst(context));

BuiltinType::gatherBuiltins(context, map);
}

Expand Down
59 changes: 54 additions & 5 deletions frontend/test/resolution/testCPtr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,31 @@

#include <sstream>

static Context* context;

// Use a single context with revisions to get this test running faster.
static void setupContext() {
std::string chpl_home;
if (const char* chpl_home_env = getenv("CHPL_HOME")) {
chpl_home = chpl_home_env;
} else {
printf("CHPL_HOME must be set");
exit(1);
}
Context::Configuration config;
config.chplHome = chpl_home;
context = new Context(config);
}

template <typename F>
void testCPtrArg(const char* formalType, const char* actualType, F&& test) {
Context ctx;
Context* context = &ctx;
context->advanceToNextRevision(false);
setupModuleSearchPaths(context, false, false, {}, {});
ErrorGuard guard(context);

std::stringstream ss;

ss << "use CTypes;" << std::endl;
ss << "record rec { type someType; }" << std::endl;
ss << "proc f(x: " << formalType << ") {}" << std::endl;
ss << "var arg: " << actualType << ";" << std::endl;
Expand All @@ -53,14 +70,14 @@ void testCPtrArg(const char* formalType, const char* actualType, F&& test) {

assert(modules.size() == 1);
auto mainMod = modules[0];
assert(mainMod->numStmts() == 4);
assert(mainMod->numStmts() == 5);

auto fChild = mainMod->child(1);
auto fChild = mainMod->child(2);
assert(fChild->isFunction());
auto fFn = fChild->toFunction();
assert(fFn->name() == "f");

auto fCallVar = mainMod->child(3);
auto fCallVar = mainMod->child(4);
assert(fCallVar->isVariable());

auto& modResResult = resolveModule(context, mainMod->id());
Expand Down Expand Up @@ -293,7 +310,35 @@ static void test22() {
});
}

static void test23() {
ErrorGuard guard(context);

std::string program = R"""(
module M{
module X {
proc foo() {
use CTypes;
var ret : c_ptr(int);
return ret;
}
}
use X;
var ptr = foo();
var x = ptr[0];
}
)""";

auto vars = resolveTypesOfVariables(context, program, {"ptr", "x"});
assert(vars["ptr"].type()->isCPtrType());
assert(vars["x"].type()->isIntType());
}

int main() {
setupContext();

test1();
test2();
test3();
Expand All @@ -317,5 +362,9 @@ int main() {
test21();
test22();

test23();

delete context;

return 0;
}
17 changes: 15 additions & 2 deletions frontend/test/resolution/testReturnTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1095,15 +1095,26 @@ static void testSelectParams() {
}

static void testCPtrEltType() {
std::string chpl_home;
if (const char* chpl_home_env = getenv("CHPL_HOME")) {
chpl_home = chpl_home_env;
} else {
printf("CHPL_HOME must be set");
exit(1);
}
Context::Configuration config;
config.chplHome = chpl_home;
{
//works for c_ptr
std::string program = ops + R"""(
use CTypes;
var y: c_ptr(uint(8));
type x = y.eltType;
)""";

Context ctx;
Context ctx(config);
Context* context = &ctx;
setupModuleSearchPaths(context, false, false, {}, {});
ErrorGuard guard(context);
auto qt = resolveTypeOfXInit(context, program);
assert(qt.type()->isUintType());
Expand All @@ -1113,15 +1124,17 @@ static void testCPtrEltType() {
{
//works for user-defined class
std::string program = ops + R"""(
use CTypes;
class c_ptr2 {
type eltType;
}
var y: c_ptr2(uint(8));
type x = y.eltType;
)""";

Context ctx;
Context ctx(config);
Context* context = &ctx;
setupModuleSearchPaths(context, false, false, {}, {});
ErrorGuard guard(context);
auto qt = resolveTypeOfXInit(context, program);
assert(qt.type()->isUintType());
Expand Down

0 comments on commit d0e3add

Please sign in to comment.