Skip to content

Commit

Permalink
[J2KT] Fix collection method mismatch between Java and Kotlin for met…
Browse files Browse the repository at this point in the history
…hods with `Collection<*>` parameter, including: `containsAll()`.

The `FixJavaKotlinCollectionMethodsMismatch` pass will convert selected method parameters to `readonly` versions of `Collection` and `Map` and insert casts if necessary.

PiperOrigin-RevId: 700085870
  • Loading branch information
Googler authored and copybara-github committed Nov 25, 2024
1 parent 90acf22 commit 18cd8cf
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ public class TypeDescriptors {
public DeclaredTypeDescriptor javaUtilList;
public DeclaredTypeDescriptor javaUtilObjects;

@Nullable
@QualifiedBinaryName("java.util.ReadonlyCollection")
public DeclaredTypeDescriptor javaUtilReadonlyCollection;

@Nullable
@QualifiedBinaryName("java.util.ReadonlyMap")
public DeclaredTypeDescriptor javaUtilReadonlyMap;

public DeclaredTypeDescriptor javaUtilFunctionSupplier;

@QualifiedBinaryName("java.util.function.BooleanSupplier")
Expand Down Expand Up @@ -712,4 +720,3 @@ private static String getClassBinaryName(Field field) {
String value();
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
*/
package com.google.j2cl.transpiler.passes;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.j2cl.transpiler.ast.PrimitiveTypes.INT;
import static com.google.j2cl.transpiler.ast.TypeVariable.createWildcardWithUpperBound;
import static com.google.j2cl.transpiler.passes.FixJavaKotlinCollectionMethodsMismatch.MethodMapping.methodMapping;
import static com.google.j2cl.transpiler.passes.FixJavaKotlinCollectionMethodsMismatch.ParameterMapping.parameterMapping;

import com.google.auto.value.AutoValue;
import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -45,43 +49,71 @@
/**
* Rewrites collection methods and method calls where Java method signature is different from
* Kotlin.
*
* <p>TODO(b/364506629): At this moment it handles methods which differ on `Object` parameter, like:
* {@code contains(Object)} -> {@code contains(T)}.
*/
public class FixJavaKotlinCollectionMethodsMismatch extends NormalizationPass {

private final DeclaredTypeDescriptor object = TypeDescriptors.get().javaLangObject;
private final DeclaredTypeDescriptor collection = TypeDescriptors.get().javaUtilCollection;
private final DeclaredTypeDescriptor map = TypeDescriptors.get().javaUtilMap;
private final DeclaredTypeDescriptor list = TypeDescriptors.get().javaUtilList;
private final TypeDescriptors types = TypeDescriptors.get();

private final DeclaredTypeDescriptor object = types.javaLangObject;
private final DeclaredTypeDescriptor collection = types.javaUtilCollection;
private final DeclaredTypeDescriptor map = types.javaUtilMap;
private final DeclaredTypeDescriptor list = types.javaUtilList;

private final DeclaredTypeDescriptor readonlyCollection =
checkNotNull(types.javaUtilReadonlyCollection);
private final DeclaredTypeDescriptor readonlyMap = checkNotNull(types.javaUtilReadonlyMap);

private final TypeVariable collectionElement = typeParameter(collection, 0);
private final TypeVariable listElement = typeParameter(list, 0);
private final TypeVariable mapKey = typeParameter(map, 0);
private final TypeVariable mapValue = typeParameter(map, 1);

private final TypeDescriptor readonlyCollectionOfElements =
readonlyCollection.withTypeArguments(ImmutableList.of(collectionElement)).toNonNullable();

private final TypeDescriptor readonlyMapOfWildcardKeysAndValues =
readonlyMap
.withTypeArguments(ImmutableList.of(createWildcardWithUpperBound(mapKey), mapValue))
.toNonNullable();

private final ImmutableList<MethodMapping> methodMappings =
ImmutableList.of(
methodMapping(
collection.getMethodDescriptor("contains", object),
collection.getTypeDeclaration().getTypeParameterDescriptors().get(0)),
methodMapping(
collection.getMethodDescriptor("remove", object),
collection.getTypeDeclaration().getTypeParameterDescriptors().get(0)),
collection, "addAll", parameterMapping(collection, readonlyCollectionOfElements)),
methodMapping(collection, "contains", parameterMapping(object, collectionElement)),
methodMapping(collection, "remove", parameterMapping(object, collectionElement)),
methodMapping(
map.getMethodDescriptor("containsKey", object),
map.getTypeDeclaration().getTypeParameterDescriptors().get(0)),
collection,
"containsAll",
parameterMapping(collection, readonlyCollectionOfElements)),
methodMapping(
map.getMethodDescriptor("containsValue", object),
map.getTypeDeclaration().getTypeParameterDescriptors().get(1)),
collection, "removeAll", parameterMapping(collection, readonlyCollectionOfElements)),
methodMapping(
map.getMethodDescriptor("get", object),
map.getTypeDeclaration().getTypeParameterDescriptors().get(0)),
collection, "retainAll", parameterMapping(collection, readonlyCollectionOfElements)),
methodMapping(map, "containsKey", parameterMapping(object, mapKey)),
methodMapping(map, "containsValue", parameterMapping(object, mapValue)),
methodMapping(map, "get", parameterMapping(object, mapKey)),
// TODO(b/364506629): This one needs return type conversion from V to V?
// methodMapping(
// map,
// "getOrDefault",
// parameterMapping(object, mapKey),
// parameterMapping(mapValue, mapValue)),
methodMapping(map, "putAll", parameterMapping(map, readonlyMapOfWildcardKeysAndValues)),
methodMapping(map, "remove", parameterMapping(object, mapKey)),
methodMapping(
map.getMethodDescriptor("remove", object),
map.getTypeDeclaration().getTypeParameterDescriptors().get(0)),
map, "remove", parameterMapping(object, mapKey), parameterMapping(object, mapValue)),
methodMapping(
list.getMethodDescriptor("indexOf", object),
list.getTypeDeclaration().getTypeParameterDescriptors().get(0)),
methodMapping(
list.getMethodDescriptor("lastIndexOf", object),
list.getTypeDeclaration().getTypeParameterDescriptors().get(0)));
list,
"addAll",
parameterMapping(INT, INT),
parameterMapping(collection, readonlyCollectionOfElements)),
methodMapping(list, "indexOf", parameterMapping(object, listElement)),
methodMapping(list, "lastIndexOf", parameterMapping(object, listElement)));

