diff --git a/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/api/ObjectSerializationStrategy.java b/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/api/ObjectSerializationStrategy.java index 2c5fec0ff..ea23fb256 100644 --- a/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/api/ObjectSerializationStrategy.java +++ b/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/api/ObjectSerializationStrategy.java @@ -21,9 +21,15 @@ import java.io.IOException; import java.io.ObjectOutputStream; import java.text.MessageFormat; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; import org.apache.aries.rsa.provider.fastbin.FastBinProvider; import org.apache.aries.rsa.provider.fastbin.util.ClassLoaderObjectInputStream; +import org.apache.aries.rsa.provider.fastbin.util.FilteredClassLoaderObjectInputStream; import org.fusesource.hawtbuf.DataByteArrayInputStream; import org.fusesource.hawtbuf.DataByteArrayOutputStream; import org.osgi.framework.ServiceException; @@ -38,6 +44,50 @@ public class ObjectSerializationStrategy implements SerializationStrategy { private static final ObjectSerializationStrategy V1 = INSTANCE; private int protocolVersion = FastBinProvider.PROTOCOL_VERSION; + private static final Set ALLOWEDCLASSES; + private static final FilteredClassLoaderObjectInputStream.AllowlistPackagesPredicate ALLOWED_PACKAGES; + private static final String ADDITIONAL_ALLOWED_PACKAGE = System.getProperty( "org.apache.aries.rsa.provider.fastbin.api.DESERIALIZATION_PACKAGE_ALLOW_LIST", ""); + private static final String ADDITIONAL_ALLOWED_CLASSES = System.getProperty( "org.apache.aries.rsa.provider.fastbin.api.DESERIALIZATION_CLASS_ALLOW_LIST", ""); + + static + { + Set classes = new HashSet<>(); + classes.addAll(Arrays.asList( + "B", // byte + "C", // char + "D", // double + "F", // float + "I", // int + "J", // long + "S", // short + "Z", // boolean + "L" // Object type (LClassName;) + )); + final String[] customClasses = ADDITIONAL_ALLOWED_CLASSES.split(","); + if (customClasses.length > 0) + { + classes.addAll(Arrays.asList(customClasses)); + } + ALLOWEDCLASSES = classes; + + List packages = new ArrayList<>(); + packages.addAll(Arrays.asList( + "java", + "javax", + "Ljava", + "org.apache.aries.rsa", + "org.osgi.framework", + "com.seeburger")); + + final String[] customPackages = ADDITIONAL_ALLOWED_PACKAGE.split(","); + if (customPackages.length > 0) + { + packages.addAll(Arrays.asList(customPackages)); + } + ALLOWED_PACKAGES = new FilteredClassLoaderObjectInputStream.AllowlistPackagesPredicate(packages); + } + + public String name() { return "object"; @@ -50,7 +100,7 @@ public void encodeRequest(ClassLoader loader, Class[] types, Object[] args, D } public void decodeResponse(ClassLoader loader, Class type, DataByteArrayInputStream source, AsyncCallback result) throws IOException, ClassNotFoundException { - ClassLoaderObjectInputStream ois = new ClassLoaderObjectInputStream(source); + ClassLoaderObjectInputStream ois = new FilteredClassLoaderObjectInputStream(source, ALLOWEDCLASSES, ALLOWED_PACKAGES); ois.setClassLoader(loader); Throwable error = (Throwable) ois.readObject(); Object value = ois.readObject(); @@ -62,7 +112,7 @@ public void decodeResponse(ClassLoader loader, Class type, DataByteArrayInput } public void decodeRequest(ClassLoader loader, Class[] types, DataByteArrayInputStream source, Object[] target) throws IOException, ClassNotFoundException { - final ClassLoaderObjectInputStream ois = new ClassLoaderObjectInputStream(source); + ClassLoaderObjectInputStream ois = new FilteredClassLoaderObjectInputStream(source, ALLOWEDCLASSES, ALLOWED_PACKAGES); ois.setClassLoader(loader); final Object[] args = (Object[]) ois.readObject(); if( args!=null ) { diff --git a/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/util/FilteredClassLoaderObjectInputStream.java b/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/util/FilteredClassLoaderObjectInputStream.java new file mode 100644 index 000000000..3fe1ffa92 --- /dev/null +++ b/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/util/FilteredClassLoaderObjectInputStream.java @@ -0,0 +1,116 @@ +/* + * FilteredClassLoaderObjectInputStream.java + * + * created at 2024-09-27 by t.neykov + * + * Copyright (c) SEEBURGER AG, Germany. All Rights Reserved. + */ + +package org.apache.aries.rsa.provider.fastbin.util; + + +import java.io.IOException; +import java.io.InputStream; +import java.io.InvalidClassException; +import java.io.ObjectStreamClass; +import java.util.List; +import java.util.Set; +import java.util.function.Predicate; + +/** + * This class is a subclass of {@link ClassLoaderObjectInputStream} that only allows a specific set of classes to be + * deserialized. This is to prevent deserialization attacks. + */ +public class FilteredClassLoaderObjectInputStream extends ClassLoaderObjectInputStream +{ + /** + * Property to disable secure deserialization. If this property is set to true, then the class will not throw an + * exception if the class is not in the allowed classes list. This is useful for testing. + */ + static final String PROPERTY_USE_INSECURE_DESERIALIZATION = "org.apache.aries.rsa.provider.fastbin.util.useInsecureDeserialization"; + static boolean useInsecureDeserialization = Boolean.getBoolean(PROPERTY_USE_INSECURE_DESERIALIZATION); + + private final Set allowedClasses; + private Predicate allowedPackages; + + public FilteredClassLoaderObjectInputStream(InputStream s, Set allowedClasses) + throws IOException + { + super(s); + if (allowedClasses == null) + { + throw new IllegalArgumentException("allowedClasses must not be null"); + } + + this.allowedClasses = allowedClasses; + } + + public FilteredClassLoaderObjectInputStream(InputStream inArg, Set allowedClasses, Predicate allowedPackages) + throws IOException + { + super(inArg); + + if (allowedClasses == null) + { + throw new IllegalArgumentException("allowedClasses must not be null"); + } + + this.allowedClasses = allowedClasses; + this.allowedPackages = allowedPackages; + } + + @Override + protected Class< ? > resolveClass(ObjectStreamClass clsDescriptor) + throws IOException, ClassNotFoundException + { + String className = removeArrayMarkersFromClassName(clsDescriptor); + + if (!useInsecureDeserialization) + { + if (allowedClasses.contains(className)) + { + return super.resolveClass(clsDescriptor); + } + if (allowedPackages != null && allowedPackages.test(className)) + { + return super.resolveClass(clsDescriptor); + } + throw new InvalidClassException(className, "Invalid de-serialisation data. POSSIBLE ATTACK. Invalid class=" + className); + } + + return super.resolveClass(clsDescriptor); + } + + + /** + * Removes array markers from the class name. (could be more than one). + * @param clsDescriptor + * @return + */ + private static String removeArrayMarkersFromClassName(ObjectStreamClass clsDescriptor) + { + String className = clsDescriptor.getName(); + int leadingBrackets = 0; + while (className.charAt(leadingBrackets) == '[') { + leadingBrackets++; + } + return className.substring(leadingBrackets); + } + + + public static class AllowlistPackagesPredicate implements Predicate + { + private final List allowedPackagesList; + + public AllowlistPackagesPredicate(List allowedPackagesList) + { + this.allowedPackagesList = allowedPackagesList; + } + + @Override + public boolean test(String className) + { + return allowedPackagesList.stream().anyMatch(className::startsWith); + } + } +}