Skip to content

Commit

Permalink
Prevent decoding VP8 videos on ARC devices, which can put the device…
Browse files Browse the repository at this point in the history
… into a bad state.

PiperOrigin-RevId: 483541133
  • Loading branch information
sjudd authored and glide-copybara-robot committed Oct 25, 2022
1 parent 4298bb7 commit 4bfda58
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import android.graphics.Bitmap;
import android.graphics.Matrix;
import android.media.MediaDataSource;
import android.media.MediaExtractor;
import android.media.MediaFormat;
import android.media.MediaMetadataRetriever;
import android.os.Build;
Expand Down Expand Up @@ -125,7 +126,9 @@ public void update(
private static final List<String> PIXEL_T_BUILD_ID_PREFIXES_REQUIRING_HDR_180_ROTATION_FIX =
Collections.unmodifiableList(Arrays.asList("TP1A", "TD1A.220804.031"));

private final MediaMetadataRetrieverInitializer<T> initializer;
private static final String WEBM_MIME_TYPE = "video/webm";

private final MediaInitializer<T> initializer;
private final BitmapPool bitmapPool;
private final MediaMetadataRetrieverFactory factory;

Expand All @@ -142,14 +145,14 @@ public static ResourceDecoder<ByteBuffer, Bitmap> byteBuffer(BitmapPool bitmapPo
return new VideoDecoder<>(bitmapPool, new ByteBufferInitializer());
}

VideoDecoder(BitmapPool bitmapPool, MediaMetadataRetrieverInitializer<T> initializer) {
VideoDecoder(BitmapPool bitmapPool, MediaInitializer<T> initializer) {
this(bitmapPool, initializer, DEFAULT_FACTORY);
}

@VisibleForTesting
VideoDecoder(
BitmapPool bitmapPool,
MediaMetadataRetrieverInitializer<T> initializer,
MediaInitializer<T> initializer,
MediaMetadataRetrieverFactory factory) {
this.bitmapPool = bitmapPool;
this.initializer = initializer;
Expand Down Expand Up @@ -185,9 +188,10 @@ public Resource<Bitmap> decode(
final Bitmap result;
MediaMetadataRetriever mediaMetadataRetriever = factory.build();
try {
initializer.initialize(mediaMetadataRetriever, resource);
initializer.initializeRetriever(mediaMetadataRetriever, resource);
result =
decodeFrame(
resource,
mediaMetadataRetriever,
frameTimeMicros,
frameOption,
Expand All @@ -206,13 +210,18 @@ public Resource<Bitmap> decode(
}

@Nullable
private static Bitmap decodeFrame(
private Bitmap decodeFrame(
@NonNull T resource,
MediaMetadataRetriever mediaMetadataRetriever,
long frameTimeMicros,
int frameOption,
int outWidth,
int outHeight,
DownsampleStrategy strategy) {
if (isUnsupportedFormat(resource, mediaMetadataRetriever)) {
throw new IllegalStateException("Cannot decode VP8 video on CrOS.");
}

Bitmap result = null;
// Arguably we should handle the case where just width or just height is set to
// Target.SIZE_ORIGINAL. Up to and including OMR1, MediaMetadataRetriever defaults to setting
Expand Down Expand Up @@ -402,6 +411,54 @@ private static Bitmap decodeOriginalFrame(
return mediaMetadataRetriever.getFrameAtTime(frameTimeMicros, frameOption);
}

/** Returns true if the format type is unsupported on the device. */
private boolean isUnsupportedFormat(
@NonNull T resource, MediaMetadataRetriever mediaMetadataRetriever) {
// MediaFormat.KEY_MIME check below requires at least JELLY_BEAN
if (Build.VERSION.SDK_INT < VERSION_CODES.JELLY_BEAN) {
return false;
}

// The primary known problem is vp8 video on ChromeOS (ARC) devices.
boolean isArc = Build.DEVICE != null && Build.DEVICE.matches(".+_cheets|cheets_.+");
if (!isArc) {
return false;
}

MediaExtractor mediaExtractor = null;
try {
// Include the MediaMetadataRetriever extract in the try block out of an abundance of caution.
String mimeType =
mediaMetadataRetriever.extractMetadata(MediaMetadataRetriever.METADATA_KEY_MIMETYPE);
if (!WEBM_MIME_TYPE.equals(mimeType)) {
return false;
}

// Only construct a MediaExtractor for webm files, since the constructor makes a JNI call
mediaExtractor = new MediaExtractor();
initializer.initializeExtractor(mediaExtractor, resource);
int numTracks = mediaExtractor.getTrackCount();
for (int i = 0; i < numTracks; ++i) {
MediaFormat mediaformat = mediaExtractor.getTrackFormat(i);
String trackMimeType = mediaformat.getString(MediaFormat.KEY_MIME);
if (MediaFormat.MIMETYPE_VIDEO_VP8.equals(trackMimeType)) {
return true;
}
}
} catch (Throwable t) {
// Catching everything here out of an abundance of caution
if (Log.isLoggable(TAG, Log.DEBUG)) {
Log.d(TAG, "Exception trying to extract track info for a webm video on CrOS.", t);
}
} finally {
if (mediaExtractor != null) {
mediaExtractor.release();
}
}

return false;
}

@VisibleForTesting
static class MediaMetadataRetrieverFactory {
public MediaMetadataRetriever build() {
Expand All @@ -410,56 +467,78 @@ public MediaMetadataRetriever build() {
}

@VisibleForTesting
interface MediaMetadataRetrieverInitializer<T> {
void initialize(MediaMetadataRetriever retriever, T data);
interface MediaInitializer<T> {
void initializeRetriever(MediaMetadataRetriever retriever, T data);

void initializeExtractor(MediaExtractor extractor, T data) throws IOException;
}

private static final class AssetFileDescriptorInitializer
implements MediaMetadataRetrieverInitializer<AssetFileDescriptor> {
implements MediaInitializer<AssetFileDescriptor> {

@Override
public void initialize(MediaMetadataRetriever retriever, AssetFileDescriptor data) {
public void initializeRetriever(MediaMetadataRetriever retriever, AssetFileDescriptor data) {
retriever.setDataSource(data.getFileDescriptor(), data.getStartOffset(), data.getLength());
}

@Override
public void initializeExtractor(MediaExtractor extractor, AssetFileDescriptor data)
throws IOException {
extractor.setDataSource(data.getFileDescriptor(), data.getStartOffset(), data.getLength());
}
}

// Visible for VideoBitmapDecoder.
static final class ParcelFileDescriptorInitializer
implements MediaMetadataRetrieverInitializer<ParcelFileDescriptor> {
implements MediaInitializer<ParcelFileDescriptor> {

@Override
public void initialize(MediaMetadataRetriever retriever, ParcelFileDescriptor data) {
public void initializeRetriever(MediaMetadataRetriever retriever, ParcelFileDescriptor data) {
retriever.setDataSource(data.getFileDescriptor());
}

@Override
public void initializeExtractor(MediaExtractor extractor, ParcelFileDescriptor data)
throws IOException {
extractor.setDataSource(data.getFileDescriptor());
}
}

@RequiresApi(Build.VERSION_CODES.M)
static final class ByteBufferInitializer
implements MediaMetadataRetrieverInitializer<ByteBuffer> {
static final class ByteBufferInitializer implements MediaInitializer<ByteBuffer> {

@Override
public void initialize(MediaMetadataRetriever retriever, final ByteBuffer data) {
retriever.setDataSource(
new MediaDataSource() {
@Override
public int readAt(long position, byte[] buffer, int offset, int size) {
if (position >= data.limit()) {
return -1;
}
data.position((int) position);
int numBytesRead = Math.min(size, data.remaining());
data.get(buffer, offset, numBytesRead);
return numBytesRead;
}
public void initializeRetriever(MediaMetadataRetriever retriever, final ByteBuffer data) {
retriever.setDataSource(getMediaDataSource(data));
}

@Override
public long getSize() {
return data.limit();
}
@Override
public void initializeExtractor(MediaExtractor extractor, final ByteBuffer data)
throws IOException {
extractor.setDataSource(getMediaDataSource(data));
}

@Override
public void close() {}
});
private MediaDataSource getMediaDataSource(final ByteBuffer data) {
return new MediaDataSource() {
@Override
public int readAt(long position, byte[] buffer, int offset, int size) {
if (position >= data.limit()) {
return -1;
}
data.position((int) position);
int numBytesRead = Math.min(size, data.remaining());
data.get(buffer, offset, numBytesRead);
return numBytesRead;
}

@Override
public long getSize() {
return data.limit();
}

@Override
public void close() {}
};
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.never;
Expand Down Expand Up @@ -37,7 +38,7 @@
public class VideoDecoderTest {
@Mock private ParcelFileDescriptor resource;
@Mock private VideoDecoder.MediaMetadataRetrieverFactory factory;
@Mock private VideoDecoder.MediaMetadataRetrieverInitializer<ParcelFileDescriptor> initializer;
@Mock private VideoDecoder.MediaInitializer<ParcelFileDescriptor> initializer;
@Mock private MediaMetadataRetriever retriever;
@Mock private BitmapPool bitmapPool;
private VideoDecoder<ParcelFileDescriptor> decoder;
Expand All @@ -46,6 +47,7 @@ public class VideoDecoderTest {
private String initialMake;
private String initialModel;
private String initialBuildId;
private String initialDevice;

@Before
public void setup() {
Expand All @@ -58,13 +60,13 @@ public void setup() {
initialMake = Build.MANUFACTURER;
initialModel = Build.MODEL;
initialBuildId = Build.ID;
initialDevice = Build.DEVICE;
}

@After
public void tearDown() {
Util.setSdkVersionInt(initialSdkVersion);
setMakeAndModel(initialMake, initialModel);
setBuildId(initialBuildId);
resetBuildInfo(initialMake, initialModel, initialBuildId, initialDevice);
}

@Test
Expand All @@ -77,7 +79,7 @@ public void testReturnsRetrievedFrameForResource() throws IOException {
Resource<Bitmap> result =
Preconditions.checkNotNull(decoder.decode(resource, 100, 100, options));

verify(initializer).initialize(retriever, resource);
verify(initializer).initializeRetriever(retriever, resource);
assertEquals(expected, result.get());
}

Expand Down Expand Up @@ -194,6 +196,45 @@ public void decodeFrame_withTargetSizeOriginalHeightOnly_onApi27_doesNotThrow()
.isSameInstanceAs(expected);
}

@Test
public void decodeFrame_notArcDeviceButWebm_doesNotInitializeMediaExtractor() throws IOException {
setDevice("notArc");
when(retriever.extractMetadata(MediaMetadataRetriever.METADATA_KEY_MIMETYPE))
.thenReturn("video/webm");
when(retriever.getFrameAtTime(-1, MediaMetadataRetriever.OPTION_CLOSEST_SYNC))
.thenReturn(Bitmap.createBitmap(100, 100, Bitmap.Config.ARGB_8888));

decoder.decode(resource, Target.SIZE_ORIGINAL, Target.SIZE_ORIGINAL, options).get();

verify(initializer, never()).initializeExtractor(any(), any());
}

@Test
public void decodeFrame_arcDeviceButNotWebm_doesNotInitializeMediaExtractor() throws IOException {
setDevice("arc_cheets");
when(retriever.extractMetadata(MediaMetadataRetriever.METADATA_KEY_MIMETYPE))
.thenReturn("video/mp4");
when(retriever.getFrameAtTime(-1, MediaMetadataRetriever.OPTION_CLOSEST_SYNC))
.thenReturn(Bitmap.createBitmap(100, 100, Bitmap.Config.ARGB_8888));

decoder.decode(resource, Target.SIZE_ORIGINAL, Target.SIZE_ORIGINAL, options).get();

verify(initializer, never()).initializeExtractor(any(), any());
}

@Test
public void decodeFrame_arcDeviceAndWebm_initializesMediaExtractor() throws IOException {
setDevice("arc_cheets");
when(retriever.extractMetadata(MediaMetadataRetriever.METADATA_KEY_MIMETYPE))
.thenReturn("video/webm");
when(retriever.getFrameAtTime(-1, MediaMetadataRetriever.OPTION_CLOSEST_SYNC))
.thenReturn(Bitmap.createBitmap(100, 100, Bitmap.Config.ARGB_8888));

decoder.decode(resource, Target.SIZE_ORIGINAL, Target.SIZE_ORIGINAL, options).get();

verify(initializer).initializeExtractor(any(), any());
}

@Test
@Config(sdk = VERSION_CODES.M)
public void isHdr180RotationFixRequired_androidM_returnsFalse() {
Expand All @@ -218,12 +259,14 @@ public void isHdr180RotationFixRequired_androidS_returnsTrue() {
assertThat(VideoDecoder.isHdr180RotationFixRequired()).isTrue();
}

private void setMakeAndModel(String make, String model) {
private void resetBuildInfo(String make, String model, String buildId, String device) {
ReflectionHelpers.setStaticField(Build.class, "MANUFACTURER", make);
ReflectionHelpers.setStaticField(Build.class, "MODEL", model);
ReflectionHelpers.setStaticField(Build.class, "ID", buildId);
setDevice(device);
}

private void setBuildId(String buildId) {
ReflectionHelpers.setStaticField(Build.class, "ID", buildId);
private void setDevice(String device) {
ReflectionHelpers.setStaticField(Build.class, "DEVICE", device);
}
}

0 comments on commit 4bfda58

Please sign in to comment.