private static TypeVariable typeParameter(DeclaredTypeDescriptor typeDescriptor, int index) {
return typeDescriptor.getTypeDeclaration().getTypeParameterDescriptors().get(index);
}

@Override
public void applyTo(CompilationUnit compilationUnit) {
Expand All @@ -98,7 +130,8 @@ public Node rewriteMethod(Method method) {

@Override
public Node rewriteMethodCall(MethodCall methodCall) {
MethodMapping methodMapping = findMethodMapping(methodCall.getTarget());
MethodMapping methodMapping =
findMethodMapping(methodCall.getTarget().getDeclarationDescriptor());
if (methodMapping == null) {
return methodCall;
}
Expand All @@ -107,33 +140,62 @@ public Node rewriteMethodCall(MethodCall methodCall) {
});
}

/** A mapping from Java parameter signature to Kotlin parameter type. */
@AutoValue
abstract static class ParameterMapping {
static ParameterMapping parameterMapping(
TypeDescriptor javaRawTypeDescriptor, TypeDescriptor kotlinTypeDescriptor) {
return new AutoValue_FixJavaKotlinCollectionMethodsMismatch_ParameterMapping(
javaRawTypeDescriptor, kotlinTypeDescriptor);
}

abstract TypeDescriptor getJavaSignatureTypeDescriptor();

abstract TypeDescriptor getKotlinTypeDescriptor();
}

/** A mapping from Java method to Kotlin method. */
@AutoValue
abstract static class MethodMapping {
static MethodMapping methodMapping(
MethodDescriptor javaMethodDescriptor, TypeDescriptor... kotlinParameterTypeDescriptors) {
checkArgument(
javaMethodDescriptor.getParameterTypeDescriptors().size()
== kotlinParameterTypeDescriptors.length);
DeclaredTypeDescriptor enclosingTypeDescriptor,
String methodName,
ParameterMapping... parameterTypeMappings) {
ImmutableList<ParameterMapping> parameterTypeMappingsList =
ImmutableList.copyOf(parameterTypeMappings);
ImmutableList<TypeDescriptor> parameterSignatureTypeDescriptors =
parameterTypeMappingsList.stream()
.map(ParameterMapping::getJavaSignatureTypeDescriptor)
.collect(toImmutableList());
MethodDescriptor javaMethodDescriptor =
enclosingTypeDescriptor.getMethodDescriptor(
methodName, parameterSignatureTypeDescriptors.toArray(new TypeDescriptor[] {}));
ImmutableList<TypeDescriptor> kotlinParameterTypeDescriptors =
parameterTypeMappingsList.stream()
.map(ParameterMapping::getKotlinTypeDescriptor)
.collect(toImmutableList());
MethodDescriptor kotlinMethodDescriptor =
MethodDescriptor.Builder.from(javaMethodDescriptor)
.updateParameterTypeDescriptors(kotlinParameterTypeDescriptors)
.build();
return new AutoValue_FixJavaKotlinCollectionMethodsMismatch_MethodMapping(
javaMethodDescriptor, ImmutableList.copyOf(kotlinParameterTypeDescriptors));
javaMethodDescriptor, kotlinMethodDescriptor);
}

/** The original Java method descriptor. */
abstract MethodDescriptor getJavaMethodDescriptor();

/** Kotlin parameter type descriptors. */
abstract ImmutableList<TypeDescriptor> getKotlinParameterTypeDescriptors();
abstract MethodDescriptor getKotlinMethodDescriptor();

final String getBridgeMethodName() {
return "java_" + getJavaMethodDescriptor().getName();
}

final boolean isOverride(MethodDescriptor methodDescriptor) {
return methodDescriptor
.getEnclosingTypeDescriptor()
.isSubtypeOf(getJavaMethodDescriptor().getEnclosingTypeDescriptor())
&& methodDescriptor.isOverride(getJavaMethodDescriptor());
final boolean isOrOverrides(MethodDescriptor methodDescriptor) {
return methodDescriptor.equals(getJavaMethodDescriptor())
|| (methodDescriptor
.getEnclosingTypeDescriptor()
.isSubtypeOf(getJavaMethodDescriptor().getEnclosingTypeDescriptor())
&& methodDescriptor.isOverride(getJavaMethodDescriptor()));
}

/**
Expand All @@ -150,23 +212,41 @@ final Method fixMethodParameters(Method method) {
new ArrayList<>(methodDescriptor.getParameterTypeDescriptors());
Block body = method.getBody();
for (int i = 0; i < getJavaMethodDescriptor().getParameterTypeDescriptors().size(); i++) {
TypeDescriptor javaParameterTypeDescriptor =
TypeDescriptor javaDeclarationParameterTypeDescriptor =
getJavaMethodDescriptor().getParameterTypeDescriptors().get(i);
TypeDescriptor kotlinParameterTypeDescriptor = getKotlinParameterTypeDescriptors().get(i);
if (!javaParameterTypeDescriptor.equals(kotlinParameterTypeDescriptor)) {
TypeDescriptor parameterTypeDescriptor =
kotlinParameterTypeDescriptor.specializeTypeVariables(parameterization);
parameterTypeDescriptors.set(i, parameterTypeDescriptor);
TypeDescriptor kotlinDeclarationParameterTypeDescriptor =
getKotlinMethodDescriptor().getParameterTypeDescriptors().get(i);
if (!javaDeclarationParameterTypeDescriptor.equals(
kotlinDeclarationParameterTypeDescriptor)) {
TypeDescriptor kotlinParameterTypeDescriptor =
kotlinDeclarationParameterTypeDescriptor.specializeTypeVariables(parameterization);
parameterTypeDescriptors.set(i, kotlinParameterTypeDescriptor);

// Create a new parameter with the type expected by kotlin overrides and move the old
// parameter variable to local declaration to avoid rewriting all uses.
Variable parameter = method.getParameters().get(i);
Variable newParameter =
Variable.Builder.from(parameter).setTypeDescriptor(parameterTypeDescriptor).build();
method.getParameters().set(i, newParameter);
TypeDescriptor parameterTypeDescriptor = parameter.getTypeDescriptor();

Variable kotlinParameter =
Variable.Builder.from(parameter)
.setTypeDescriptor(kotlinParameterTypeDescriptor)
.build();
method.getParameters().set(i, kotlinParameter);

Expression initializer = kotlinParameter.createReference();

// Insert cast from read-only to mutable if necessary.
if (!kotlinParameterTypeDescriptor.isAssignableTo(parameterTypeDescriptor)) {
initializer =
CastExpression.newBuilder()
.setExpression(initializer)
.setCastTypeDescriptor(parameterTypeDescriptor)
.build();
}

Statement declarationStatement =
VariableDeclarationExpression.newBuilder()
.addVariableDeclaration(parameter, newParameter.createReference())
.addVariableDeclaration(parameter, initializer)
.build()
.makeStatement(parameter.getSourcePosition());
body =
Expand Down Expand Up @@ -195,21 +275,29 @@ final Method fixMethodParameters(Method method) {
* bridge extension functions, and explicit casts are inserted for arguments to super calls.
*/
MethodCall bridgeOrInsertCastIfNeeded(MethodCall methodCall) {
ImmutableList<TypeDescriptor> javaParameterTypeDescriptors =
getJavaMethodDescriptor().getParameterTypeDescriptors();
ImmutableList<TypeDescriptor> kotlinParameterTypeDescriptors =
getKotlinMethodDescriptor().getParameterTypeDescriptors();
MethodDescriptor methodDescriptor = methodCall.getTarget();
Expression qualifier = methodCall.getQualifier();

Map<TypeVariable, TypeDescriptor> parameterization =
methodDescriptor.getEnclosingTypeDescriptor().getParameterization();

if (qualifier instanceof SuperReference) {
if (javaParameterTypeDescriptors.equals(kotlinParameterTypeDescriptors)) {
return methodCall;
}

List<Expression> arguments = methodCall.getArguments();
for (int i = 0; i < arguments.size(); i++) {
arguments.set(
i,
insertCastIfNeeded(
arguments.get(i),
getJavaMethodDescriptor().getParameterTypeDescriptors().get(i),
getKotlinParameterTypeDescriptors().get(i),
javaParameterTypeDescriptors.get(i),
kotlinParameterTypeDescriptors.get(i),
parameterization));
}
return methodCall;
Expand All @@ -225,6 +313,7 @@ static Expression insertCastIfNeeded(
TypeDescriptor javaParameterTypeDescriptor,
TypeDescriptor kotlinParameterTypeDescriptor,
Map<TypeVariable, TypeDescriptor> parameterization) {
// TODO(b/380235439): Use generic-aware isAssignableTo() when it's implemented.
if (!javaParameterTypeDescriptor.equals(kotlinParameterTypeDescriptor)) {
return CastExpression.newBuilder()
.setExpression(argument)
Expand All @@ -239,7 +328,7 @@ static Expression insertCastIfNeeded(
@Nullable
MethodMapping findMethodMapping(MethodDescriptor methodDescriptor) {
return methodMappings.stream()
.filter(it -> it.isOverride(methodDescriptor))
.filter(it -> it.isOrOverrides(methodDescriptor))
.findFirst()
.orElse(null);
}
Expand Down
Loading

0 comments on commit 18cd8cf

Please sign in to comment.