From 9c20da3fa019f0fea5a11b6d99ba6258c374ea73 Mon Sep 17 00:00:00 2001 From: Arend Jan Kramer Date: Tue, 14 Jan 2025 14:18:07 +0100 Subject: [PATCH 01/15] Initial vlm feature --- lib/interop/generated_bindings.dart | 163 +++++++ lib/interop/openvino_bindings.dart | 7 + lib/interop/vlm_inference.dart | 123 ++++++ lib/pages/models/inference.dart | 3 + lib/pages/vlm/live_inference_pane.dart | 82 ++++ lib/pages/vlm/performance_metrics_pane.dart | 67 +++ .../vlm/providers/vlm_inference_provider.dart | 239 ++++++++++ lib/pages/vlm/vlm_page.dart | 133 ++++++ lib/pages/vlm/widgets/device_selector.dart | 70 +++ lib/pages/vlm/widgets/horizontal_rule.dart | 29 ++ lib/pages/vlm/widgets/image_grid.dart | 128 ++++++ lib/pages/vlm/widgets/model_properties.dart | 75 ++++ lib/pages/vlm/widgets/toolbar_text_input.dart | 119 +++++ lib/pages/vlm/widgets/vertical_rule.dart | 29 ++ lib/pages/vlm/widgets/vlm_chat_area.dart | 412 ++++++++++++++++++ lib/pages/vlm/widgets/vlm_metrics_grid.dart | 39 ++ lib/project.dart | 7 +- lib/widgets/controls/drop_area.dart | 95 ++-- openvino_bindings/src/BUILD | 2 + openvino_bindings/src/bindings.cc | 84 ++++ openvino_bindings/src/bindings.h | 25 ++ openvino_bindings/src/utils/BUILD | 8 + openvino_bindings/src/utils/vlm_metrics.h | 20 + openvino_bindings/src/vlm/BUILD | 35 ++ openvino_bindings/src/vlm/load_image.cpp | 82 ++++ openvino_bindings/src/vlm/load_image.hpp | 15 + openvino_bindings/src/vlm/vlm_inference.cc | 137 ++++++ openvino_bindings/src/vlm/vlm_inference.h | 69 +++ .../src/vlm/vlm_inference_test.cc | 16 + 29 files changed, 2272 insertions(+), 41 deletions(-) create mode 100644 lib/interop/vlm_inference.dart create mode 100644 lib/pages/vlm/live_inference_pane.dart create mode 100644 lib/pages/vlm/performance_metrics_pane.dart create mode 100644 lib/pages/vlm/providers/vlm_inference_provider.dart create mode 100644 lib/pages/vlm/vlm_page.dart create mode 100644 lib/pages/vlm/widgets/device_selector.dart create mode 100644 lib/pages/vlm/widgets/horizontal_rule.dart create mode 100644 lib/pages/vlm/widgets/image_grid.dart create mode 100644 lib/pages/vlm/widgets/model_properties.dart create mode 100644 lib/pages/vlm/widgets/toolbar_text_input.dart create mode 100644 lib/pages/vlm/widgets/vertical_rule.dart create mode 100644 lib/pages/vlm/widgets/vlm_chat_area.dart create mode 100644 lib/pages/vlm/widgets/vlm_metrics_grid.dart create mode 100644 openvino_bindings/src/utils/vlm_metrics.h create mode 100644 openvino_bindings/src/vlm/BUILD create mode 100644 openvino_bindings/src/vlm/load_image.cpp create mode 100644 openvino_bindings/src/vlm/load_image.hpp create mode 100644 openvino_bindings/src/vlm/vlm_inference.cc create mode 100644 openvino_bindings/src/vlm/vlm_inference.h create mode 100644 openvino_bindings/src/vlm/vlm_inference_test.cc diff --git a/lib/interop/generated_bindings.dart b/lib/interop/generated_bindings.dart index 08f9547d..07a52835 100644 --- a/lib/interop/generated_bindings.dart +++ b/lib/interop/generated_bindings.dart @@ -571,6 +571,128 @@ class OpenVINO { late final _ttiInferenceClose = _ttiInferenceClosePtr .asFunction Function(CLLMInference)>(); + ffi.Pointer vlmInferenceOpen( + ffi.Pointer model_path, + ffi.Pointer device, + ) { + return _vlmInferenceOpen( + model_path, + device, + ); + } + + late final _vlmInferenceOpenPtr = _lookup< + ffi.NativeFunction< + ffi.Pointer Function(ffi.Pointer, + ffi.Pointer)>>('vlmInferenceOpen'); + late final _vlmInferenceOpen = _vlmInferenceOpenPtr.asFunction< + ffi.Pointer Function( + ffi.Pointer, ffi.Pointer)>(); + + ffi.Pointer vlmInferenceSetListener( + CVLMInference instance, + VLMInferenceCallbackFunction callback, + ) { + return _vlmInferenceSetListener( + instance, + callback, + ); + } + + late final _vlmInferenceSetListenerPtr = _lookup< + ffi.NativeFunction< + ffi.Pointer Function(CVLMInference, + VLMInferenceCallbackFunction)>>('vlmInferenceSetListener'); + late final _vlmInferenceSetListener = _vlmInferenceSetListenerPtr.asFunction< + ffi.Pointer Function( + CVLMInference, VLMInferenceCallbackFunction)>(); + + ffi.Pointer vlmInferencePrompt( + CVLMInference instance, + ffi.Pointer message, + int max_new_tokens, + ) { + return _vlmInferencePrompt( + instance, + message, + max_new_tokens, + ); + } + + late final _vlmInferencePromptPtr = _lookup< + ffi.NativeFunction< + ffi.Pointer Function(CVLMInference, + ffi.Pointer, ffi.Int)>>('vlmInferencePrompt'); + late final _vlmInferencePrompt = _vlmInferencePromptPtr.asFunction< + ffi.Pointer Function( + CVLMInference, ffi.Pointer, int)>(); + + ffi.Pointer vlmInferenceSetImagePaths( + CVLMInference instance, + ffi.Pointer> paths, + int length, + ) { + return _vlmInferenceSetImagePaths( + instance, + paths, + length, + ); + } + + late final _vlmInferenceSetImagePathsPtr = _lookup< + ffi.NativeFunction< + ffi.Pointer Function( + CVLMInference, + ffi.Pointer>, + ffi.Int)>>('vlmInferenceSetImagePaths'); + late final _vlmInferenceSetImagePaths = + _vlmInferenceSetImagePathsPtr.asFunction< + ffi.Pointer Function( + CVLMInference, ffi.Pointer>, int)>(); + + ffi.Pointer vlmInferenceHasModelIndex( + CVLMInference instance, + ) { + return _vlmInferenceHasModelIndex( + instance, + ); + } + + late final _vlmInferenceHasModelIndexPtr = _lookup< + ffi + .NativeFunction Function(CVLMInference)>>( + 'vlmInferenceHasModelIndex'); + late final _vlmInferenceHasModelIndex = _vlmInferenceHasModelIndexPtr + .asFunction Function(CVLMInference)>(); + + ffi.Pointer vlmInferenceStop( + CVLMInference instance, + ) { + return _vlmInferenceStop( + instance, + ); + } + + late final _vlmInferenceStopPtr = + _lookup Function(CVLMInference)>>( + 'vlmInferenceStop'); + late final _vlmInferenceStop = _vlmInferenceStopPtr + .asFunction Function(CVLMInference)>(); + + ffi.Pointer vlmInferenceClose( + CVLMInference instance, + ) { + return _vlmInferenceClose( + instance, + ); + } + + late final _vlmInferenceClosePtr = + _lookup Function(CVLMInference)>>( + 'vlmInferenceClose'); + late final _vlmInferenceClose = _vlmInferenceClosePtr + .asFunction Function(CVLMInference)>(); + ffi.Pointer graphRunnerOpen( ffi.Pointer graph, ) { @@ -861,6 +983,20 @@ final class StringWithMetrics extends ffi.Struct { external TTIMetrics metrics; } +final class VLMMetrics extends ffi.Struct { + @ffi.Float() + external double load_time; + + @ffi.Float() + external double generate_time; +} + +final class VLMStringWithMetrics extends ffi.Struct { + external ffi.Pointer string; + + external VLMMetrics metrics; +} + final class Device extends ffi.Struct { external ffi.Pointer id; @@ -966,6 +1102,15 @@ final class StatusOrTTIInference extends ffi.Struct { external CLLMInference value; } +final class StatusOrVLMInference extends ffi.Struct { + @ffi.Int() + external int status; + + external ffi.Pointer message; + + external CLLMInference value; +} + final class StatusOrModelResponse extends ffi.Struct { @ffi.Int() external int status; @@ -1004,6 +1149,17 @@ final class StatusOrTTIModelResponse extends ffi.Struct { external ffi.Pointer value; } +final class StatusOrVLMModelResponse extends ffi.Struct { + @ffi.Int() + external int status; + + external ffi.Pointer message; + + external VLMMetrics metrics; + + external ffi.Pointer value; +} + final class StatusOrDevices extends ffi.Struct { @ffi.Int() external int status; @@ -1029,3 +1185,10 @@ typedef LLMInferenceCallbackFunctionFunction = ffi.Void Function( typedef DartLLMInferenceCallbackFunctionFunction = void Function( ffi.Pointer); typedef CTTIInference = ffi.Pointer; +typedef CVLMInference = ffi.Pointer; +typedef VLMInferenceCallbackFunction + = ffi.Pointer>; +typedef VLMInferenceCallbackFunctionFunction = ffi.Void Function( + ffi.Pointer); +typedef DartVLMInferenceCallbackFunctionFunction = void Function( + ffi.Pointer); diff --git a/lib/interop/openvino_bindings.dart b/lib/interop/openvino_bindings.dart index a0363b71..d9733135 100644 --- a/lib/interop/openvino_bindings.dart +++ b/lib/interop/openvino_bindings.dart @@ -49,6 +49,13 @@ class TTIModelResponse { const TTIModelResponse(this.content, this.metrics); } +class VLMModelResponse { + final String content; + final VLMMetrics metrics; + + const VLMModelResponse(this.content, this.metrics); +} + String getLibraryPath() { if (Platform.isWindows) { diff --git a/lib/interop/vlm_inference.dart b/lib/interop/vlm_inference.dart new file mode 100644 index 00000000..7a72e194 --- /dev/null +++ b/lib/interop/vlm_inference.dart @@ -0,0 +1,123 @@ +// Copyright (c) 2024 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 + +import 'dart:ffi'; +import 'dart:isolate'; + +import 'package:ffi/ffi.dart'; +import 'package:inference/interop/openvino_bindings.dart'; + +final vlmOV = getBindings(); + +class VLMInference { + NativeCallable? nativeListener; + final Pointer instance; + late bool chatEnabled; + + VLMInference(this.instance) { + chatEnabled = true; + } + + static Future init(String modelPath, String device) async { + final result = await Isolate.run(() { + final modelPathPtr = modelPath.toNativeUtf8(); + final devicePtr = device.toNativeUtf8(); + final status = vlmOV.vlmInferenceOpen(modelPathPtr, devicePtr); + calloc.free(modelPathPtr); + calloc.free(devicePtr); + + return status; + }); + + print("${result.ref.status}, ${result.ref.message}"); + if (StatusEnum.fromValue(result.ref.status) != StatusEnum.OkStatus) { + throw "VLMInference open error: ${result.ref.status} ${result.ref.message.toDartString()}"; + } + + return VLMInference(result); + } + + Future setListener(void Function(String) callback) async{ + int instanceAddress = instance.ref.value.address; + void localCallback(Pointer ptr) { + if (StatusEnum.fromValue(ptr.ref.status) != StatusEnum.OkStatus) { + // TODO(RHeckerIntel): instead of throw, call an onError callback. + throw "VLM Callback error: ${ptr.ref.status} ${ptr.ref.message.toDartString()}"; + } + callback(ptr.ref.value.toDartString()); + vlmOV.freeStatusOrString(ptr); + } + nativeListener?.close(); + nativeListener = NativeCallable.listener(localCallback); + final status = vlmOV.vlmInferenceSetListener(Pointer.fromAddress(instanceAddress), nativeListener!.nativeFunction); + if (StatusEnum.fromValue(status.ref.status) != StatusEnum.OkStatus) { + // TODO(RHeckerIntel): instead of throw, call an onError callback. + throw "VLM setListener error: ${status.ref.status} ${status.ref.message.toDartString()}"; + } + vlmOV.freeStatus(status); + } + + + Future prompt( + String message, int maxNewTokens) async { + int instanceAddress = instance.ref.value.address; + final result = await Isolate.run(() { + final messagePtr = message.toNativeUtf8(); + final status = vlmOV.vlmInferencePrompt( + Pointer.fromAddress(instanceAddress), + messagePtr, + maxNewTokens); + calloc.free(messagePtr); + return status; + }); + + if (StatusEnum.fromValue(result.ref.status) != StatusEnum.OkStatus) { + var msg = result.ref.message; + var status = result.ref.status; + var dStr = msg.toDartString(); + + throw "VLMInference prompt error: $status $dStr"; + } + + return VLMModelResponse( + result.ref.value.toDartString(), result.ref.metrics); + } + + + void setImagePaths(List paths) { + // Convert Dart strings to C strings + final cStrings = paths.map((str) => str.toNativeUtf8()).toList(); + + // Create a pointer to the array of C strings + final pointerToCStrings = malloc>(cStrings.length); + for (var i = 0; i < cStrings.length; i++) { + pointerToCStrings[i] = cStrings[i]; + } + + final status = vlmOV.vlmInferenceSetImagePaths(instance.ref.value, pointerToCStrings, cStrings.length); + + if (StatusEnum.fromValue(status.ref.status) != StatusEnum.OkStatus) { + throw "Close error: ${status.ref.status} ${status.ref.message.toDartString()}"; + } + vlmOV.freeStatus(status); + } + + void forceStop() { + final status = vlmOV.vlmInferenceStop(instance.ref.value); + + if (StatusEnum.fromValue(status.ref.status) != StatusEnum.OkStatus) { + throw "VLM Force Stop error: ${status.ref.status} ${status.ref.message.toDartString()}"; + } + } + + + void close() { + final status = vlmOV.vlmInferenceClose(instance.ref.value); + + if (StatusEnum.fromValue(status.ref.status) != StatusEnum.OkStatus) { + throw "Close error: ${status.ref.status} ${status.ref.message.toDartString()}"; + } + vlmOV.freeStatus(status); + } +} diff --git a/lib/pages/models/inference.dart b/lib/pages/models/inference.dart index 01847cab..00859504 100644 --- a/lib/pages/models/inference.dart +++ b/lib/pages/models/inference.dart @@ -7,6 +7,7 @@ import 'package:inference/pages/computer_vision/computer_vision.dart'; import 'package:inference/pages/text_generation/text_generation.dart'; import 'package:inference/pages/text_to_image/text_to_image_page.dart'; import 'package:inference/pages/transcription/transcription.dart'; +import 'package:inference/pages/vlm/vlm_page.dart'; import 'package:inference/project.dart'; class InferencePage extends StatelessWidget { @@ -24,6 +25,8 @@ class InferencePage extends StatelessWidget { return TranscriptionPage(project); case ProjectType.textToImage: return TextToImagePage(project); + case ProjectType.vlm: + return VLMPage(project); } } diff --git a/lib/pages/vlm/live_inference_pane.dart b/lib/pages/vlm/live_inference_pane.dart new file mode 100644 index 00000000..bd565d1b --- /dev/null +++ b/lib/pages/vlm/live_inference_pane.dart @@ -0,0 +1,82 @@ +// Copyright (c) 2024 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 + +import 'package:fluent_ui/fluent_ui.dart'; +import 'package:inference/widgets/grid_container.dart'; +import 'package:inference/pages/vlm/providers/vlm_inference_provider.dart'; +import 'package:inference/pages/vlm/widgets/model_properties.dart'; +import 'package:inference/pages/vlm/widgets/toolbar_text_input.dart'; +import 'package:inference/pages/vlm/widgets/vlm_chat_area.dart'; +import 'package:inference/pages/vlm/widgets/vertical_rule.dart'; +import 'package:inference/theme_fluent.dart'; +import 'package:provider/provider.dart'; +import 'package:inference/widgets/device_selector.dart'; + +class VLMLiveInferencePane extends StatefulWidget { + const VLMLiveInferencePane({super.key}); + + @override + State createState() => _PlaygroundState(); +} + +class _PlaygroundState extends State { + VLMInferenceProvider provider() => + Provider.of(context, listen: false); + + + @override + Widget build(BuildContext context) { + final theme = FluentTheme.of(context); + + const vlmChatArea = VLMChatArea(); + + return Consumer(builder: (context, inference, child) { + return Row( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Expanded( + child: Column( + children: [ + SizedBox( + height: 64, + child: GridContainer( + child: Padding( + padding: const EdgeInsets.symmetric(horizontal: 16), + child: Row( + children: [ + const DeviceSelector(), + const Padding( + padding: EdgeInsets.symmetric(vertical: 16), + child: VerticalRule()), + ToolbarTextInput( + marginLeft: 0, + labelText: "Max new tokens", + suffix: "", + initialValue: provider().maxTokens, + roundPowerOfTwo: true, + onChanged: (value) { + provider().maxTokens = value; + }), + ], + ), + ), + ), + ), + Expanded( + child: GridContainer( + color: backgroundColor.of(theme), + child: Builder(builder: (context) { + return vlmChatArea; + }), + ), + ) + ], + ), + ), + const ModelProperties(), + ], + ); + }); + } +} diff --git a/lib/pages/vlm/performance_metrics_pane.dart b/lib/pages/vlm/performance_metrics_pane.dart new file mode 100644 index 00000000..c0f4b2f0 --- /dev/null +++ b/lib/pages/vlm/performance_metrics_pane.dart @@ -0,0 +1,67 @@ +// Copyright (c) 2024 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 + +import 'package:fluent_ui/fluent_ui.dart'; +import 'package:inference/pages/vlm/providers/vlm_inference_provider.dart'; +import 'package:inference/pages/vlm/widgets/vlm_metrics_grid.dart'; +import 'package:provider/provider.dart'; + +class VLMPerformanceMetricsPane extends StatefulWidget { + const VLMPerformanceMetricsPane({super.key}); + + @override + State createState() => _VLMPerformanceMetricsPaneState(); +} + +class _VLMPerformanceMetricsPaneState extends State { + + @override + void initState() { + super.initState(); + final provider = Provider.of(context, listen: false); + if (provider.metrics == null) { + provider.loaded.future.then((_) { + provider.message("Generate OpenVINO logo"); + }); + } + } + + @override + Widget build(BuildContext context) { + return Consumer(builder: (context, inference, child) { + if (inference.metrics == null) { + return Center( + child: Column( + mainAxisAlignment: MainAxisAlignment.center, + children: [ + Image.asset('images/intel-loading.gif', width: 100), + const Text("Running benchmark prompt...") + ], + ) + ); + } + + final metrics = inference.metrics!; + + return Container( + decoration: const BoxDecoration( + shape: BoxShape.rectangle, + borderRadius: BorderRadius.all(Radius.circular(8)), + ), + child: Padding( + padding: const EdgeInsets.all(30.0), + child: Column( + mainAxisAlignment: MainAxisAlignment.start, + crossAxisAlignment: CrossAxisAlignment.center, + children: [ + VLMMetricsGrid(metrics: metrics), + ], + ), + ), + ); + }); + } +} + + diff --git a/lib/pages/vlm/providers/vlm_inference_provider.dart b/lib/pages/vlm/providers/vlm_inference_provider.dart new file mode 100644 index 00000000..b3338912 --- /dev/null +++ b/lib/pages/vlm/providers/vlm_inference_provider.dart @@ -0,0 +1,239 @@ +// Copyright (c) 2024 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 + +import 'dart:async'; +import 'dart:isolate'; +import 'dart:typed_data'; +import 'dart:ui' as ui; +import 'package:fluent_ui/fluent_ui.dart'; +import 'package:inference/interop/generated_bindings.dart'; +import 'package:inference/interop/vlm_inference.dart'; +import 'package:inference/project.dart'; + +enum Speaker { assistant, user } + + +class Message { + final Speaker speaker; + final String message; + final VLMMetrics? metrics; + final bool allowedCopy; // Don't allow loading images to be copied + + const Message(this.speaker, this.message, this.metrics, this.allowedCopy); +} + +class VLMInferenceProvider extends ChangeNotifier { + Completer loaded = Completer(); + + Project? _project; + String? _device; + + Project? get project => _project; + + String? get device => _device; + + VLMMetrics? get metrics => _messages.lastOrNull?.metrics; + + int _maxTokens = 100; + + int get maxTokens => _maxTokens; + + set maxTokens(int v) { + _maxTokens = v; + notifyListeners(); + } + + VLMInference? _inference; + final stopWatch = Stopwatch(); + int n = 0; + + VLMInferenceProvider(Project? project, String? device) { + _project = project; + _device = device; + + if (project != null && device != null) { + print("instantiating project: ${project.name}"); + print(project.storagePath); + print(device); + } + } + + Future init() async { + + _inference = await VLMInference.init(project!.storagePath, device!) + ..setListener(onMessage); + + loaded.complete(); + if (hasListeners) { + notifyListeners(); + } + } + + void onMessage(String word) { + stopWatch.stop(); + if (n == 0) { // dont count first token since it's slow. + stopWatch.reset(); + } + + double timeElapsed = stopWatch.elapsedMilliseconds.toDouble(); + double averageElapsed = (n == 0 ? 0.0 : timeElapsed / n); + if (n == 0) { + _response = word; + } else { + _response = _response! + word; + } + _speed = averageElapsed; + if (hasListeners) { + notifyListeners(); + } + stopWatch.start(); + n++; + } + + + bool sameProps(Project? project, String? device) { + return _project == project && _device == device; + } + + bool get initialized => loaded.isCompleted; + final List _messages = []; + + double? _speed; + + double? get speed => _speed; + + set speed(double? speed) { + _speed = speed; + notifyListeners(); + } + + String? _response; + + String? get response => _response; + + set response(String? response) { + _response = response; + notifyListeners(); + } + + String get task { + return "Image Generation"; + } + + Message? get interimResponse { + if (_response == null) { + return null; + } + + return Message(Speaker.assistant, response!, null, false); + } + + List get messages { + if (interimResponse == null) { + return _messages; + } + return [..._messages, interimResponse!]; + } + + Future createImage(Uint8List bytes) async { + return await decodeImageFromList(bytes); + } + + Future message(String message) async { + _response = "..."; + + _messages.add(Message(Speaker.user, message, null, false)); + notifyListeners(); + + final response = await _inference!.prompt(message, maxTokens); + + if (_messages.isNotEmpty) { + _messages.add(Message(Speaker.assistant, response.content, response.metrics, true)); + } + _response = null; + + n = 0; + if (hasListeners) { + notifyListeners(); + } + } + + void setImagePaths(List paths) { + _inference?.setImagePaths(paths); + } + + + void close() { + _messages.clear(); + _inference?.close(); + _response = null; + if (_inference != null) { + _inference!.close(); + } + } + + void forceStop() { + _inference?.forceStop(); + if (_response != '...') { + _messages.add(Message(Speaker.assistant, _response!, null, true)); + } + _response = null; + if (hasListeners) { + notifyListeners(); + } + } + + void reset() { + //_inference?.close(); + _inference?.forceStop(); + // _inference?.clearHistory(); + _messages.clear(); + _response = null; + notifyListeners(); + } + + + Future _closeInferenceInIsolate(dynamic inference) async { + final receivePort = ReceivePort(); + + // Spawn an isolate and pass the SendPort and inference + await Isolate.spawn((List args) { + final SendPort sendPort = args[0]; + final dynamic inference = args[1]; + try { + inference?.close(); // Perform the blocking operation + } catch (e) { + print("Error closing inference: $e"); + } finally { + sendPort.send(null); // Notify that the operation is complete + } + }, [receivePort.sendPort, inference]); + + // Wait for the isolate to complete + await receivePort.first; + } + + Future _waitForLoadCompletion() async { + if (!loaded.isCompleted) { + print("Still loading model, await disposal"); + await loaded.future; + } + } + + @override + void dispose() async { + // Wait for model to finish loading + await _waitForLoadCompletion(); + + if (_inference != null) { + print("Closing inference"); + await _closeInferenceInIsolate(_inference!); + print("Closing inference done"); + } else { + close(); + } + + super.dispose(); // Always call super.dispose() + } +} diff --git a/lib/pages/vlm/vlm_page.dart b/lib/pages/vlm/vlm_page.dart new file mode 100644 index 00000000..3076aa2c --- /dev/null +++ b/lib/pages/vlm/vlm_page.dart @@ -0,0 +1,133 @@ +// Copyright (c) 2024 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 + +import 'package:fluent_ui/fluent_ui.dart'; +import 'package:flutter_svg/svg.dart'; +import 'package:inference/pages/vlm/live_inference_pane.dart'; +import 'package:inference/pages/vlm/providers/vlm_inference_provider.dart'; +import 'package:inference/pages/vlm/performance_metrics_pane.dart'; +import 'package:inference/project.dart'; +import 'package:inference/providers/preference_provider.dart'; +import 'package:inference/utils.dart'; +import 'package:inference/widgets/controls/close_model_button.dart'; +import 'package:provider/provider.dart'; + +class VLMPage extends StatefulWidget { + final Project project; + const VLMPage(this.project, {super.key}); + + @override + State createState() => _VLMPageState(); +} + +class _VLMPageState extends State { + + + int selected = 0; + @override + Widget build(BuildContext context) { + final theme = FluentTheme.of(context); + final updatedTheme = theme.copyWith( + navigationPaneTheme: theme.navigationPaneTheme.merge(NavigationPaneThemeData( + backgroundColor: theme.scaffoldBackgroundColor, + )) + ); + final textColor = theme.typography.body?.color ?? Colors.black; + + const inferencePane = VLMLiveInferencePane(); + const metricsPane = VLMPerformanceMetricsPane(); + return ChangeNotifierProxyProvider( + lazy: false, + create: (_) { + final device = Provider.of(context, listen: false).device; + return VLMInferenceProvider(widget.project, device)..init(); + }, + update: (_, preferences, imageInferenceProvider) { + if (imageInferenceProvider != null && imageInferenceProvider.sameProps(widget.project, preferences.device)) { + return imageInferenceProvider; + } + return VLMInferenceProvider(widget.project, preferences.device)..init(); + }, + + child: Stack( + children: [ + FluentTheme( + data: updatedTheme, + child: NavigationView( + pane: NavigationPane( + size: const NavigationPaneSize(topHeight: 64), + header: Row( + children: [ + Padding( + padding: const EdgeInsets.only(left: 12.0), + child: ClipRRect( + borderRadius: BorderRadius.circular(4.0), + child: Container( + width: 40, + height: 40, + decoration: BoxDecoration( + image: DecorationImage( + image: widget.project.thumbnailImage(), + fit: BoxFit.cover), + ), + ), + ), + ), + Padding( + padding: const EdgeInsets.symmetric(horizontal: 16), + child: Text(widget.project.name, + style: const TextStyle(fontSize: 20, fontWeight: FontWeight.bold), + ), + ), + ], + ), + //customPane: CustomNavigationPane(), + selected: selected, + onChanged: (i) => setState(() {selected = i;}), + displayMode: PaneDisplayMode.top, + items: [ + PaneItem( + icon: SvgPicture.asset("images/playground.svg", + colorFilter: ColorFilter.mode(textColor, BlendMode.srcIn), + width: 15, + ), + title: const Text("Live Inference"), + body: inferencePane, + ), + PaneItem( + icon: SvgPicture.asset("images/stats.svg", + colorFilter: ColorFilter.mode(textColor, BlendMode.srcIn), + width: 15, + ), + title: const Text("Performance metrics"), + body: metricsPane, + ), + ], + ) + ), + ), + SizedBox( + height: 64, + child: Padding( + padding: const EdgeInsets.symmetric(horizontal: 25), + child: Row( + mainAxisAlignment: MainAxisAlignment.end, + children: [ + Padding( + padding: const EdgeInsets.all(4), + child: FilledButton( + child: const Text("Export model"), + onPressed: () => downloadProject(widget.project), + ), + ), + const CloseModelButton(), + ] + ), + ), + ) + ], + ), + ); + } +} diff --git a/lib/pages/vlm/widgets/device_selector.dart b/lib/pages/vlm/widgets/device_selector.dart new file mode 100644 index 00000000..1d96babc --- /dev/null +++ b/lib/pages/vlm/widgets/device_selector.dart @@ -0,0 +1,70 @@ +// Copyright (c) 2024 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 + +// ignore_for_file: unused_import + +import 'package:fluent_ui/fluent_ui.dart'; +import 'package:inference/providers/preference_provider.dart'; +import 'package:provider/provider.dart'; +import 'package:collection/collection.dart'; + +class DeviceSelector extends StatefulWidget { + const DeviceSelector({super.key}); + + @override + State createState() => _DeviceSelectorState(); +} + +class _DeviceSelectorState extends State { + String? selectedDevice; + + @override + void initState() { + super.initState(); + selectedDevice = Provider.of(context, listen: false).device; + } + + @override + Widget build(BuildContext context) { + return Consumer(builder: (context, preferences, child) { + return Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + const Padding( + padding: EdgeInsets.only(bottom: 16), + child: Text("Device", + style: TextStyle( + fontSize: 16, + fontWeight: FontWeight.bold, + ), + ), + ), + Column( + crossAxisAlignment: CrossAxisAlignment.stretch, + children: [ + ComboBox( + value: selectedDevice, + items: PreferenceProvider.availableDevices.map>((e) { + return ComboBoxItem( + value: e.id, + child: Text(e.name), + ); + }).toList(), + onChanged: (v) { + setState(() { + selectedDevice = v; + if (v != null) { + preferences.device = v; + } + }); + }, + ), + ], + ), + ], + ); + } + ); + } +} diff --git a/lib/pages/vlm/widgets/horizontal_rule.dart b/lib/pages/vlm/widgets/horizontal_rule.dart new file mode 100644 index 00000000..e513822b --- /dev/null +++ b/lib/pages/vlm/widgets/horizontal_rule.dart @@ -0,0 +1,29 @@ +// Copyright (c) 2024 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 + +import 'package:fluent_ui/fluent_ui.dart'; +import 'package:inference/theme_fluent.dart'; + +class HorizontalRule extends StatelessWidget { + const HorizontalRule({super.key}); + + @override + Widget build(BuildContext context) { + final theme = FluentTheme.of(context); + + return Padding( + padding: const EdgeInsets.symmetric(vertical: 20), + child: Container( + decoration: BoxDecoration( + border: Border( + bottom: BorderSide( + color: borderColor.of(theme), + width: 1, + ) + ) + ), + ), + ); + } +} diff --git a/lib/pages/vlm/widgets/image_grid.dart b/lib/pages/vlm/widgets/image_grid.dart new file mode 100644 index 00000000..03965695 --- /dev/null +++ b/lib/pages/vlm/widgets/image_grid.dart @@ -0,0 +1,128 @@ +import 'dart:io'; + +import 'package:fluent_ui/fluent_ui.dart'; +import 'package:inference/widgets/controls/drop_area.dart'; + +class ImageGrid extends StatefulWidget { + final List initialGalleryData; + final void Function(List) onFileListChange; + + const ImageGrid({ + super.key, + required this.initialGalleryData, + required this.onFileListChange, + }); + + @override + _ImageGridState createState() => _ImageGridState(); +} + +class _ImageGridState extends State { + late List galleryData; + Map hoverStates = {}; + + @override + void initState() { + super.initState(); + galleryData = List.from(widget.initialGalleryData); + } + + void onDrop(String path) { + if (!galleryData.contains(path)) { + setState(() { + galleryData.add(path); + }); + widget.onFileListChange(galleryData); + } + } + + void removeImage(int index) { + setState(() { + hoverStates.remove(index); + galleryData.removeAt(index); + + widget.onFileListChange(galleryData); + }); + } + + @override + Widget build(BuildContext context) { + var width = MediaQuery.of(context).size.width; + var height = MediaQuery.of(context).size.height; + + return DropArea( + showChild: galleryData.isNotEmpty, + onUpload: onDrop, + type: "image", + extensions: const ["jpg", "jpeg", "bmp", "png", "tif", "tiff"], + child: GridView.count( + primary: false, + padding: const EdgeInsets.all(20), + crossAxisSpacing: 10, + mainAxisSpacing: 10, + crossAxisCount: 7, + children: List.generate(galleryData.length, (index) { + String path = galleryData[index]; + bool isLocalFile = File(path).existsSync(); // Check if the file exists locally + + bool isHovered = hoverStates[index] ?? false; // Get hover state + return MouseRegion( + onEnter: (_) { + setState(() { + hoverStates[index] = true; + }); + }, + onHover: (_) { + setState(() { + hoverStates[index] = true; + }); + }, + onExit: (_) { + setState(() { + hoverStates[index] = false; + }); + }, + child: Stack( + children: [ + Container( + width: width * 0.3, + height: height * 0.3, + decoration: BoxDecoration( + borderRadius: BorderRadius.circular(10), + color: Colors.black, + image: DecorationImage( + image: isLocalFile + ? FileImage(File(path)) // Load local file + : NetworkImage(path) as ImageProvider, // Load from network + fit: BoxFit.cover, + ), + ), + ), + if (isHovered) + Positioned( + top: 5, + right: 5, + child: GestureDetector( + onTap: () => removeImage(index), + child: Container( + padding: const EdgeInsets.all(8), + decoration: BoxDecoration( + color: Colors.black.withAlpha(200), + shape: BoxShape.circle, + ), + child: const Icon( + FluentIcons.cancel, + size: 12, + color: Colors.white, + ), + ), + ), + ), + ], + ), + ); + }), + ), + ); + } +} \ No newline at end of file diff --git a/lib/pages/vlm/widgets/model_properties.dart b/lib/pages/vlm/widgets/model_properties.dart new file mode 100644 index 00000000..ffda0246 --- /dev/null +++ b/lib/pages/vlm/widgets/model_properties.dart @@ -0,0 +1,75 @@ +// Copyright (c) 2024 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 + +import 'package:fluent_ui/fluent_ui.dart'; +import 'package:inference/pages/vlm/providers/vlm_inference_provider.dart'; +import 'package:inference/utils.dart'; +import 'package:inference/widgets/grid_container.dart'; +import 'package:inference/widgets/model_propery.dart'; +import 'package:intl/intl.dart'; +import 'package:provider/provider.dart'; + +class ModelProperties extends StatelessWidget { + const ModelProperties({super.key}); + + @override + Widget build(BuildContext context) { + return Consumer(builder: (context, inference, child) { + Locale locale = Localizations.localeOf(context); + final formatter = NumberFormat.percentPattern(locale.languageCode); + + return SizedBox( + width: 280, + child: GridContainer( + padding: const EdgeInsets.symmetric(vertical: 18, horizontal: 24), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + const Text("Model parameters", style: TextStyle( + fontSize: 20, + )), + Container( + padding: const EdgeInsets.only(top: 16), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + ModelProperty( + title: "Model name", + value: inference.project!.name, + ), + ModelProperty( + title: "Task", + value: inference.project!.taskName(), + ), + ModelProperty( + title: "Architecture", + value: inference.project!.architecture, + ), + ModelProperty( + title: "Size", + value: inference.project!.size?.readableFileSize() ?? "", + ), + Builder( + builder: (context) { + if (inference.project!.tasks.first.performance == null) { + return Container(); + } + return ModelProperty( + title: "Accuracy", + value: formatter.format(inference.project!.tasks.first.performance!.score) + ); + } + ), + ], + ), + ) + ], + ) + ), + ); + } + ); + } +} + diff --git a/lib/pages/vlm/widgets/toolbar_text_input.dart b/lib/pages/vlm/widgets/toolbar_text_input.dart new file mode 100644 index 00000000..3113bf9d --- /dev/null +++ b/lib/pages/vlm/widgets/toolbar_text_input.dart @@ -0,0 +1,119 @@ +// Copyright (c) 2024 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 + +import 'dart:math'; + +import 'package:fluent_ui/fluent_ui.dart'; + +class ToolbarTextInput extends StatefulWidget { + final String labelText; + final String suffix; + final int marginLeft; + final int initialValue; + final bool? roundPowerOfTwo; + final void Function(int)? onChanged; + + const ToolbarTextInput({ + super.key, + required this.labelText, + required this.suffix, + required this.marginLeft, + required this.initialValue, + this.roundPowerOfTwo, + this.onChanged, + }); + + @override + State createState() => _ToolbarTextInputState(); +} + +class _ToolbarTextInputState extends State { + final TextEditingController _controller = TextEditingController(); + final FocusNode _focusNode = FocusNode(); + + @override + void initState() { + super.initState(); + _controller.text = widget.initialValue.toString(); // Set the initial text + if (widget.roundPowerOfTwo ?? false) { + _focusNode.addListener(_onFocusChange); // Listen for focus changes + } + } + + void _onFocusChange() { + if (!_focusNode.hasFocus) { + // When the TextBox loses focus, round and update + final inputValue = int.tryParse(_controller.text.replaceAll(RegExp(r'[^0-9]'), '')) ?? 0; + // final rounded = _nearestPowerOfTwo(inputValue); + // + // _controller.text = rounded.toString(); + widget.onChanged!(inputValue); + + } + } + + /// Calculate the nearest power of 2 for a given number + int _nearestPowerOfTwo(int value) { + if (value <= 0) return 1; // Smallest power of 2 is 1 + int lowerPower = pow(2, (log(value) / log(2)).floor()).toInt(); + int higherPower = pow(2, (log(value) / log(2)).ceil()).toInt(); + return (value - lowerPower < higherPower - value) ? lowerPower : higherPower; + } + + + void _onTextChanged(String value) { + // Keep only digits in the input + final newValue = value.replaceAll(RegExp(r'[^0-9]'), ''); + if (value != newValue) { + // Update the controller text and cursor position + _controller.text = newValue; + _controller.selection = TextSelection.collapsed(offset: newValue.length); + } + + if (widget.onChanged != null) { + if (newValue.isNotEmpty) { + // Parse the integer and call the callback + widget.onChanged!(int.parse(newValue)); + } else { + // Optionally handle empty input + widget.onChanged!(0); // You can choose to pass null or handle differently + } + } + } + + + @override + Widget build(BuildContext context) { + return Row( + crossAxisAlignment: CrossAxisAlignment.center, + children: [ + Padding( + padding: EdgeInsets.only(left: 10 + widget.marginLeft.toDouble(), right: 10), + child: Text( + widget.labelText, + style: const TextStyle( + fontSize: 14, + fontWeight: FontWeight.normal, + ), + ), + ), + SizedBox( + width: 85, + height: 30, + child: TextBox( + controller: _controller, + focusNode: _focusNode, + maxLines: 1, + keyboardType: TextInputType.number, // Ensure numeric keyboard + suffix: Padding( + padding: const EdgeInsets.only(right: 8.0), + child: Text(widget.suffix), + ), + onChanged: _onTextChanged, // Custom handler for integer validation + ), + ), + ], + ); + } +} diff --git a/lib/pages/vlm/widgets/vertical_rule.dart b/lib/pages/vlm/widgets/vertical_rule.dart new file mode 100644 index 00000000..2072adbb --- /dev/null +++ b/lib/pages/vlm/widgets/vertical_rule.dart @@ -0,0 +1,29 @@ +// Copyright (c) 2024 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 + +import 'package:fluent_ui/fluent_ui.dart'; +import 'package:inference/theme_fluent.dart'; + +class VerticalRule extends StatelessWidget { + const VerticalRule({super.key}); + + @override + Widget build(BuildContext context) { + final theme = FluentTheme.of(context); + + return Padding( + padding: const EdgeInsets.symmetric(horizontal: 20), + child: Container( + decoration: BoxDecoration( + border: Border( + left: BorderSide( + color: borderColor.of(theme), + width: 1, + ) + ) + ), + ), + ); + } +} diff --git a/lib/pages/vlm/widgets/vlm_chat_area.dart b/lib/pages/vlm/widgets/vlm_chat_area.dart new file mode 100644 index 00000000..dd4f305d --- /dev/null +++ b/lib/pages/vlm/widgets/vlm_chat_area.dart @@ -0,0 +1,412 @@ +// Copyright (c) 2024 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 + +import 'package:fluent_ui/fluent_ui.dart'; +import 'package:flutter_svg/svg.dart'; +import 'package:inference/interop/openvino_bindings.dart'; +import 'package:inference/pages/vlm/providers/vlm_inference_provider.dart'; +import 'package:inference/pages/vlm/widgets/horizontal_rule.dart'; +import 'package:inference/pages/vlm/widgets/vlm_metrics_grid.dart'; +import 'package:inference/theme_fluent.dart'; +import 'package:provider/provider.dart'; +import 'package:super_clipboard/super_clipboard.dart'; + +import 'image_grid.dart'; + +class VLMChatArea extends StatefulWidget { + const VLMChatArea({super.key}); + + @override + State createState() => VLMChatAreaState(); +} + +class VLMChatAreaState extends State { + final _controller = TextEditingController(); + final _scrollController = ScrollController(); + bool attachedToBottom = true; + + void handleFileListChange(List paths) { + final vlm = provider(); + if (!vlm.initialized) { + return; + } + vlm.setImagePaths(paths); + } + + void jumpToBottom({offset = 0}) { + if (_scrollController.hasClients) { + _scrollController + .jumpTo(_scrollController.position.maxScrollExtent + offset); + } + } + + void message(String message) async { + if (message.isEmpty) { + return; + } + final vlm = provider(); + if (!vlm.initialized) { + return; + } + + if (vlm.response != null) { + return; + } + _controller.text = ""; + jumpToBottom(offset: 110); //move to bottom including both + vlm.message(message); + } + + VLMInferenceProvider provider() => + Provider.of(context, listen: false); + + @override + void initState() { + super.initState(); + _scrollController.addListener(() { + setState(() { + attachedToBottom = _scrollController.position.pixels + 0.001 >= + _scrollController.position.maxScrollExtent; + }); + }); + } + + @override + void dispose() { + super.dispose(); + _controller.dispose(); + _scrollController.dispose(); + } + + @override + Widget build(BuildContext context) { + return Consumer( + builder: (context, inference, child) { + WidgetsBinding.instance.addPostFrameCallback((_) { + if (attachedToBottom) { + jumpToBottom(); + } + }); + + final theme = FluentTheme.of(context); + final textColor = theme.typography.body?.color ?? Colors.black; + + return Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Builder(builder: (context) { + if (!inference.initialized) { + return Expanded( + child: Center( + child: Column( + mainAxisAlignment: MainAxisAlignment.center, + children: [ + Image.asset('images/intel-loading.gif', width: 100), + const Text("Loading model...") + ], + )), + ); + } + return Expanded( + child: Container( + decoration: const BoxDecoration( + shape: BoxShape.rectangle, + borderRadius: BorderRadius.all(Radius.circular(8)), + ), + child: Column( + children: [ + SizedBox( + height: 220, + child: ImageGrid( + initialGalleryData: const [], + onFileListChange: handleFileListChange, + )), + const HorizontalRule(), + Expanded( + child: Builder(builder: (context) { + if (inference.messages.isEmpty) { + return Center( + child: Text( + "Type a message to ${inference.project?.name ?? "assistant"}")); + } + return Stack( + alignment: Alignment.topCenter, + children: [ + SingleChildScrollView( + controller: _scrollController, + child: Padding( + padding: const EdgeInsets.all(20), + child: Column( + + // mainAxisAlignment: MainAxisAlignment.start, + crossAxisAlignment: + CrossAxisAlignment.stretch, + children: inference.messages.map((message) { + switch (message.speaker) { + case Speaker.user: + return UserInputMessage(message); + case Speaker.assistant: + return GeneratedResponseMessage( + message, + inference.project! + .thumbnailImage(), + inference.project!.name); + } + }).toList()), + ), + ), + Positioned( + bottom: 10, + child: Builder(builder: (context) { + if (attachedToBottom) { + return Container(); + } + return Center( + child: Padding( + padding: const EdgeInsets.only(top: 2.0), + child: SizedBox( + width: 200, + height: 40, + // Adjusted height to match Fluent UI's button dimensions + child: FilledButton( + child: const Text("Jump to bottom"), + onPressed: () { + jumpToBottom(); + setState(() { + attachedToBottom = true; + }); + }, + ), + ), + ), + ); + }), + ), + ], + ); + }), + ), + + // SizedBox( + // height: 30, + // child: Builder( + // builder: (context) { + // if (inference.interimResponse == null){ + // return Container(); + // } + // return Center( + // child: OutlinedButton.icon( + // onPressed: () => inference.forceStop(), + // icon: const Icon(Icons.stop), + // label: const Text("Stop responding") + // ), + // ); + // } + // ), + // ), + Padding( + padding: const EdgeInsets.only( + left: 45, right: 45, top: 10, bottom: 25), + child: SizedBox( + height: 40, + child: Row( + crossAxisAlignment: CrossAxisAlignment.center, + children: [ + Padding( + padding: const EdgeInsets.only(right: 8), + child: IconButton( + icon: SvgPicture.asset( + "images/clear.svg", + width: 20, + colorFilter: ColorFilter.mode(textColor, BlendMode.srcIn), + ), + onPressed: () => inference.reset(), + ), + ), + Expanded( + child: TextBox( + maxLines: null, + keyboardType: TextInputType.text, + placeholder: "Ask me anything...", + controller: _controller, + onSubmitted: message, + style: const TextStyle( + fontSize: 14, + ), + suffix: IconButton( + icon: Icon( + FluentIcons.send, + color: + (inference.interimResponse == null + ? textColor + : textColor.withOpacity(0.2)), + ), + onPressed: () => + inference.interimResponse != null ? null : + message(_controller.text), + ), + ), + ) + ], + ), + ), + ), + ], + ), + ), + ); + }), + ], + ); + }); + } +} + +class UserInputMessage extends StatelessWidget { + final Message message; + + const UserInputMessage(this.message, {super.key}); + + @override + Widget build(BuildContext context) { + return Padding( + padding: const EdgeInsets.only(bottom: 20), + child: Column( + crossAxisAlignment: CrossAxisAlignment.end, + children: [ + Container( + constraints: const BoxConstraints(maxWidth: 500), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Padding( + padding: const EdgeInsets.only(right: 30.0), + child: MessageWidget( + message: message.message, + innerPadding: 8, + isSender: true), + ), + ])) + ], + ), + ); + } +} + +class GeneratedResponseMessage extends StatelessWidget { + final Message message; + final ImageProvider icon; + final String name; + + const GeneratedResponseMessage(this.message, this.icon, this.name, {super.key}); + + @override + Widget build(BuildContext context) { + return Padding( + padding: const EdgeInsets.only(bottom: 20), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Container( + constraints: const BoxConstraints(maxWidth: 500), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Padding( + padding: const EdgeInsets.only(right: 30.0), + child: MessageWidget( + message: message.message, + innerPadding: 8, + isSender: false), + ), + ])) + ], + ), + ); + } +} + + +void showMetricsDialog(BuildContext context, VLMMetrics metrics) async { + await showDialog( + context: context, + barrierDismissible: true, + builder: (context) => ContentDialog( + constraints: const BoxConstraints(maxWidth: double.infinity), + content: VLMMetricsGrid(metrics: metrics), + ), + ); +} + +class RoundedPicture extends StatelessWidget { + final String name; + final ImageProvider icon; // Icon widget provided + + const RoundedPicture({super.key, required this.name, required this.icon}); + + @override + Widget build(BuildContext context) { + return ClipOval( + child: Container( + width: 40, + height: 40, + decoration: BoxDecoration( + image: DecorationImage( + image: icon, // Adjust this to fit your `name` field + fit: BoxFit.cover, + ), + ), + ), + ); + } +} + +class MessageWidget extends StatelessWidget { + final String? message; + final VLMMetrics? metrics; // If not set, no copy-paste options. + final double innerPadding; + final bool isSender; + + const MessageWidget( + {super.key, + this.message, + this.metrics, + required this.innerPadding, + required this.isSender}); + + @override + Widget build(BuildContext context) { + final theme = FluentTheme.of(context); + final textColor = theme.typography.body?.color ?? Colors.black; + + return Container( + decoration: BoxDecoration( + color: + isSender ? userMessageColor.of(theme) : modelMessageColor.of(theme), + borderRadius: const BorderRadius.only( + topLeft: Radius.circular(4.0), + topRight: Radius.circular(4.0), + bottomLeft: Radius.circular(4.0), + bottomRight: Radius.circular(4.0), + ), + ), + padding: EdgeInsets.all(innerPadding), + child: Column(children: [ + message != null + ? SelectableText( + message!, + style: TextStyle( + color: textColor, + fontSize: 14, + fontWeight: FontWeight.w400, + ), + ) + : const SizedBox.shrink(), + ]), + ); + } +} + + + diff --git a/lib/pages/vlm/widgets/vlm_metrics_grid.dart b/lib/pages/vlm/widgets/vlm_metrics_grid.dart new file mode 100644 index 00000000..d406cb4a --- /dev/null +++ b/lib/pages/vlm/widgets/vlm_metrics_grid.dart @@ -0,0 +1,39 @@ +// Copyright (c) 2024 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 + +import 'package:fluent_ui/fluent_ui.dart'; +import 'package:inference/interop/openvino_bindings.dart'; +import 'package:inference/widgets/metrics_card.dart'; +import 'package:intl/intl.dart'; + + +class VLMMetricsGrid extends StatelessWidget { + final VLMMetrics metrics; + + const VLMMetricsGrid({super.key, required this.metrics}); + + @override + Widget build(BuildContext context) { + Locale locale = Localizations.localeOf(context); + final nf = NumberFormat.decimalPatternDigits( + locale: locale.languageCode, decimalDigits: 0); + + return Row( + mainAxisAlignment: MainAxisAlignment.spaceBetween, + mainAxisSize: MainAxisSize.min, // Wrap content + children: [ + MetricsCard( + header: "Time to load model", + value: nf.format(metrics.load_time), + unit: "ms", + ), + MetricsCard( + header: "Time to generate image", + value: nf.format(metrics.generate_time), + unit: "ms", + ) + ], + ); + } +} diff --git a/lib/project.dart b/lib/project.dart index a0415317..edc303b1 100644 --- a/lib/project.dart +++ b/lib/project.dart @@ -113,7 +113,7 @@ class Task { } } -enum ProjectType { image, text, textToImage, speech } +enum ProjectType { image, text, textToImage, speech, vlm } ProjectType parseProjectType(String name) { if (name == "image") { return ProjectType.image; @@ -124,6 +124,9 @@ ProjectType parseProjectType(String name) { if (name == "textToImage" || name == "text-to-image"){ return ProjectType.textToImage; } + if (name == "vlm"){ + return ProjectType.vlm; + } if (name == "speech") { return ProjectType.speech; } @@ -137,6 +140,8 @@ String projectTypeToString(ProjectType type) { return "text"; case ProjectType.textToImage: return "textToImage"; + case ProjectType.vlm: + return "vlm"; case ProjectType.image: return "image"; case ProjectType.speech: diff --git a/lib/widgets/controls/drop_area.dart b/lib/widgets/controls/drop_area.dart index 16e895fd..3afc2bd9 100644 --- a/lib/widgets/controls/drop_area.dart +++ b/lib/widgets/controls/drop_area.dart @@ -29,6 +29,7 @@ class DropArea extends StatefulWidget { class _DropAreaState extends State { bool _showReleaseMessage = false; + void handleDrop(DropDoneDetails details) { if (details.files.isNotEmpty) { widget.onUpload(details.files[0].path); @@ -56,48 +57,62 @@ class _DropAreaState extends State { final theme = FluentTheme.of(context); return DropTarget( - onDragDone: (details) => handleDrop(details), - onDragExited: (val) => hideReleaseMessage(), - onDragEntered: (val) => showReleaseMessage(), - - child: Builder( - builder: (context) { - if (widget.showChild && !_showReleaseMessage) { - return widget.child ?? Container(); - } - - final String text = _showReleaseMessage - ? "Release to drop media" - : "Drag and drop ${widget.type} here for testing"; - - - return Center( - child: SizedBox( - height: 310, - child: Column( - crossAxisAlignment: CrossAxisAlignment.center, - mainAxisAlignment: MainAxisAlignment.spaceBetween, + onDragDone: handleDrop, + onDragExited: (_) => hideReleaseMessage(), + onDragEntered: (_) => showReleaseMessage(), + + child: + Padding( + padding: const EdgeInsets.all(20), + child: Center( + child: SizedBox( + height: 310, + child: Builder( + builder: (context) { + if (widget.showChild && !_showReleaseMessage) { + // If we have a child and aren't showing the drop message, display it. + return widget.child ?? const SizedBox.shrink(); + } + + final theme = FluentTheme.of(context); + final String text = _showReleaseMessage + ? "Release to drop media" + : "Drag and drop ${widget.type} here for testing"; + + return Column( + mainAxisAlignment: MainAxisAlignment.spaceAround, + spacing: 10, children: [ - Text(text, style: const TextStyle(fontSize: 26, fontWeight: FontWeight.w500)), - (theme.brightness.isDark - ? SvgPicture.asset('images/drop.svg') - : SvgPicture.asset('images/drop_light.svg') + // Top text + Text( + text, + style: const TextStyle(fontSize: 26, fontWeight: FontWeight.w500), + textAlign: TextAlign.center, + ), + + // SVG scales to available space in the column + Expanded( + // FittedBox automatically scales its child to fit the parent's constraints + child: FittedBox( + fit: BoxFit.contain, + child: theme.brightness.isDark + ? SvgPicture.asset('images/drop.svg') + : SvgPicture.asset('images/drop_light.svg'), + ), ), - Builder( - builder: (context) { - if (widget.extensions == null) { - return Container(); - } - return Text(widget.extensions!.join(", ")); - } - ) + + // Optional file extension text at the bottom + if (widget.extensions != null) + Text(widget.extensions!.join(", ")) + else + const SizedBox.shrink(), ], - ), - ), - ); - } - ) + ); + }, + ), + ), + ), + ) ); - } -} +} \ No newline at end of file diff --git a/openvino_bindings/src/BUILD b/openvino_bindings/src/BUILD index 756d22dc..a39c3cfa 100644 --- a/openvino_bindings/src/BUILD +++ b/openvino_bindings/src/BUILD @@ -12,6 +12,7 @@ cc_library( "//src/audio:speech_to_text", "//src/mediapipe:graph_runner", "//src/tti:tti_inference", + "//src/vlm:vlm_inference", ], ) @@ -23,6 +24,7 @@ cc_library( ":bindings_deps", "//src/utils:metrics", "//src/utils:tti_metrics", + "//src/utils:vlm_metrics", ], copts = ["-fPIC"], alwayslink=1, diff --git a/openvino_bindings/src/bindings.cc b/openvino_bindings/src/bindings.cc index 28d9d0de..65a0c9ab 100644 --- a/openvino_bindings/src/bindings.cc +++ b/openvino_bindings/src/bindings.cc @@ -17,6 +17,7 @@ #include "src/mediapipe/serialization/serialization_calculators.h" #include "src/llm/llm_inference.h" #include "src/tti/tti_inference.h" +#include "src/vlm/vlm_inference.h" #include "src/utils/errors.h" #include "src/utils/utils.h" #include "src/utils/status.h" @@ -297,6 +298,85 @@ Status* ttiInferenceClose(CTTIInference instance) { return new Status{OkStatus}; } +StatusOrVLMInference* vlmInferenceOpen(const char* model_path, const char* device) { + try { + auto instance = new VLMInference(model_path, device); + return new StatusOrVLMInference{OkStatus, "", instance}; + } catch (...) { + auto except = handle_exceptions(); + printf(except->message); + return new StatusOrVLMInference{except->status, except->message}; + } +} + +Status* vlmInferenceSetListener(CVLMInference instance, VLMInferenceCallbackFunction callback) { + try { + auto lambda_callback = [callback](const std::string& word) { + callback(new StatusOrString{OkStatus, "", strdup(word.c_str())}); + }; + reinterpret_cast(instance)->set_streamer(lambda_callback); + return new Status{OkStatus, ""}; + } catch (...) { + return handle_exceptions(); + } +} + +StatusOrVLMModelResponse* vlmInferencePrompt(CVLMInference instance, const char* message, int max_new_tokens) { + try { + auto inference = reinterpret_cast(instance); + auto result = inference->prompt(message, max_new_tokens); + auto text = result.string; + auto metrics = result.metrics; + return new StatusOrVLMModelResponse{OkStatus, {}, metrics, text}; + } catch (...) { + auto except = handle_exceptions(); + return new StatusOrVLMModelResponse{except->status, except->message, {}, {}}; + } +} + +Status* vlmInferenceSetImagePaths(CVLMInference instance, const char** paths, int length) { + try { + auto inference = reinterpret_cast(instance); + + std::vector stringPaths; + stringPaths.reserve(length); + for (int i = 0; i < length; ++i) { + stringPaths.emplace_back(paths[i]); + } + + inference->setImagePaths(stringPaths); + return new Status{OkStatus}; + } catch (...) { + return new Status{ErrorStatus}; + } +} + +StatusOrBool* vlmInferenceHasModelIndex(CVLMInference instance) { + try { + bool has_chat_template = reinterpret_cast(instance)->has_model_index(); + return new StatusOrBool{OkStatus, "", has_chat_template}; + } catch (...) { + auto except = handle_exceptions(); + return new StatusOrBool{except->status, except->message}; + } +} + +Status* vlmInferenceForceStop(CVLMInference instance) { + try { + reinterpret_cast(instance)->force_stop(); + return new Status{OkStatus, ""}; + } catch (...) { + return handle_exceptions(); + } +} + +Status* vlmInferenceClose(CVLMInference instance) { + auto inference = reinterpret_cast(instance); + inference->force_stop(); + delete inference; + return new Status{OkStatus}; +} + StatusOrGraphRunner* graphRunnerOpen(const char* graph) { try { @@ -427,12 +507,16 @@ Status* handle_exceptions() { } catch(ov::Exception e) { std::string message = "OV Exception: \n"; message += e.what(); + std::cout << message << std::endl; return new Status{OpenVINOError, strdup(message.c_str())}; } catch (api_error e) { + std::cout << e.what() << std::endl; return new Status{e.status, strdup(e.additional_info.c_str())}; } catch(const std::exception& ex) { + std::cout << ex.what() << std::endl; return new Status{ErrorStatus, ex.what()}; } catch (...) { + std::cout << "Unknown exception" << std::endl; return new Status{ErrorStatus, "Unknown exception"}; } } diff --git a/openvino_bindings/src/bindings.h b/openvino_bindings/src/bindings.h index 6f177e53..a284b574 100644 --- a/openvino_bindings/src/bindings.h +++ b/openvino_bindings/src/bindings.h @@ -22,12 +22,14 @@ #include "src/utils/status.h" #include "src/utils/metrics.h" #include "utils/tti_metrics.h" +#include "utils/vlm_metrics.h" typedef void* CImageInference; typedef void* CGraphRunner; typedef void* CSpeechToText; typedef void* CLLMInference; typedef void* CTTIInference; +typedef void* CVLMInference; typedef struct { const char* id; @@ -93,6 +95,12 @@ typedef struct { CLLMInference value; } StatusOrTTIInference; +typedef struct { + enum StatusEnum status; + const char* message; + CLLMInference value; +} StatusOrVLMInference; + typedef struct { enum StatusEnum status; const char* message; @@ -116,6 +124,13 @@ typedef struct { const char* value; } StatusOrTTIModelResponse; +typedef struct { + enum StatusEnum status; + const char* message; + VLMMetrics metrics; + const char* value; +} StatusOrVLMModelResponse; + typedef struct { enum StatusEnum status; const char* message; @@ -125,6 +140,7 @@ typedef struct { typedef void (*ImageInferenceCallbackFunction)(StatusOrString*); typedef void (*LLMInferenceCallbackFunction)(StatusOrString*); +typedef void (*VLMInferenceCallbackFunction)(StatusOrString*); EXPORT void freeStatus(Status *status); EXPORT void freeStatusOrString(StatusOrString *status); @@ -159,6 +175,15 @@ EXPORT StatusOrTTIModelResponse* ttiInferencePrompt(CTTIInference instance, cons EXPORT StatusOrBool* ttiInferenceHasModelIndex(CTTIInference instance); EXPORT Status* ttiInferenceClose(CLLMInference instance); + +EXPORT StatusOrVLMInference* vlmInferenceOpen(const char* model_path, const char* device); +EXPORT Status* vlmInferenceSetListener(CVLMInference instance, VLMInferenceCallbackFunction callback); +EXPORT StatusOrVLMModelResponse* vlmInferencePrompt(CVLMInference instance, const char* message, int max_new_tokens); +EXPORT Status* vlmInferenceSetImagePaths(CVLMInference instance, const char** paths, int length); +EXPORT StatusOrBool* vlmInferenceHasModelIndex(CVLMInference instance); +EXPORT Status* vlmInferenceStop(CVLMInference instance); +EXPORT Status* vlmInferenceClose(CVLMInference instance); + EXPORT StatusOrGraphRunner* graphRunnerOpen(const char* graph); EXPORT Status* graphRunnerQueueImage(CGraphRunner instance, const char* name, int timestamp, unsigned char* image_data, const size_t data_length); EXPORT Status* graphRunnerQueueSerializationOutput(CGraphRunner instance, const char* name, int timestamp, bool json, bool csv, bool overlay); diff --git a/openvino_bindings/src/utils/BUILD b/openvino_bindings/src/utils/BUILD index 8936bd5a..d12fac22 100644 --- a/openvino_bindings/src/utils/BUILD +++ b/openvino_bindings/src/utils/BUILD @@ -31,6 +31,13 @@ cc_library( ], ) +cc_library( + name = "vlm_metrics", + hdrs = [ + "vlm_metrics.h", + ], +) + cc_library( name = "utils", srcs = [ @@ -42,6 +49,7 @@ cc_library( deps = [ ":metrics", ":tti_metrics", + ":vlm_metrics", "//third_party:openvino", ], ) diff --git a/openvino_bindings/src/utils/vlm_metrics.h b/openvino_bindings/src/utils/vlm_metrics.h new file mode 100644 index 00000000..2d7323f1 --- /dev/null +++ b/openvino_bindings/src/utils/vlm_metrics.h @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2024 Intel Corporation + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifndef VLM_METRICS_H +#define VLM_METRICS_H + +typedef struct { + float load_time; + float generate_time; +} VLMMetrics; + +typedef struct { + const char* string; + VLMMetrics metrics; +} VLMStringWithMetrics; + +#endif //VLM_METRICS_H diff --git a/openvino_bindings/src/vlm/BUILD b/openvino_bindings/src/vlm/BUILD new file mode 100644 index 00000000..f31d2267 --- /dev/null +++ b/openvino_bindings/src/vlm/BUILD @@ -0,0 +1,35 @@ +cc_library( + name = "vlm_inference", + srcs = [ + "load_image.cpp", + "vlm_inference.cc", + ], + hdrs = [ + "load_image.hpp", + "vlm_inference.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//src/image:serialization", + "//src/utils", + "//src/utils:errors", + "//src/utils:vlm_metrics", + "//third_party:opencv", + "//third_party:openvino", + "@nlohmann_json//:json", + ], +) + +cc_test( + name = "vlm_inference_test", + srcs = [ + "vlm_inference_test.cc", + ], + data = [ + "//data:models", + ], + deps = [ + ":vlm_inference", + "@gtest//:gtest_main", + ], +) diff --git a/openvino_bindings/src/vlm/load_image.cpp b/openvino_bindings/src/vlm/load_image.cpp new file mode 100644 index 00000000..b466614e --- /dev/null +++ b/openvino_bindings/src/vlm/load_image.cpp @@ -0,0 +1,82 @@ + +// Copyright (C) 2023-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#define STB_IMAGE_IMPLEMENTATION +#include "load_image.hpp" +#include + +namespace fs = std::filesystem; + +#include +#include + +std::vector utils::load_images(const std::vector& input_paths) { + std::vector images; + images.reserve(input_paths.size()); + + for (const std::string& dir_entry : input_paths) { + std::filesystem::path image_path(dir_entry); + if (!exists(image_path)) { + std::cerr << "Warning: File does not exist - " << dir_entry << std::endl; + continue; // Skip this file + } + images.push_back(load_image(image_path)); + } + return images; +} + +ov::Tensor utils::load_image(const std::filesystem::path& image_path) { + constexpr int desired_channels = 3; + + // Load the image using OpenCV + std::ifstream file(image_path, std::ios::binary); + if (!file) { + throw std::runtime_error{"Cannot access file."}; + } + + std::vector buffer((std::istreambuf_iterator(file)), std::istreambuf_iterator()); + cv::Mat cv_image = cv::imdecode(buffer, cv::IMREAD_COLOR); + + if (cv_image.empty()) { + throw std::runtime_error{"Failed to load the image."}; + } + + // Ensure the image is converted to the desired number of channels + if (cv_image.channels() != desired_channels) { + throw std::runtime_error{"The loaded image does not have the desired number of channels."}; + } + + int width = cv_image.cols; + int height = cv_image.rows; + + struct SharedImageAllocator { + unsigned char* image; + int channels, height, width; + + void* allocate(size_t bytes, size_t) const { + if (image && static_cast(channels * height * width) == bytes) { + return image; + } + throw std::runtime_error{"Unexpected number of bytes was requested to allocate."}; + } + + void deallocate(void*, size_t bytes, size_t) { + if (static_cast(channels * height * width) != bytes) { + throw std::runtime_error{"Unexpected number of bytes was requested to deallocate."}; + } + image = nullptr; // Prevent dangling pointer + } + + bool is_equal(const SharedImageAllocator& other) const noexcept { + return this == &other; + } + }; + + // Wrap OpenCV image data into the custom allocator + return ov::Tensor( + ov::element::u8, + ov::Shape{1, static_cast(height), static_cast(width), static_cast(desired_channels)}, + SharedImageAllocator{cv_image.data, desired_channels, height, width} + ); +} diff --git a/openvino_bindings/src/vlm/load_image.hpp b/openvino_bindings/src/vlm/load_image.hpp new file mode 100644 index 00000000..f43b3373 --- /dev/null +++ b/openvino_bindings/src/vlm/load_image.hpp @@ -0,0 +1,15 @@ + +// Copyright (C) 2023-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +namespace utils +{ + ov::Tensor load_image(const std::filesystem::path& image_path); + std::vector load_images(const std::vector & input_paths); +} \ No newline at end of file diff --git a/openvino_bindings/src/vlm/vlm_inference.cc b/openvino_bindings/src/vlm/vlm_inference.cc new file mode 100644 index 00000000..385c6aee --- /dev/null +++ b/openvino_bindings/src/vlm/vlm_inference.cc @@ -0,0 +1,137 @@ +/* + * Copyright (c) 2024 Intel Corporation + * + * SPDX-License-Identifier: Apache-2.0 + */ + + +#include +#include + +#include "vlm_inference.h" +#include "load_image.hpp" +#include + +#include "src/image/json_serialization.h" + +bool print_subword(std::string&& subword) +{ + return !(std::cout << subword << std::flush); +} + +std::string join_texts(const std::vector& texts) +{ + std::ostringstream oss; + for (size_t i = 0; i < texts.size(); ++i) + { + oss << texts[i]; + if (i < texts.size() - 1) + { + oss << " "; // Add a space between words but not after the last one + } + } + return oss.str(); +} + +void VLMInference::set_streamer(const std::function callback) { + streamer = [callback, this](std::string word) { + if (_stop) { + _done = true; + streamer_lock.unlock(); + cond.notify_all(); + return true; + } + callback(word.c_str()); + return false; + }; +} + + +VLMStringWithMetrics VLMInference::prompt(std::string message, int max_new_tokens) +{ + _stop = false; + + std::lock_guard guard(pipe_mutex); + + const auto t1 = std::chrono::steady_clock::now(); + + ov::genai::GenerationConfig generation_config; + generation_config.max_new_tokens = 100; + + if (streamer) + { + streamer_lock.lock(); + } + _done = false; + + const ov::genai::DecodedResults results = update_images && !imagePaths.empty() + ? ov_pipe->generate(message, + ov::genai::images( + utils::load_images(imagePaths) + ), + ov::genai::generation_config(generation_config), + ov::genai::streamer(streamer)) + : ov_pipe->generate(message, + ov::genai::generation_config(generation_config), + ov::genai::streamer(streamer)); + + + update_images = false; // Do not reload images on next message, except when setImagePaths is called before. + + if (streamer) + { + streamer_lock.unlock(); + cond.notify_all(); + } + + _done = true; + + + auto texts = results.texts; + + // Make Metrics + const auto t2 = std::chrono::steady_clock::now(); + + const auto generate_time = std::chrono::duration_cast(t2 - t1).count(); + + const auto load_time_f = static_cast(load_time); + const auto generate_time_f = static_cast(generate_time); + const auto metrics = VLMMetrics{ + !std::isnan(load_time_f) ? load_time_f : 0.0f, + !std::isnan(generate_time_f) ? generate_time_f : 0.0f, + }; + + // Return + auto res = VLMStringWithMetrics{strdup(join_texts(texts).c_str()), metrics}; + return res; +} + +void VLMInference::setImagePaths(std::vector paths) +{ + imagePaths = paths; + update_images = true; +} + +void VLMInference::force_stop() +{ + // This lock comes free after generation is complete + // During generation, it's not safe to dispose class as OV may still write to memory + std::lock_guard guard(pipe_mutex); + ov_pipe->finish_chat(); + + // Stop streamer + _stop = true; + std::unique_lock lock(streamer_lock); + while(!_done) { + cond.wait(lock); + } + +} + + +bool VLMInference::has_model_index() const +{ + std::ifstream ifs(model_path + "/model_index.json"); + auto r = nlohmann::json::parse(ifs); + return r.find("chat_template") != r.end(); +} diff --git a/openvino_bindings/src/vlm/vlm_inference.h b/openvino_bindings/src/vlm/vlm_inference.h new file mode 100644 index 00000000..da5f06b2 --- /dev/null +++ b/openvino_bindings/src/vlm/vlm_inference.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2024 Intel Corporation + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifndef VLM_INFERENCE_H_ +#define VLM_INFERENCE_H_ + +#include + +#include "src/utils/vlm_metrics.h" +#include "openvino/genai/visual_language/pipeline.hpp" + +class VLMInference +{ + long load_time = 9999; + std::unique_ptr ov_pipe; + std::mutex pipe_mutex; + std::function streamer; + +public: + VLMInference(std::string model_path, std::string device): + // Use a lambda to initialize the 'pipe' and measure the construction time in one step + ov_pipe(nullptr), model_path(model_path) + { + auto start_time = std::chrono::steady_clock::now(); + + ov::AnyMap enable_compile_cache; + if (device == "GPU") { + // Cache compiled models on disk for GPU to save time on the + // next run. It's not beneficial for CPU. + enable_compile_cache.insert({ov::cache_dir("vlm_cache")}); + } + + ov_pipe = std::make_unique(model_path, device, enable_compile_cache); + ov_pipe->start_chat(); + + auto end_time = std::chrono::steady_clock::now(); + + std::filesystem::path bgr_path = std::filesystem::path(model_path) / "channel_info.json"; + this->flip_bgr = std::filesystem::exists(bgr_path); + + // Calculate load time + this->load_time = std::chrono::duration_cast(end_time - start_time).count(); + } + + + VLMStringWithMetrics prompt(std::string message, int max_new_tokens); + void set_streamer(std::function callback); + void setImagePaths(std::vector paths); + void force_stop(); + bool has_model_index() const; + +private: + std::string model_path; + + bool _stop = false; + std::mutex streamer_lock; + std::condition_variable cond; + bool _done = true; + + bool update_images = true; + std::vector imagePaths; + bool flip_bgr = false; + +}; + +#endif // VLM_INFERENCE_H_ diff --git a/openvino_bindings/src/vlm/vlm_inference_test.cc b/openvino_bindings/src/vlm/vlm_inference_test.cc new file mode 100644 index 00000000..a926208c --- /dev/null +++ b/openvino_bindings/src/vlm/vlm_inference_test.cc @@ -0,0 +1,16 @@ +/* + * Copyright (c) 2024 Intel Corporation + * + * SPDX-License-Identifier: Apache-2.0 + */ + + +#include "gtest/gtest.h" +#include "tti_inference.h" + +TEST(VLMInference, Sanity) { + std::string model_path = "data/TinyLlama-1.1B-Chat-v1.0-int4-ov"; + LLMInference inference(model_path, "CPU"); + std::string output = inference.prompt("What is the color of the sun?", 1.0f, 1.0f); + EXPECT_STREQ(output.c_str(), "The color of the sun is a beautiful and awe-inspiring yellow-amber color. It is a natural, radiant, and beautiful color that is associated with warmth, light, and lightning. The sun is often depicted as a radiant, yellow-amber ball of light that shines down on the earth, illuminating the world and inspiring wonder and awe in all who see it."); +} From 3a5cffa71a1cfd02f094b784df3885a5cdd791c8 Mon Sep 17 00:00:00 2001 From: Arend Jan Kramer Date: Tue, 14 Jan 2025 14:29:42 +0100 Subject: [PATCH 02/15] Bug fix --- lib/widgets/controls/drop_area.dart | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/widgets/controls/drop_area.dart b/lib/widgets/controls/drop_area.dart index 3afc2bd9..51f48cdb 100644 --- a/lib/widgets/controls/drop_area.dart +++ b/lib/widgets/controls/drop_area.dart @@ -81,7 +81,7 @@ class _DropAreaState extends State { return Column( mainAxisAlignment: MainAxisAlignment.spaceAround, - spacing: 10, + // spacing: 10, children: [ // Top text Text( From 63188cb96f810aea9099e1a9f6dda93938c9daf4 Mon Sep 17 00:00:00 2001 From: Arend Jan Kramer Date: Tue, 21 Jan 2025 14:38:31 +0100 Subject: [PATCH 03/15] Cleanup and synchronize UI for TTI --- .../text_to_image/live_inference_pane.dart | 96 ---- .../performance_metrics_pane.dart | 36 +- lib/pages/text_to_image/playground.dart | 301 ++++++++++ .../text_to_image/text_to_image_page.dart | 7 +- .../widgets/assistant_message.dart | 226 ++++++++ .../widgets/device_selector.dart | 70 --- .../widgets/horizontal_rule.dart | 29 - .../widgets/model_properties.dart | 2 +- .../text_to_image/widgets/tti_chat_area.dart | 543 ------------------ .../widgets/tti_metrics_grid.dart | 39 -- .../text_to_image/widgets/user_message.dart | 39 ++ .../text_to_image_inference_provider.dart | 19 +- 12 files changed, 614 insertions(+), 793 deletions(-) delete mode 100644 lib/pages/text_to_image/live_inference_pane.dart create mode 100644 lib/pages/text_to_image/playground.dart create mode 100644 lib/pages/text_to_image/widgets/assistant_message.dart delete mode 100644 lib/pages/text_to_image/widgets/device_selector.dart delete mode 100644 lib/pages/text_to_image/widgets/horizontal_rule.dart delete mode 100644 lib/pages/text_to_image/widgets/tti_chat_area.dart delete mode 100644 lib/pages/text_to_image/widgets/tti_metrics_grid.dart create mode 100644 lib/pages/text_to_image/widgets/user_message.dart rename lib/{pages/text_to_image => }/providers/text_to_image_inference_provider.dart (89%) diff --git a/lib/pages/text_to_image/live_inference_pane.dart b/lib/pages/text_to_image/live_inference_pane.dart deleted file mode 100644 index 07b232d2..00000000 --- a/lib/pages/text_to_image/live_inference_pane.dart +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright (c) 2024 Intel Corporation -// -// SPDX-License-Identifier: Apache-2.0 - -import 'package:fluent_ui/fluent_ui.dart'; -import 'package:inference/widgets/grid_container.dart'; -import 'package:inference/pages/text_to_image/providers/text_to_image_inference_provider.dart'; -import 'package:inference/pages/text_to_image/widgets/model_properties.dart'; -import 'package:inference/pages/text_to_image/widgets/toolbar_text_input.dart'; -import 'package:inference/pages/text_to_image/widgets/tti_chat_area.dart'; -import 'package:inference/pages/text_to_image/widgets/vertical_rule.dart'; -import 'package:inference/theme_fluent.dart'; -import 'package:provider/provider.dart'; -import 'package:inference/widgets/device_selector.dart'; - -class TTILiveInferencePane extends StatefulWidget { - const TTILiveInferencePane({super.key}); - - @override - State createState() => _PlaygroundState(); -} - -class _PlaygroundState extends State { - TextToImageInferenceProvider provider() => - Provider.of(context, listen: false); - - @override - Widget build(BuildContext context) { - final theme = FluentTheme.of(context); - - return Consumer( - builder: (context, inference, child) { - - return Row( - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - Expanded( - child: Column( - children: [ - SizedBox( - height: 64, - child: GridContainer( - child: Padding( - padding: const EdgeInsets.symmetric(horizontal: 16), - child: Row( - children: [ - const DeviceSelector(), - const Padding( - padding: EdgeInsets.symmetric(vertical: 16), - child: VerticalRule()), - ToolbarTextInput( - marginLeft: 0, - labelText: "Width", - suffix: "px", - initialValue: provider().width, - roundPowerOfTwo: true, - onChanged: (value) { provider().width = value; }), - ToolbarTextInput( - marginLeft: 20, - labelText: "Height", - suffix: "px", - initialValue: provider().height, - roundPowerOfTwo: true, - onChanged: (value) { provider().height = value; }), - ToolbarTextInput( - marginLeft: 20, - labelText: "Rounds", - suffix: "", - initialValue: provider().rounds, - onChanged: (value) { provider().rounds = value; }), - ], - ), - ), - ), - ), - Expanded( - child: GridContainer( - color: backgroundColor.of(theme), - child: Builder(builder: (context) { - return const TTIChatArea(); - }), - ), - ) - ], - ), - ), - const ModelProperties(), - ], - ); - }); - } -} - - - - diff --git a/lib/pages/text_to_image/performance_metrics_pane.dart b/lib/pages/text_to_image/performance_metrics_pane.dart index b517b697..75620014 100644 --- a/lib/pages/text_to_image/performance_metrics_pane.dart +++ b/lib/pages/text_to_image/performance_metrics_pane.dart @@ -3,8 +3,10 @@ // SPDX-License-Identifier: Apache-2.0 import 'package:fluent_ui/fluent_ui.dart'; -import 'package:inference/pages/text_to_image/providers/text_to_image_inference_provider.dart'; -import 'package:inference/pages/text_to_image/widgets/tti_metrics_grid.dart'; +import 'package:inference/interop/openvino_bindings.dart'; +import 'package:inference/providers/text_to_image_inference_provider.dart'; +import 'package:inference/widgets/metrics_card.dart'; +import 'package:intl/intl.dart'; import 'package:provider/provider.dart'; class TTIPerformanceMetricsPane extends StatefulWidget { @@ -65,3 +67,33 @@ class _TTIPerformanceMetricsPaneState extends State { } + +class TTIMetricsGrid extends StatelessWidget { + final TTIMetrics metrics; + + const TTIMetricsGrid({super.key, required this.metrics}); + + @override + Widget build(BuildContext context) { + Locale locale = Localizations.localeOf(context); + final nf = NumberFormat.decimalPatternDigits( + locale: locale.languageCode, decimalDigits: 0); + + return Row( + mainAxisAlignment: MainAxisAlignment.spaceBetween, + mainAxisSize: MainAxisSize.min, // Wrap content + children: [ + MetricsCard( + header: "Time to load model", + value: nf.format(metrics.load_time), + unit: "ms", + ), + MetricsCard( + header: "Time to generate image", + value: nf.format(metrics.generate_time), + unit: "ms", + ) + ], + ); + } +} diff --git a/lib/pages/text_to_image/playground.dart b/lib/pages/text_to_image/playground.dart new file mode 100644 index 00000000..87ed78a4 --- /dev/null +++ b/lib/pages/text_to_image/playground.dart @@ -0,0 +1,301 @@ +// Copyright (c) 2024 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 + +import 'dart:io'; + +import 'package:fluent_ui/fluent_ui.dart'; +import 'package:flutter/services.dart'; +import 'package:inference/pages/text_to_image/widgets/assistant_message.dart'; +import 'package:inference/pages/text_to_image/widgets/toolbar_text_input.dart'; +import 'package:inference/pages/text_to_image/widgets/user_message.dart'; +import 'package:inference/pages/text_to_image/widgets/model_properties.dart'; +import 'package:inference/widgets/grid_container.dart'; +import 'package:inference/project.dart'; +import 'package:inference/providers/text_to_image_inference_provider.dart'; +import 'package:inference/theme_fluent.dart'; +import 'package:inference/widgets/device_selector.dart'; +import 'package:intl/intl.dart'; +import 'package:provider/provider.dart'; + +class TTIPlayground extends StatefulWidget { + final Project project; + + const TTIPlayground({required this.project, super.key}); + + + @override + _TTIPlaygroundState createState() => _TTIPlaygroundState(); +} + +class SubmitMessageIntent extends Intent {} + +class _TTIPlaygroundState extends State { + final _textController = TextEditingController(); + final _scrollController = ScrollController(); + bool attachedToBottom = true; + + void jumpToBottom({ offset = 0 }) { + if (_scrollController.hasClients) { + _scrollController.jumpTo(_scrollController.position.maxScrollExtent + offset); + } + } + + void message(String message) { + if (message.isEmpty) return; + final provider = Provider.of(context, listen: false); + if (!provider.initialized || provider.response != null) return; + _textController.text = ''; + jumpToBottom(offset: 110); //move to bottom including both + provider.message(message).catchError((e) async { + if (mounted) { + await displayInfoBar(context, builder: (context, close) => InfoBar( + title: const Text("An error occurred processing the message"), + content: Text(e.toString()), + severity: InfoBarSeverity.error, + action: IconButton( + icon: const Icon(FluentIcons.clear), + onPressed: close, + ), + )); + } + }); + } + + @override + void initState() { + super.initState(); + _scrollController.addListener(() { + setState(() { + attachedToBottom = _scrollController.position.pixels + 0.001 >= _scrollController.position.maxScrollExtent; + }); + }); + } + + @override + void dispose() { + _textController.dispose(); + _scrollController.dispose(); + super.dispose(); + } + + @override + void didChangeDependencies() { + super.didChangeDependencies(); + if (attachedToBottom) { + jumpToBottom(); + } + } + + @override + Widget build(BuildContext context) { + Locale locale = Localizations.localeOf(context); + final nf = NumberFormat.decimalPatternDigits( + locale: locale.languageCode, decimalDigits: 2); + final theme = FluentTheme.of(context); + + return Row( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Consumer(builder: (context, provider, child) => + Expanded(child: Column( + children: [ + SizedBox( + height: 64, + child: GridContainer( + child: Padding( + padding: const EdgeInsets.symmetric(horizontal: 16), + child: Row( + children: [ + const DeviceSelector(), + const Divider(size: 24,direction: Axis.vertical,), + const SizedBox(width: 24,), + ToolbarTextInput( + marginLeft: 0, + labelText: "Width", + suffix: "px", + initialValue: provider.width, + roundPowerOfTwo: true, + onChanged: (value) { provider.width = value; }), + ToolbarTextInput( + marginLeft: 20, + labelText: "Height", + suffix: "px", + initialValue: provider.height, + roundPowerOfTwo: true, + onChanged: (value) { provider.height = value; }), + ToolbarTextInput( + marginLeft: 20, + labelText: "Rounds", + suffix: "", + initialValue: provider.rounds, + onChanged: (value) { provider.rounds = value; }), + ], + ) + ), + ), + ), + Expanded(child: DecoratedBox( + decoration: BoxDecoration( + color: theme.brightness.isDark ? backgroundColor.dark : theme.scaffoldBackgroundColor + ), + child: GridContainer(child: SizedBox( + width: double.infinity, + child: Builder(builder: (context) { + if (!provider.initialized) { + return const Column( + mainAxisAlignment: MainAxisAlignment.center, + children: [ + SizedBox(width: 64,height: 64, child: ProgressRing()), + Padding( + padding: EdgeInsets.only(top: 18), + child: Text("Loading model..."), + ) + ], + ); + } + return Column( + children: [ + Expanded( + child: Builder(builder: (context) { + if (provider.messages.isEmpty) { + return Center( + child: Text("Start chatting with ${provider.project?.name ?? "the model"}!"), + ); + } + return Stack( + alignment: Alignment.bottomCenter, + children: [ + SingleChildScrollView( + controller: _scrollController, + child: Padding(padding: const EdgeInsets.symmetric(horizontal: 64, vertical: 20), child: SelectionArea( + child: SelectionArea( + child: Column( + children: provider.messages.map((message) { switch (message.speaker) { + case Speaker.user: return Padding( + padding: const EdgeInsets.only(left: 42), + child: UserInputMessage(message), + ); + case Speaker.system: return Text('System: ${message.message}'); + case Speaker.assistant: return GeneratedImageMessage( + message, + provider.project! + .thumbnailImage(), + provider.project!.name); + }}).toList(), + ), + ), + ),), + ), + Positioned( + bottom: 10, + child: Builder(builder: (context) => attachedToBottom + ? const SizedBox() + : Padding( + padding: const EdgeInsets.only(top:2), + child: FilledButton(child: const Row( + children: [ + Icon(FluentIcons.chevron_down, size: 12), + SizedBox(width: 4), + Text('Scroll to bottom'), + ], + ), onPressed: () { + jumpToBottom(); + setState(() { + attachedToBottom = true; + }); + }), + ) + ), + ) + ], + ); + }), + ), + Padding( + padding: const EdgeInsets.symmetric(horizontal: 64, vertical: 24), + child: Column( + children: [ + Row( + crossAxisAlignment: CrossAxisAlignment.end, + mainAxisAlignment: MainAxisAlignment.center, + children: [ + Padding( + padding: const EdgeInsets.only(bottom: 20), + child: Tooltip( + message: "Create new thread", + child: Button( + onPressed: provider.interimResponse == null ? () => provider.reset() : null, + child: const Icon(FluentIcons.rocket, size: 18), + ), + ), + ), + Expanded( + child: Padding( + padding: const EdgeInsets.symmetric(horizontal: 8), + child: Shortcuts( + shortcuts: { + LogicalKeySet(LogicalKeyboardKey.meta, LogicalKeyboardKey.enter): SubmitMessageIntent(), + LogicalKeySet(LogicalKeyboardKey.control, LogicalKeyboardKey.enter): SubmitMessageIntent(), + }, + child: Actions( + actions: >{ + SubmitMessageIntent: CallbackAction( + onInvoke: (SubmitMessageIntent intent) => message(_textController.text), + ), + }, + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + TextBox( + placeholder: "Type a message...", + keyboardType: TextInputType.multiline, + controller: _textController, + maxLines: null, + expands: true, + onSubmitted: message, + autofocus: true, + ), + Padding( + padding: const EdgeInsets.only(top: 6, left: 10), + child: Text( + 'Press ${Platform.isMacOS ? '⌘' : 'Ctrl'} + Enter to submit, Enter for newline', + style: TextStyle(fontSize: 11, color: subtleTextColor.of(theme)), + ), + ), + ], + ), + ), + ), + ), + ), + Padding( + padding: const EdgeInsets.only(bottom: 20), + child: Builder(builder: (context) { + final isRunning = provider.interimResponse != null; + return Tooltip( + message: "Send message", + child: Button( + onPressed: isRunning ? null : () => message(_textController.text), + child: const Icon(FluentIcons.send, size: 18), + ), + ); + }), + ) + ] + ), + ], + ), + ) + ], + ); + }), + )), + )), + ], + ))), + const ModelProperties(), + ], + ); + } +} diff --git a/lib/pages/text_to_image/text_to_image_page.dart b/lib/pages/text_to_image/text_to_image_page.dart index 3f7773af..a06fb12b 100644 --- a/lib/pages/text_to_image/text_to_image_page.dart +++ b/lib/pages/text_to_image/text_to_image_page.dart @@ -4,8 +4,8 @@ import 'package:fluent_ui/fluent_ui.dart'; import 'package:flutter_svg/svg.dart'; -import 'package:inference/pages/text_to_image/live_inference_pane.dart'; -import 'package:inference/pages/text_to_image/providers/text_to_image_inference_provider.dart'; +import 'package:inference/pages/text_to_image/playground.dart'; +import 'package:inference/providers/text_to_image_inference_provider.dart'; import 'package:inference/pages/text_to_image/performance_metrics_pane.dart'; import 'package:inference/project.dart'; import 'package:inference/providers/preference_provider.dart'; @@ -35,7 +35,6 @@ class _TextToImagePageState extends State { ); final textColor = theme.typography.body?.color ?? Colors.black; - const inferencePane = TTILiveInferencePane(); const metricsPane = TTIPerformanceMetricsPane(); return ChangeNotifierProxyProvider( lazy: false, @@ -93,7 +92,7 @@ class _TextToImagePageState extends State { width: 15, ), title: const Text("Live Inference"), - body: inferencePane, + body: TTIPlayground(project: widget.project), ), PaneItem( icon: SvgPicture.asset("images/stats.svg", diff --git a/lib/pages/text_to_image/widgets/assistant_message.dart b/lib/pages/text_to_image/widgets/assistant_message.dart new file mode 100644 index 00000000..9207b7d3 --- /dev/null +++ b/lib/pages/text_to_image/widgets/assistant_message.dart @@ -0,0 +1,226 @@ +import 'package:fluent_ui/fluent_ui.dart'; +import 'package:inference/providers/text_to_image_inference_provider.dart'; +import 'package:inference/theme_fluent.dart'; +import 'package:intl/intl.dart'; +import 'package:provider/provider.dart'; +import 'package:super_clipboard/super_clipboard.dart'; + +class GeneratedImageMessage extends StatefulWidget { + final ImageMessage message; + final ImageProvider icon; + final String name; + + const GeneratedImageMessage(this.message, this.icon, this.name, {super.key}); + + @override + State createState() => _GeneratedImageMessageState(); +} + +class _GeneratedImageMessageState extends State { + bool _hovering = false; + + @override + Widget build(BuildContext context) { + Locale locale = Localizations.localeOf(context); + final nf = NumberFormat.decimalPatternDigits( + locale: locale.languageCode, decimalDigits: 0); + final theme = FluentTheme.of(context); + final backgroundColor = theme.brightness.isDark + ? theme.scaffoldBackgroundColor + : const Color(0xFFF5F5F5); + + final image = SmartImageWidget(message: widget.message); + + void showContentDialog(BuildContext context) async { + if (image.hasClipboard()) { + showDialog( + context: context, + barrierDismissible: true, + builder: (context) => ContentDialog( + constraints: const BoxConstraints(maxWidth: double.infinity), + content: image, + ), + ); + } + } + + return Consumer( + builder: (context, inferenceProvider, child) => Align( + child: Padding( + padding: const EdgeInsets.only(bottom: 8), + child: Row( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Padding( + padding: const EdgeInsets.only(right: 10, top: 20), + child: ClipRRect( + borderRadius: BorderRadius.circular(16.0), + child: Container( + width: 32, + height: 32, + decoration: BoxDecoration( + image: DecorationImage( + image: inferenceProvider.project!.thumbnailImage(), + fit: BoxFit.cover, + ), + ), + ), + ), + ), + Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Padding( + padding: const EdgeInsets.only(bottom: 4), + child: SelectionContainer.disabled( + child: Row( + children: [ + Text( + inferenceProvider.project!.name, + style: TextStyle( + color: subtleTextColor.of(theme), + ), + ), + if (widget.message.time != null) + Text( + DateFormat(' | yyyy-MM-dd HH:mm:ss') + .format(widget.message.time!), + style: TextStyle( + color: subtleTextColor.of(theme), + ), + ), + ], + ), + ), + ), + MouseRegion( + onEnter: (_) => setState(() { + _hovering = true; + }), + onExit: (_) => setState(() { + _hovering = false; + }), + child: Column( + crossAxisAlignment: CrossAxisAlignment.end, + children: [ + Container( + constraints: BoxConstraints( + maxWidth: + MediaQuery.of(context).size.width - 502), + decoration: BoxDecoration( + color: backgroundColor, + borderRadius: BorderRadius.circular(4), + ), + child: Padding( + padding: const EdgeInsets.all(12.0), + child: Column(children: [ + Column(children: [ + // Changes cursor to hand on hover when there's a dialog that can be opened + MouseRegion( + cursor: image.hasClipboard() + ? SystemMouseCursors.click + : SystemMouseCursors.basic, + child: GestureDetector( + onTap: () => + showContentDialog(context), + // Opens dialog on click + child: ConstrainedBox( + constraints: + const BoxConstraints( + maxHeight: 256.0, + maxWidth: 256.0, + ), + child: image))), + ]) + , + ]), + ), + ), + if (_hovering && image.hasClipboard()) + Padding( + padding: const EdgeInsets.only(top: 4), + child: SelectionContainer.disabled( + child: Row( + children: [ + if (widget.message.metrics != null) + Padding( + padding: const EdgeInsets.only(right: 8), + child: Tooltip( + message: 'Generation time', + child: Text( + '${nf.format(widget.message.metrics!.generate_time)}ms', + style: TextStyle( + fontSize: 12, + color: subtleTextColor.of(theme), + ), + ), + ), + ), + IconButton( + icon: const Icon(FluentIcons.copy), + onPressed: () async { + await displayInfoBar( + context, + builder: (context, close) => InfoBar( + title: + const Text('Copied to clipboard'), + severity: InfoBarSeverity.info, + action: IconButton( + icon: const Icon(FluentIcons.clear), + onPressed: close, + ), + ), + ); + image.copyToClipboard(); + }, + ), + ], + ), + ), + ) + else + const SizedBox(height: 34) + ], + ), + ), + ], + ), + ], + ), + ), + ), + ); + } +} + +class SmartImageWidget extends StatelessWidget { + final ImageMessage message; + + const SmartImageWidget({super.key, required this.message}); + + void copyToClipboard() { + final clipboard = SystemClipboard.instance; + if (clipboard == null || message.imageContent == null) { + return; // Clipboard API is not supported on this platform. + } + final item = DataWriterItem(); + + item.add(Formats.jpeg(message.imageContent!.imageData)); + clipboard.write([item]); + } + + bool hasClipboard() { + return message.allowedCopy; + } + + @override + Widget build(BuildContext context) { + return Image.memory( + message.imageContent!.imageData, + width: message.imageContent!.width.toDouble(), + height: message.imageContent!.height.toDouble(), + fit: message.imageContent!.boxFit, + ); + } +} + diff --git a/lib/pages/text_to_image/widgets/device_selector.dart b/lib/pages/text_to_image/widgets/device_selector.dart deleted file mode 100644 index 1d96babc..00000000 --- a/lib/pages/text_to_image/widgets/device_selector.dart +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) 2024 Intel Corporation -// -// SPDX-License-Identifier: Apache-2.0 - -// ignore_for_file: unused_import - -import 'package:fluent_ui/fluent_ui.dart'; -import 'package:inference/providers/preference_provider.dart'; -import 'package:provider/provider.dart'; -import 'package:collection/collection.dart'; - -class DeviceSelector extends StatefulWidget { - const DeviceSelector({super.key}); - - @override - State createState() => _DeviceSelectorState(); -} - -class _DeviceSelectorState extends State { - String? selectedDevice; - - @override - void initState() { - super.initState(); - selectedDevice = Provider.of(context, listen: false).device; - } - - @override - Widget build(BuildContext context) { - return Consumer(builder: (context, preferences, child) { - return Column( - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - const Padding( - padding: EdgeInsets.only(bottom: 16), - child: Text("Device", - style: TextStyle( - fontSize: 16, - fontWeight: FontWeight.bold, - ), - ), - ), - Column( - crossAxisAlignment: CrossAxisAlignment.stretch, - children: [ - ComboBox( - value: selectedDevice, - items: PreferenceProvider.availableDevices.map>((e) { - return ComboBoxItem( - value: e.id, - child: Text(e.name), - ); - }).toList(), - onChanged: (v) { - setState(() { - selectedDevice = v; - if (v != null) { - preferences.device = v; - } - }); - }, - ), - ], - ), - ], - ); - } - ); - } -} diff --git a/lib/pages/text_to_image/widgets/horizontal_rule.dart b/lib/pages/text_to_image/widgets/horizontal_rule.dart deleted file mode 100644 index e513822b..00000000 --- a/lib/pages/text_to_image/widgets/horizontal_rule.dart +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) 2024 Intel Corporation -// -// SPDX-License-Identifier: Apache-2.0 - -import 'package:fluent_ui/fluent_ui.dart'; -import 'package:inference/theme_fluent.dart'; - -class HorizontalRule extends StatelessWidget { - const HorizontalRule({super.key}); - - @override - Widget build(BuildContext context) { - final theme = FluentTheme.of(context); - - return Padding( - padding: const EdgeInsets.symmetric(vertical: 20), - child: Container( - decoration: BoxDecoration( - border: Border( - bottom: BorderSide( - color: borderColor.of(theme), - width: 1, - ) - ) - ), - ), - ); - } -} diff --git a/lib/pages/text_to_image/widgets/model_properties.dart b/lib/pages/text_to_image/widgets/model_properties.dart index 2a04e003..00b2cccc 100644 --- a/lib/pages/text_to_image/widgets/model_properties.dart +++ b/lib/pages/text_to_image/widgets/model_properties.dart @@ -3,7 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 import 'package:fluent_ui/fluent_ui.dart'; -import 'package:inference/pages/text_to_image/providers/text_to_image_inference_provider.dart'; +import 'package:inference/providers/text_to_image_inference_provider.dart'; import 'package:inference/utils.dart'; import 'package:inference/widgets/grid_container.dart'; import 'package:inference/widgets/model_propery.dart'; diff --git a/lib/pages/text_to_image/widgets/tti_chat_area.dart b/lib/pages/text_to_image/widgets/tti_chat_area.dart deleted file mode 100644 index f10190f9..00000000 --- a/lib/pages/text_to_image/widgets/tti_chat_area.dart +++ /dev/null @@ -1,543 +0,0 @@ -// Copyright (c) 2024 Intel Corporation -// -// SPDX-License-Identifier: Apache-2.0 - -import 'package:fluent_ui/fluent_ui.dart'; -import 'package:flutter_svg/svg.dart'; -import 'package:inference/interop/openvino_bindings.dart'; -import 'package:inference/pages/text_to_image/providers/text_to_image_inference_provider.dart'; -import 'package:inference/pages/text_to_image/widgets/tti_metrics_grid.dart'; -import 'package:inference/theme_fluent.dart'; -import 'package:provider/provider.dart'; -import 'package:super_clipboard/super_clipboard.dart'; - -class TTIChatArea extends StatefulWidget { - const TTIChatArea({super.key}); - - @override - State createState() => _PlaygroundState(); -} - -class _PlaygroundState extends State { - final _controller = TextEditingController(); - final _scrollController = ScrollController(); - bool attachedToBottom = true; - - void jumpToBottom({offset = 0}) { - if (_scrollController.hasClients) { - _scrollController - .jumpTo(_scrollController.position.maxScrollExtent + offset); - } - } - - void message(String message) async { - if (message.isEmpty) { - return; - } - final tti = provider(); - if (!tti.initialized) { - return; - } - - if (tti.response != null) { - return; - } - _controller.text = ""; - jumpToBottom(offset: 110); //move to bottom including both - tti.message(message); - } - - TextToImageInferenceProvider provider() => - Provider.of(context, listen: false); - - @override - void initState() { - super.initState(); - _scrollController.addListener(() { - setState(() { - attachedToBottom = _scrollController.position.pixels + 0.001 >= - _scrollController.position.maxScrollExtent; - }); - }); - } - - @override - void dispose() { - super.dispose(); - _controller.dispose(); - _scrollController.dispose(); - } - - @override - Widget build(BuildContext context) { - return Consumer( - builder: (context, inference, child) { - WidgetsBinding.instance.addPostFrameCallback((_) { - if (attachedToBottom) { - jumpToBottom(); - } - }); - - final theme = FluentTheme.of(context); - final textColor = theme.typography.body?.color ?? Colors.black; - - return Column( - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - Builder(builder: (context) { - if (!inference.initialized) { - return Expanded( - child: Center( - child: Column( - mainAxisAlignment: MainAxisAlignment.center, - children: [ - Image.asset('images/intel-loading.gif', width: 100), - const Text("Loading model...") - ], - )), - ); - } - return Expanded( - child: Container( - decoration: const BoxDecoration( - shape: BoxShape.rectangle, - borderRadius: BorderRadius.all(Radius.circular(8)), - ), - child: Column( - children: [ - Expanded( - child: Builder(builder: (context) { - if (inference.messages.isEmpty) { - return Center( - child: Text( - "Type a message to ${inference.project?.name ?? "assistant"}")); - } - return Stack( - alignment: Alignment.topCenter, - children: [ - SingleChildScrollView( - controller: _scrollController, - child: Padding( - padding: const EdgeInsets.all(20), - child: Column( - - // mainAxisAlignment: MainAxisAlignment.start, - crossAxisAlignment: - CrossAxisAlignment.stretch, - children: inference.messages.map((message) { - switch (message.speaker) { - case Speaker.user: - return UserInputMessage(message); - case Speaker.assistant: - return GeneratedImageMessage( - message, - inference.project! - .thumbnailImage(), - inference.project!.name); - } - }).toList()), - ), - ), - Positioned( - bottom: 10, - child: Builder(builder: (context) { - if (attachedToBottom) { - return Container(); - } - return Center( - child: Padding( - padding: const EdgeInsets.only(top: 2.0), - child: SizedBox( - width: 200, - height: 40, - // Adjusted height to match Fluent UI's button dimensions - child: FilledButton( - child: const Text("Jump to bottom"), - onPressed: () { - jumpToBottom(); - setState(() { - attachedToBottom = true; - }); - }, - ), - ), - ), - ); - }), - ), - ], - ); - }), - ), - - // SizedBox( - // height: 30, - // child: Builder( - // builder: (context) { - // if (inference.interimResponse == null){ - // return Container(); - // } - // return Center( - // child: OutlinedButton.icon( - // onPressed: () => inference.forceStop(), - // icon: const Icon(Icons.stop), - // label: const Text("Stop responding") - // ), - // ); - // } - // ), - // ), - Padding( - padding: const EdgeInsets.only( - left: 45, right: 45, top: 10, bottom: 25), - child: SizedBox( - height: 40, - child: Row( - crossAxisAlignment: CrossAxisAlignment.center, - children: [ - Padding( - padding: const EdgeInsets.only(right: 8), - child: IconButton( - icon: SvgPicture.asset( - "images/clear.svg", - width: 20, - colorFilter: ColorFilter.mode(textColor, BlendMode.srcIn), - ), - onPressed: () => inference.reset(), - ), - ), - Expanded( - child: TextBox( - maxLines: null, - keyboardType: TextInputType.text, - placeholder: "Ask me anything...", - controller: _controller, - onSubmitted: message, - style: const TextStyle( - fontSize: 14, - ), - suffix: IconButton( - icon: Icon( - FluentIcons.send, - color: - (inference.interimResponse == null - ? textColor - : textColor.withOpacity(0.2)), - ), - onPressed: () => - inference.interimResponse != null ? null : - message(_controller.text), - ), - ), - ) - ], - ), - ), - ), - ], - ), - ), - ); - }), - ], - ); - }); - } -} - -class UserInputMessage extends StatelessWidget { - final Message message; - - const UserInputMessage(this.message, {super.key}); - - @override - Widget build(BuildContext context) { - return Padding( - padding: const EdgeInsets.only(bottom: 20), - child: Column( - crossAxisAlignment: CrossAxisAlignment.end, - children: [ - Container( - constraints: const BoxConstraints(maxWidth: 500), - child: Column( - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - Padding( - padding: const EdgeInsets.only(right: 30.0), - child: MessageWidget( - message: message.message, - innerPadding: 8, - isSender: true), - ), - ])) - ], - ), - ); - } -} - -class GeneratedImageMessage extends StatelessWidget { - final Message message; - final ImageProvider icon; - final String name; - - const GeneratedImageMessage(this.message, this.icon, this.name, {super.key}); - - @override - Widget build(BuildContext context) { - final theme = FluentTheme.of(context); - - final smartImage = SmartImage(message: message); - - return Padding( - padding: const EdgeInsets.only(bottom: 20), - child: Column( - crossAxisAlignment: CrossAxisAlignment.start, - mainAxisSize: MainAxisSize.min, // Wrap content - children: [ - Table( - columnWidths: const { - 0: FixedColumnWidth(40), // Fixed width for the first column - 1: IntrinsicColumnWidth(), // Flexible width for the second column - 2: FixedColumnWidth(40), // Fixed width for the first column - }, - children: [ - TableRow(children: [ - Container(), - Padding( - padding: - const EdgeInsets.only(left: 10, top: 8.0, bottom: 8.0), - child: Text( - name, - style: TextStyle( - color: subtleTextColor.of(theme), - fontSize: 12, - ), - textAlign: TextAlign.start, - ), - ), - const SizedBox.shrink(), - ]), - TableRow(children: [ - RoundedPicture( - name: name, - icon: icon, - ), - if (message.imageContent?.imageData != null) - Padding( - padding: const EdgeInsets.only(left: 10, bottom: 20), - child: Align( - alignment: Alignment.centerLeft, // Align left - child: MessageWidget( - image: smartImage, - metrics: message.metrics, - innerPadding: 24.0, - isSender: false, - ))), - ImageOptionsWidget( - image: smartImage, - metrics: message.metrics, - ), - ]) - ], - ), - ], - ), - ); - } -} - -void showMetricsDialog(BuildContext context, TTIMetrics metrics) async { - await showDialog( - context: context, - barrierDismissible: true, - builder: (context) => ContentDialog( - constraints: const BoxConstraints(maxWidth: double.infinity), - content: TTIMetricsGrid(metrics: metrics), - ), - ); -} - -class RoundedPicture extends StatelessWidget { - final String name; - final ImageProvider icon; // Icon widget provided - - const RoundedPicture({super.key, required this.name, required this.icon}); - - @override - Widget build(BuildContext context) { - return ClipOval( - child: Container( - width: 40, - height: 40, - decoration: BoxDecoration( - image: DecorationImage( - image: icon, // Adjust this to fit your `name` field - fit: BoxFit.cover, - ), - ), - ), - ); - } -} - -class MessageWidget extends StatelessWidget { - final String? message; - final SmartImage? image; - final TTIMetrics? metrics; // If not set, no copy-paste options. - final double innerPadding; - final bool isSender; - - const MessageWidget( - {super.key, - this.message, - this.image, - this.metrics, - required this.innerPadding, - required this.isSender}); - - @override - Widget build(BuildContext context) { - final theme = FluentTheme.of(context); - final textColor = theme.typography.body?.color ?? Colors.black; - - void showContentDialog(BuildContext context) async { - if (image?.hasClipboard() ?? false) { - showDialog( - context: context, - barrierDismissible: true, - builder: (context) => ContentDialog( - constraints: const BoxConstraints(maxWidth: double.infinity), - content: image, - ), - ); - } - } - - return Container( - decoration: BoxDecoration( - color: - isSender ? userMessageColor.of(theme) : modelMessageColor.of(theme), - borderRadius: const BorderRadius.only( - topLeft: Radius.circular(4.0), - topRight: Radius.circular(4.0), - bottomLeft: Radius.circular(4.0), - bottomRight: Radius.circular(4.0), - ), - ), - padding: EdgeInsets.all(innerPadding), - child: Column(children: [ - image != null - ? Column(children: [ - // Changes cursor to hand on hover when there's a dialog that can be opened - MouseRegion( - cursor: image?.hasClipboard() ?? false - ? SystemMouseCursors.click - : SystemMouseCursors.basic, - child: GestureDetector( - onTap: () => showContentDialog(context), - // Opens dialog on click - child: ConstrainedBox( - constraints: const BoxConstraints( - maxHeight: 256.0, - maxWidth: 256.0, - ), - child: image!))), - ]) - : const SizedBox.shrink(), - message != null - ? SelectableText( - message!, - style: TextStyle( - color: textColor, - fontSize: 14, - fontWeight: FontWeight.w400, - ), - ) - : const SizedBox.shrink(), - ]), - ); - } -} - -class ImageOptionsWidget extends StatelessWidget { - final SmartImage? image; - final TTIMetrics? metrics; - - const ImageOptionsWidget({super.key, this.image, this.metrics}); - - @override - Widget build(BuildContext context) { - bool hasClipboard = image?.hasClipboard() ?? false; - bool hasMetrics = metrics != null; - final textColor = - FluentTheme.of(context).typography.body?.color ?? Colors.black; - - return Column( - mainAxisSize: MainAxisSize.min, // Wrap content - children: [ - Opacity( - opacity: hasClipboard ? 1.0 : 0.25, - child: IconButton( - icon: SvgPicture.asset( - "images/copy.svg", - colorFilter: ColorFilter.mode(textColor, BlendMode.srcIn), - width: 14, - height: 14, - ), - onPressed: hasClipboard - ? () { - image?.copyToClipboard(); - } - : null, - )), - Opacity( - opacity: hasMetrics ? 1.0 : 0.25, - child: IconButton( - icon: SvgPicture.asset( - "images/stats.svg", - colorFilter: ColorFilter.mode(textColor, BlendMode.srcIn), - width: 14, - height: 14, - ), - // tooltip: "Show stats", - onPressed: () { - metrics != null ? showMetricsDialog(context, metrics!) : null; - }, - ), - ), - ], - ); - } -} - -class SmartImage extends StatelessWidget { - final Message message; - - const SmartImage({super.key, required this.message}); - - void copyToClipboard() { - final clipboard = SystemClipboard.instance; - if (clipboard == null || message.imageContent == null) { - return; // Clipboard API is not supported on this platform. - } - final item = DataWriterItem(); - - item.add(Formats.jpeg(message.imageContent!.imageData)); - clipboard.write([item]); - } - - bool hasClipboard() { - return message.allowedCopy; - } - - @override - Widget build(BuildContext context) { - return Image.memory( - message.imageContent!.imageData, - width: message.imageContent!.width.toDouble(), - height: message.imageContent!.height.toDouble(), - fit: message.imageContent!.boxFit, - ); - } -} diff --git a/lib/pages/text_to_image/widgets/tti_metrics_grid.dart b/lib/pages/text_to_image/widgets/tti_metrics_grid.dart deleted file mode 100644 index e2e7be0b..00000000 --- a/lib/pages/text_to_image/widgets/tti_metrics_grid.dart +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) 2024 Intel Corporation -// -// SPDX-License-Identifier: Apache-2.0 - -import 'package:fluent_ui/fluent_ui.dart'; -import 'package:inference/interop/openvino_bindings.dart'; -import 'package:inference/widgets/metrics_card.dart'; -import 'package:intl/intl.dart'; - - -class TTIMetricsGrid extends StatelessWidget { - final TTIMetrics metrics; - - const TTIMetricsGrid({super.key, required this.metrics}); - - @override - Widget build(BuildContext context) { - Locale locale = Localizations.localeOf(context); - final nf = NumberFormat.decimalPatternDigits( - locale: locale.languageCode, decimalDigits: 0); - - return Row( - mainAxisAlignment: MainAxisAlignment.spaceBetween, - mainAxisSize: MainAxisSize.min, // Wrap content - children: [ - MetricsCard( - header: "Time to load model", - value: nf.format(metrics.load_time), - unit: "ms", - ), - MetricsCard( - header: "Time to generate image", - value: nf.format(metrics.generate_time), - unit: "ms", - ) - ], - ); - } -} diff --git a/lib/pages/text_to_image/widgets/user_message.dart b/lib/pages/text_to_image/widgets/user_message.dart new file mode 100644 index 00000000..47e19423 --- /dev/null +++ b/lib/pages/text_to_image/widgets/user_message.dart @@ -0,0 +1,39 @@ +import 'package:flutter_markdown/flutter_markdown.dart'; +import 'package:inference/providers/text_to_image_inference_provider.dart'; +import 'package:fluent_ui/fluent_ui.dart'; +import 'package:inference/theme_fluent.dart'; +import 'package:markdown/markdown.dart' as md; + +class UserInputMessage extends StatelessWidget { + final ImageMessage message; + const UserInputMessage(this.message, {super.key}); + + @override + Widget build(BuildContext context) { + final theme = FluentTheme.of(context); + return Align( + alignment: Alignment.centerRight, + child: Padding(padding: const EdgeInsets.only(bottom: 20), child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Container( + decoration: BoxDecoration( + color: cosmosBackground.of(theme), + borderRadius: BorderRadius.circular(4), + ), + child: Padding( + padding: const EdgeInsets.all(12.0), + child: MarkdownBody( + data: message.message, + extensionSet: md.ExtensionSet( + md.ExtensionSet.gitHubFlavored.blockSyntaxes, + [md.EmojiSyntax(), ...md.ExtensionSet.gitHubFlavored.inlineSyntaxes], + ), + ), + ), + ) + ], + ),), + ); + } +} \ No newline at end of file diff --git a/lib/pages/text_to_image/providers/text_to_image_inference_provider.dart b/lib/providers/text_to_image_inference_provider.dart similarity index 89% rename from lib/pages/text_to_image/providers/text_to_image_inference_provider.dart rename to lib/providers/text_to_image_inference_provider.dart index 2d11bcd3..f231c803 100644 --- a/lib/pages/text_to_image/providers/text_to_image_inference_provider.dart +++ b/lib/providers/text_to_image_inference_provider.dart @@ -13,7 +13,7 @@ import 'package:inference/interop/generated_bindings.dart'; import 'package:inference/interop/tti_inference.dart'; import 'package:inference/project.dart'; -enum Speaker { assistant, user } +enum Speaker { assistant, system, user } class ImageContent { final Uint8List imageData; @@ -24,14 +24,15 @@ class ImageContent { } -class Message { +class ImageMessage { final Speaker speaker; final String message; final ImageContent? imageContent; final TTIMetrics? metrics; + final DateTime? time; final bool allowedCopy; // Don't allow loading images to be copied - const Message(this.speaker, this.message, this.imageContent, this.metrics, this.allowedCopy); + const ImageMessage(this.speaker, this.message, this.imageContent, this.metrics, this.time, this.allowedCopy); } class TextToImageInferenceProvider extends ChangeNotifier { @@ -120,7 +121,7 @@ class TextToImageInferenceProvider extends ChangeNotifier { } bool get initialized => loaded.isCompleted; - final List _messages = []; + final List _messages = []; double? _speed; @@ -144,16 +145,16 @@ class TextToImageInferenceProvider extends ChangeNotifier { return "Image Generation"; } - Message? get interimResponse { + ImageMessage? get interimResponse { if (_response == null) { return null; } final imageContent = ImageContent(_imageBytes ?? Uint8List(0), _loadWidth, _loadHeight, BoxFit.contain); - return Message(Speaker.assistant, response!, imageContent, null, false); + return ImageMessage(Speaker.assistant, response!, imageContent, null, DateTime.now(), false); } - List get messages { + List get messages { if (interimResponse == null) { return _messages; } @@ -167,7 +168,7 @@ class TextToImageInferenceProvider extends ChangeNotifier { Future message(String message) async { _response = "Generating image..."; - _messages.add(Message(Speaker.user, message, null, null, false)); + _messages.add(ImageMessage(Speaker.user, message, null, null, DateTime.now(), false)); notifyListeners(); _loadWidth = width; @@ -178,7 +179,7 @@ class TextToImageInferenceProvider extends ChangeNotifier { final imageContent = ImageContent(imageData, _loadWidth, _loadHeight, BoxFit.contain); if (_messages.isNotEmpty) { - _messages.add(Message(Speaker.assistant, "Generated image", imageContent, response.metrics, true)); + _messages.add(ImageMessage(Speaker.assistant, "Generated image", imageContent, response.metrics, DateTime.now(), true)); } _response = null; From 86e78d6684fb5040c47ed45bfa726f0414544821 Mon Sep 17 00:00:00 2001 From: Arend Jan Kramer Date: Tue, 21 Jan 2025 14:47:17 +0100 Subject: [PATCH 04/15] Sync icons --- lib/pages/text_generation/text_generation.dart | 5 ++++- lib/pages/transcription/transcription.dart | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/lib/pages/text_generation/text_generation.dart b/lib/pages/text_generation/text_generation.dart index 20ce6bcc..ef46961e 100644 --- a/lib/pages/text_generation/text_generation.dart +++ b/lib/pages/text_generation/text_generation.dart @@ -98,7 +98,10 @@ class _TextGenerationPageState extends State { displayMode: PaneDisplayMode.top, items: [ PaneItem( - icon: const Icon(FluentIcons.game), + icon: SvgPicture.asset("images/playground.svg", + colorFilter: ColorFilter.mode(textColor, BlendMode.srcIn), + width: 15, + ), title: const Text("Playground"), body: Playground(project: widget.project), ), diff --git a/lib/pages/transcription/transcription.dart b/lib/pages/transcription/transcription.dart index 3bf284fc..8b1c25bd 100644 --- a/lib/pages/transcription/transcription.dart +++ b/lib/pages/transcription/transcription.dart @@ -85,7 +85,10 @@ class _TranscriptionPageState extends State { displayMode: PaneDisplayMode.top, items: [ PaneItem( - icon: const Icon(FluentIcons.game), + icon: SvgPicture.asset("images/playground.svg", + colorFilter: ColorFilter.mode(textColor, BlendMode.srcIn), + width: 15, + ), title: const Text("Playground"), body: Playground(project: widget.project), ), From 048671e5a204dc698a5782d9d61ba76e432993aa Mon Sep 17 00:00:00 2001 From: Arend Jan Kramer Date: Tue, 21 Jan 2025 14:47:53 +0100 Subject: [PATCH 05/15] Add external link buttons to vlm and tti --- .../text_to_image/widgets/model_properties.dart | 12 ++++++++++++ lib/pages/vlm/widgets/model_properties.dart | 14 +++++++++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/lib/pages/text_to_image/widgets/model_properties.dart b/lib/pages/text_to_image/widgets/model_properties.dart index 00b2cccc..93fa7757 100644 --- a/lib/pages/text_to_image/widgets/model_properties.dart +++ b/lib/pages/text_to_image/widgets/model_properties.dart @@ -6,9 +6,11 @@ import 'package:fluent_ui/fluent_ui.dart'; import 'package:inference/providers/text_to_image_inference_provider.dart'; import 'package:inference/utils.dart'; import 'package:inference/widgets/grid_container.dart'; +import 'package:inference/widgets/horizontal_rule.dart'; import 'package:inference/widgets/model_propery.dart'; import 'package:intl/intl.dart'; import 'package:provider/provider.dart'; +import 'package:url_launcher/url_launcher.dart'; class ModelProperties extends StatelessWidget { const ModelProperties({super.key}); @@ -61,6 +63,16 @@ class ModelProperties extends StatelessWidget { ); } ), + if (inference.project!.isPublic) Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + const HorizontalRule(), + const Text('External links', style: TextStyle(fontSize: 16, fontWeight: FontWeight.w600)), + HyperlinkButton( + child: const Text("Model on Hugging Face"), onPressed: () { launchUrl(Uri.parse('https://huggingface.co/${inference.project!.modelId}')); } + ), + ], + ), ], ), ) diff --git a/lib/pages/vlm/widgets/model_properties.dart b/lib/pages/vlm/widgets/model_properties.dart index ffda0246..275fb11b 100644 --- a/lib/pages/vlm/widgets/model_properties.dart +++ b/lib/pages/vlm/widgets/model_properties.dart @@ -3,12 +3,14 @@ // SPDX-License-Identifier: Apache-2.0 import 'package:fluent_ui/fluent_ui.dart'; -import 'package:inference/pages/vlm/providers/vlm_inference_provider.dart'; +import 'package:inference/providers/vlm_inference_provider.dart'; import 'package:inference/utils.dart'; import 'package:inference/widgets/grid_container.dart'; +import 'package:inference/widgets/horizontal_rule.dart'; import 'package:inference/widgets/model_propery.dart'; import 'package:intl/intl.dart'; import 'package:provider/provider.dart'; +import 'package:url_launcher/url_launcher.dart'; class ModelProperties extends StatelessWidget { const ModelProperties({super.key}); @@ -61,6 +63,16 @@ class ModelProperties extends StatelessWidget { ); } ), + if (inference.project!.isPublic) Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + const HorizontalRule(), + const Text('External links', style: TextStyle(fontSize: 16, fontWeight: FontWeight.w600)), + HyperlinkButton( + child: const Text("Model on Hugging Face"), onPressed: () { launchUrl(Uri.parse('https://huggingface.co/${inference.project!.modelId}')); } + ), + ], + ), ], ), ) From 15c3792ca6255e125944cdf9e4aba81a87dc074c Mon Sep 17 00:00:00 2001 From: Arend Jan Kramer Date: Tue, 21 Jan 2025 16:29:03 +0100 Subject: [PATCH 06/15] Sync ui between chat, tti and vlm --- lib/interop/generated_bindings.dart | 21 + lib/pages/text_to_image/playground.dart | 2 +- .../widgets/toolbar_text_input.dart | 119 ----- lib/pages/vlm/live_inference_pane.dart | 318 +++++++++++--- lib/pages/vlm/performance_metrics_pane.dart | 2 +- lib/pages/vlm/vlm_page.dart | 5 +- lib/pages/vlm/widgets/assistant_message.dart | 182 ++++++++ lib/pages/vlm/widgets/device_selector.dart | 70 --- lib/pages/vlm/widgets/horizontal_rule.dart | 29 -- lib/pages/vlm/widgets/image_grid.dart | 151 ++++--- lib/pages/vlm/widgets/user_message.dart | 43 ++ lib/pages/vlm/widgets/vertical_rule.dart | 29 -- lib/pages/vlm/widgets/vlm_chat_area.dart | 412 ------------------ .../providers/vlm_inference_provider.dart | 54 +-- lib/widgets/controls/drop_area.dart | 7 +- .../vlm => }/widgets/toolbar_text_input.dart | 0 .../widgets/vertical_rule.dart | 0 openvino_bindings/src/utils/utils.cc | 15 + openvino_bindings/src/utils/utils.h | 2 + openvino_bindings/src/utils/vlm_metrics.h | 7 + openvino_bindings/src/vlm/load_image.cpp | 9 +- openvino_bindings/src/vlm/load_image.hpp | 8 +- openvino_bindings/src/vlm/vlm_inference.cc | 7 +- openvino_bindings/src/vlm/vlm_inference.h | 2 +- 24 files changed, 672 insertions(+), 822 deletions(-) delete mode 100644 lib/pages/text_to_image/widgets/toolbar_text_input.dart create mode 100644 lib/pages/vlm/widgets/assistant_message.dart delete mode 100644 lib/pages/vlm/widgets/device_selector.dart delete mode 100644 lib/pages/vlm/widgets/horizontal_rule.dart create mode 100644 lib/pages/vlm/widgets/user_message.dart delete mode 100644 lib/pages/vlm/widgets/vertical_rule.dart delete mode 100644 lib/pages/vlm/widgets/vlm_chat_area.dart rename lib/{pages/vlm => }/providers/vlm_inference_provider.dart (80%) rename lib/{pages/vlm => }/widgets/toolbar_text_input.dart (100%) rename lib/{pages/text_to_image => }/widgets/vertical_rule.dart (100%) diff --git a/lib/interop/generated_bindings.dart b/lib/interop/generated_bindings.dart index 07a52835..53476871 100644 --- a/lib/interop/generated_bindings.dart +++ b/lib/interop/generated_bindings.dart @@ -989,6 +989,27 @@ final class VLMMetrics extends ffi.Struct { @ffi.Float() external double generate_time; + + @ffi.Float() + external double tokenization_time; + + @ffi.Float() + external double detokenization_time; + + @ffi.Float() + external double ttft; + + @ffi.Float() + external double tpot; + + @ffi.Float() + external double throughput; + + @ffi.Int() + external int number_of_generated_tokens; + + @ffi.Int() + external int number_of_input_tokens; } final class VLMStringWithMetrics extends ffi.Struct { diff --git a/lib/pages/text_to_image/playground.dart b/lib/pages/text_to_image/playground.dart index 87ed78a4..a35700e7 100644 --- a/lib/pages/text_to_image/playground.dart +++ b/lib/pages/text_to_image/playground.dart @@ -7,7 +7,6 @@ import 'dart:io'; import 'package:fluent_ui/fluent_ui.dart'; import 'package:flutter/services.dart'; import 'package:inference/pages/text_to_image/widgets/assistant_message.dart'; -import 'package:inference/pages/text_to_image/widgets/toolbar_text_input.dart'; import 'package:inference/pages/text_to_image/widgets/user_message.dart'; import 'package:inference/pages/text_to_image/widgets/model_properties.dart'; import 'package:inference/widgets/grid_container.dart'; @@ -15,6 +14,7 @@ import 'package:inference/project.dart'; import 'package:inference/providers/text_to_image_inference_provider.dart'; import 'package:inference/theme_fluent.dart'; import 'package:inference/widgets/device_selector.dart'; +import 'package:inference/widgets/toolbar_text_input.dart'; import 'package:intl/intl.dart'; import 'package:provider/provider.dart'; diff --git a/lib/pages/text_to_image/widgets/toolbar_text_input.dart b/lib/pages/text_to_image/widgets/toolbar_text_input.dart deleted file mode 100644 index 3946abb3..00000000 --- a/lib/pages/text_to_image/widgets/toolbar_text_input.dart +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright (c) 2024 Intel Corporation -// -// SPDX-License-Identifier: Apache-2.0 - -import 'dart:math'; - -import 'package:fluent_ui/fluent_ui.dart'; - -class ToolbarTextInput extends StatefulWidget { - final String labelText; - final String suffix; - final int marginLeft; - final int initialValue; - final bool? roundPowerOfTwo; - final void Function(int)? onChanged; - - const ToolbarTextInput({ - super.key, - required this.labelText, - required this.suffix, - required this.marginLeft, - required this.initialValue, - this.roundPowerOfTwo, - this.onChanged, - }); - - @override - State createState() => _ToolbarTextInputState(); -} - -class _ToolbarTextInputState extends State { - final TextEditingController _controller = TextEditingController(); - final FocusNode _focusNode = FocusNode(); - - @override - void initState() { - super.initState(); - _controller.text = widget.initialValue.toString(); // Set the initial text - if (widget.roundPowerOfTwo ?? false) { - _focusNode.addListener(_onFocusChange); // Listen for focus changes - } - } - - void _onFocusChange() { - if (!_focusNode.hasFocus) { - // When the TextBox loses focus, round and update - final inputValue = int.tryParse(_controller.text.replaceAll(RegExp(r'[^0-9]'), '')) ?? 0; - final rounded = _nearestPowerOfTwo(inputValue); - - _controller.text = rounded.toString(); - widget.onChanged!(rounded); - - } - } - - /// Calculate the nearest power of 2 for a given number - int _nearestPowerOfTwo(int value) { - if (value <= 0) return 1; // Smallest power of 2 is 1 - int lowerPower = pow(2, (log(value) / log(2)).floor()).toInt(); - int higherPower = pow(2, (log(value) / log(2)).ceil()).toInt(); - return (value - lowerPower < higherPower - value) ? lowerPower : higherPower; - } - - - void _onTextChanged(String value) { - // Keep only digits in the input - final newValue = value.replaceAll(RegExp(r'[^0-9]'), ''); - if (value != newValue) { - // Update the controller text and cursor position - _controller.text = newValue; - _controller.selection = TextSelection.collapsed(offset: newValue.length); - } - - if (widget.onChanged != null) { - if (newValue.isNotEmpty) { - // Parse the integer and call the callback - widget.onChanged!(int.parse(newValue)); - } else { - // Optionally handle empty input - widget.onChanged!(0); // You can choose to pass null or handle differently - } - } - } - - - @override - Widget build(BuildContext context) { - return Row( - crossAxisAlignment: CrossAxisAlignment.center, - children: [ - Padding( - padding: EdgeInsets.only(left: 10 + widget.marginLeft.toDouble(), right: 10), - child: Text( - widget.labelText, - style: const TextStyle( - fontSize: 14, - fontWeight: FontWeight.normal, - ), - ), - ), - SizedBox( - width: 85, - height: 30, - child: TextBox( - controller: _controller, - focusNode: _focusNode, - maxLines: 1, - keyboardType: TextInputType.number, // Ensure numeric keyboard - suffix: Padding( - padding: const EdgeInsets.only(right: 8.0), - child: Text(widget.suffix), - ), - onChanged: _onTextChanged, // Custom handler for integer validation - ), - ), - ], - ); - } -} diff --git a/lib/pages/vlm/live_inference_pane.dart b/lib/pages/vlm/live_inference_pane.dart index bd565d1b..0ba5df0a 100644 --- a/lib/pages/vlm/live_inference_pane.dart +++ b/lib/pages/vlm/live_inference_pane.dart @@ -2,81 +2,299 @@ // // SPDX-License-Identifier: Apache-2.0 +import 'dart:io'; + import 'package:fluent_ui/fluent_ui.dart'; -import 'package:inference/widgets/grid_container.dart'; -import 'package:inference/pages/vlm/providers/vlm_inference_provider.dart'; +import 'package:flutter/services.dart'; +import 'package:inference/pages/vlm/widgets/assistant_message.dart'; +import 'package:inference/pages/vlm/widgets/image_grid.dart'; import 'package:inference/pages/vlm/widgets/model_properties.dart'; -import 'package:inference/pages/vlm/widgets/toolbar_text_input.dart'; -import 'package:inference/pages/vlm/widgets/vlm_chat_area.dart'; -import 'package:inference/pages/vlm/widgets/vertical_rule.dart'; +import 'package:inference/widgets/toolbar_text_input.dart'; +import 'package:inference/pages/vlm/widgets/user_message.dart'; +import 'package:inference/project.dart'; +import 'package:inference/providers/vlm_inference_provider.dart'; import 'package:inference/theme_fluent.dart'; -import 'package:provider/provider.dart'; import 'package:inference/widgets/device_selector.dart'; +import 'package:inference/widgets/grid_container.dart'; +import 'package:inference/widgets/horizontal_rule.dart'; +import 'package:provider/provider.dart'; + +class VLMPlayground extends StatefulWidget { + final Project project; + + const VLMPlayground({required this.project, super.key}); -class VLMLiveInferencePane extends StatefulWidget { - const VLMLiveInferencePane({super.key}); @override - State createState() => _PlaygroundState(); + _VLMPlaygroundState createState() => _VLMPlaygroundState(); } -class _PlaygroundState extends State { - VLMInferenceProvider provider() => - Provider.of(context, listen: false); +class SubmitMessageIntent extends Intent {} + +class _VLMPlaygroundState extends State { + final _textController = TextEditingController(); + final _scrollController = ScrollController(); + bool attachedToBottom = true; + + void handleFileListChange(List paths) { + final provider = Provider.of(context, listen: false); + if (!provider.initialized) { + return; + } + provider.setImagePaths(paths); + } + void jumpToBottom({ offset = 0 }) { + if (_scrollController.hasClients) { + _scrollController.jumpTo(_scrollController.position.maxScrollExtent + offset); + } + } + + void message(String message) { + if (message.isEmpty) return; + final provider = Provider.of(context, listen: false); + if (!provider.initialized || provider.response != null) return; + _textController.text = ''; + jumpToBottom(offset: 110); //move to bottom including both + provider.message(message).catchError((e) async { + // provider.resetInterimResponse(); // Allow user to type again + if (mounted) { + await displayInfoBar(context, builder: (context, close) => InfoBar( + title: const Text("An error occurred processing the message"), + content: Text(e.toString()), + severity: InfoBarSeverity.error, + action: IconButton( + icon: const Icon(FluentIcons.clear), + onPressed: close, + ), + )); + } + }); + } + + @override + void initState() { + super.initState(); + _scrollController.addListener(() { + setState(() { + attachedToBottom = _scrollController.position.pixels + 0.001 >= _scrollController.position.maxScrollExtent; + }); + }); + } + + @override + void dispose() { + _textController.dispose(); + _scrollController.dispose(); + super.dispose(); + } + + @override + void didChangeDependencies() { + super.didChangeDependencies(); + if (attachedToBottom) { + jumpToBottom(); + } + } @override Widget build(BuildContext context) { final theme = FluentTheme.of(context); - const vlmChatArea = VLMChatArea(); - - return Consumer(builder: (context, inference, child) { - return Row( - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - Expanded( - child: Column( + return Row( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Consumer(builder: (context, provider, child) => + Expanded(child: Column( children: [ SizedBox( height: 64, child: GridContainer( child: Padding( - padding: const EdgeInsets.symmetric(horizontal: 16), - child: Row( - children: [ - const DeviceSelector(), - const Padding( - padding: EdgeInsets.symmetric(vertical: 16), - child: VerticalRule()), - ToolbarTextInput( - marginLeft: 0, - labelText: "Max new tokens", - suffix: "", - initialValue: provider().maxTokens, - roundPowerOfTwo: true, - onChanged: (value) { - provider().maxTokens = value; - }), - ], - ), + padding: const EdgeInsets.symmetric(horizontal: 16), + child: Row( + children: [ + const DeviceSelector(), + const Divider(size: 24,direction: Axis.vertical,), + const SizedBox(width: 24,), + ToolbarTextInput( + marginLeft: 0, + labelText: "Max new tokens", + suffix: "", + initialValue: provider.maxTokens, + roundPowerOfTwo: true, + onChanged: (value) { + provider.maxTokens = value; + }), + ], + ) ), ), ), - Expanded( - child: GridContainer( - color: backgroundColor.of(theme), + Expanded(child: DecoratedBox( + decoration: BoxDecoration( + color: theme.brightness.isDark ? backgroundColor.dark : theme.scaffoldBackgroundColor + ), + child: GridContainer(child: SizedBox( + width: double.infinity, child: Builder(builder: (context) { - return vlmChatArea; + if (!provider.initialized) { + return const Column( + mainAxisAlignment: MainAxisAlignment.center, + children: [ + SizedBox(width: 64,height: 64, child: ProgressRing()), + Padding( + padding: EdgeInsets.only(top: 18), + child: Text("Loading model..."), + ) + ], + ); + } + return Column( + children: [ + SizedBox( + height: 220, + child: ImageGrid( + initialGalleryData: provider.getImagePaths(), + onFileListChange: handleFileListChange, + )), + const Divider(size: double.infinity,direction: Axis.horizontal,), + Expanded( + child: Builder(builder: (context) { + if (provider.messages.isEmpty) { + return Center( + child: Text("Start chatting with ${provider.project?.name ?? "the model"}!"), + ); + } + return Stack( + alignment: Alignment.bottomCenter, + children: [ + SingleChildScrollView( + controller: _scrollController, + child: Padding(padding: const EdgeInsets.symmetric(horizontal: 64, vertical: 20), child: SelectionArea( + child: SelectionArea( + child: Column( + children: provider.messages.map((message) { switch (message.speaker) { + case Speaker.user: return Padding( + padding: const EdgeInsets.only(left: 42), + child: UserMessage(message), + ); + case Speaker.system: return Text('System: ${message.message}'); + case Speaker.assistant: return AssistantMessage(message); + }}).toList(), + ), + ), + ),), + ), + Positioned( + bottom: 10, + child: Builder(builder: (context) => attachedToBottom + ? const SizedBox() + : Padding( + padding: const EdgeInsets.only(top:2), + child: FilledButton(child: const Row( + children: [ + Icon(FluentIcons.chevron_down, size: 12), + SizedBox(width: 4), + Text('Scroll to bottom'), + ], + ), onPressed: () { + jumpToBottom(); + setState(() { + attachedToBottom = true; + }); + }), + ) + ), + ) + ], + ); + }), + ), + Padding( + padding: const EdgeInsets.symmetric(horizontal: 64, vertical: 24), + child: Column( + children: [ + Row( + crossAxisAlignment: CrossAxisAlignment.end, + mainAxisAlignment: MainAxisAlignment.center, + children: [ + Padding( + padding: const EdgeInsets.only(bottom: 20), + child: Tooltip( + message: "Create new thread", + child: Button( + onPressed: provider.interimResponse == null ? () => provider.reset() : null, + child: const Icon(FluentIcons.rocket, size: 18), + ), + ), + ), + Expanded( + child: Padding( + padding: const EdgeInsets.symmetric(horizontal: 8), + child: Shortcuts( + shortcuts: { + LogicalKeySet(LogicalKeyboardKey.meta, LogicalKeyboardKey.enter): SubmitMessageIntent(), + LogicalKeySet(LogicalKeyboardKey.control, LogicalKeyboardKey.enter): SubmitMessageIntent(), + }, + child: Actions( + actions: >{ + SubmitMessageIntent: CallbackAction( + onInvoke: (SubmitMessageIntent intent) => message(_textController.text), + ), + }, + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + TextBox( + placeholder: "Type a message...", + keyboardType: TextInputType.multiline, + controller: _textController, + maxLines: null, + expands: true, + onSubmitted: message, + autofocus: true, + ), + Padding( + padding: const EdgeInsets.only(top: 6, left: 10), + child: Text( + 'Press ${Platform.isMacOS ? '⌘' : 'Ctrl'} + Enter to submit, Enter for newline', + style: TextStyle(fontSize: 11, color: subtleTextColor.of(theme)), + ), + ), + ], + ), + ), + ), + ), + ), + Padding( + padding: const EdgeInsets.only(bottom: 20), + child: Builder(builder: (context) { + final isRunning = provider.interimResponse != null; + return Tooltip( + message: "Send message", + child: Button( + onPressed: isRunning ? null : () => message(_textController.text), + child: const Icon(FluentIcons.send, size: 18), + ), + ); + }), + ) + ] + ), + ], + ), + ) + ], + ); }), - ), - ) + )), + )), ], - ), - ), - const ModelProperties(), - ], - ); - }); + ))), + const ModelProperties(), + ], + ); } } diff --git a/lib/pages/vlm/performance_metrics_pane.dart b/lib/pages/vlm/performance_metrics_pane.dart index c0f4b2f0..e6a8323f 100644 --- a/lib/pages/vlm/performance_metrics_pane.dart +++ b/lib/pages/vlm/performance_metrics_pane.dart @@ -3,7 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 import 'package:fluent_ui/fluent_ui.dart'; -import 'package:inference/pages/vlm/providers/vlm_inference_provider.dart'; +import 'package:inference/providers/vlm_inference_provider.dart'; import 'package:inference/pages/vlm/widgets/vlm_metrics_grid.dart'; import 'package:provider/provider.dart'; diff --git a/lib/pages/vlm/vlm_page.dart b/lib/pages/vlm/vlm_page.dart index 3076aa2c..65cc4ed6 100644 --- a/lib/pages/vlm/vlm_page.dart +++ b/lib/pages/vlm/vlm_page.dart @@ -5,7 +5,7 @@ import 'package:fluent_ui/fluent_ui.dart'; import 'package:flutter_svg/svg.dart'; import 'package:inference/pages/vlm/live_inference_pane.dart'; -import 'package:inference/pages/vlm/providers/vlm_inference_provider.dart'; +import 'package:inference/providers/vlm_inference_provider.dart'; import 'package:inference/pages/vlm/performance_metrics_pane.dart'; import 'package:inference/project.dart'; import 'package:inference/providers/preference_provider.dart'; @@ -35,7 +35,6 @@ class _VLMPageState extends State { ); final textColor = theme.typography.body?.color ?? Colors.black; - const inferencePane = VLMLiveInferencePane(); const metricsPane = VLMPerformanceMetricsPane(); return ChangeNotifierProxyProvider( lazy: false, @@ -93,7 +92,7 @@ class _VLMPageState extends State { width: 15, ), title: const Text("Live Inference"), - body: inferencePane, + body: VLMPlayground(project: widget.project), ), PaneItem( icon: SvgPicture.asset("images/stats.svg", diff --git a/lib/pages/vlm/widgets/assistant_message.dart b/lib/pages/vlm/widgets/assistant_message.dart new file mode 100644 index 00000000..9b9264b2 --- /dev/null +++ b/lib/pages/vlm/widgets/assistant_message.dart @@ -0,0 +1,182 @@ +// Copyright (c) 2024 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 + +import 'package:fluent_ui/fluent_ui.dart'; +import 'package:flutter/services.dart'; +import 'package:flutter_markdown/flutter_markdown.dart'; +import 'package:markdown/markdown.dart' as md; +import 'package:inference/providers/vlm_inference_provider.dart'; +import 'package:inference/theme_fluent.dart'; +import 'package:intl/intl.dart'; +import 'package:provider/provider.dart'; + +class AssistantMessage extends StatefulWidget { + final Message message; + const AssistantMessage(this.message, {super.key}); + + @override + _AssistantMessageState createState() => _AssistantMessageState(); +} + +class _AssistantMessageState extends State { + bool _hovering = false; + + @override + Widget build(BuildContext context) { + Locale locale = Localizations.localeOf(context); + final nf = NumberFormat.decimalPatternDigits( + locale: locale.languageCode, decimalDigits: 0); + final theme = FluentTheme.of(context); + final backgroundColor = theme.brightness.isDark + ? theme.scaffoldBackgroundColor + : const Color(0xFFF5F5F5); + + return Consumer(builder: (context, inferenceProvider, child) => + Align( + child: Padding( + padding: const EdgeInsets.only(bottom: 8), + child: Row( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Padding( + padding: const EdgeInsets.only(right: 10, top: 20), + child: ClipRRect( + borderRadius: BorderRadius.circular(16.0), + child: Container( + width: 32, + height: 32, + decoration: BoxDecoration( + image: DecorationImage( + image: inferenceProvider.project!.thumbnailImage(), + fit: BoxFit.cover, + ), + ), + ), + ), + ), + Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Padding( + padding: const EdgeInsets.only(bottom: 4), + child: SelectionContainer.disabled( + child: Row( + children: [ + Text( + inferenceProvider.project!.name, + style: TextStyle( + color: subtleTextColor.of(theme), + ), + ), + if (widget.message.time != null) Text( + DateFormat(' | yyyy-MM-dd HH:mm:ss').format(widget.message.time!), + style: TextStyle( + color: subtleTextColor.of(theme), + ), + ), + ], + ), + ), + ), + MouseRegion( + onEnter: (_) => setState(() { _hovering = true; }), + onExit: (_) => setState(() { _hovering = false; }), + child: Column( + crossAxisAlignment: CrossAxisAlignment.end, + children: [ + Container( + constraints: BoxConstraints(maxWidth: MediaQuery.of(context).size.width - 502), + decoration: BoxDecoration( + color: backgroundColor, + borderRadius: BorderRadius.circular(4), + ), + child: Padding( + padding: const EdgeInsets.all(12.0), + child: MarkdownBody( + data: widget.message.message, + extensionSet: md.ExtensionSet( + md.ExtensionSet.gitHubFlavored.blockSyntaxes, + [md.EmojiSyntax(), ...md.ExtensionSet.gitHubFlavored.inlineSyntaxes], + ), + ), + ), + ), + if (_hovering) + Padding( + padding: const EdgeInsets.only(top: 4), + child: SelectionContainer.disabled( + child: Row( + children: [ + // Commented out, as VLM pipeline in OV doesn't return the correct values in return value. + // if (widget.message.metrics != null) Padding( + // padding: const EdgeInsets.only(right: 8), + // child: Tooltip( + // message: 'Time to first token', + // child: Text( + // 'TTF: ${nf.format(widget.message.metrics!.ttft)}ms', + // style: TextStyle( + // fontSize: 12, + // color: subtleTextColor.of(theme), + // ), + // ), + // ), + // ), + // if (widget.message.metrics != null) Padding( + // padding: const EdgeInsets.only(right: 8), + // child: Tooltip( + // message: 'Time per output token', + // child: Text( + // 'TPOT: ${nf.format(widget.message.metrics!.tpot)}ms', + // style: TextStyle( + // fontSize: 12, + // color: subtleTextColor.of(theme), + // ), + // ), + // ), + // ), + if (widget.message.metrics != null) Padding( + padding: const EdgeInsets.only(right: 8), + child: Tooltip( + message: 'Generate total duration', + child: Text( + 'Generate: ${nf.format(widget.message.metrics!.generate_time/1000)}s', + style: TextStyle( + fontSize: 12, + color: subtleTextColor.of(theme), + ), + ), + ), + ), + IconButton( + icon: const Icon(FluentIcons.copy), + onPressed: () async{ + await displayInfoBar(context, builder: (context, close) => + InfoBar( + title: const Text('Copied to clipboard'), + severity: InfoBarSeverity.info, + action: IconButton( + icon: const Icon(FluentIcons.clear), + onPressed: close, + ), + ), + ); + Clipboard.setData(ClipboardData(text: widget.message.message)); + }, + ), + ], + ), + ), + ) else const SizedBox(height: 34) + ], + ), + ), + ], + ), + ], + ), + ), + ), + ); + } +} \ No newline at end of file diff --git a/lib/pages/vlm/widgets/device_selector.dart b/lib/pages/vlm/widgets/device_selector.dart deleted file mode 100644 index 1d96babc..00000000 --- a/lib/pages/vlm/widgets/device_selector.dart +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) 2024 Intel Corporation -// -// SPDX-License-Identifier: Apache-2.0 - -// ignore_for_file: unused_import - -import 'package:fluent_ui/fluent_ui.dart'; -import 'package:inference/providers/preference_provider.dart'; -import 'package:provider/provider.dart'; -import 'package:collection/collection.dart'; - -class DeviceSelector extends StatefulWidget { - const DeviceSelector({super.key}); - - @override - State createState() => _DeviceSelectorState(); -} - -class _DeviceSelectorState extends State { - String? selectedDevice; - - @override - void initState() { - super.initState(); - selectedDevice = Provider.of(context, listen: false).device; - } - - @override - Widget build(BuildContext context) { - return Consumer(builder: (context, preferences, child) { - return Column( - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - const Padding( - padding: EdgeInsets.only(bottom: 16), - child: Text("Device", - style: TextStyle( - fontSize: 16, - fontWeight: FontWeight.bold, - ), - ), - ), - Column( - crossAxisAlignment: CrossAxisAlignment.stretch, - children: [ - ComboBox( - value: selectedDevice, - items: PreferenceProvider.availableDevices.map>((e) { - return ComboBoxItem( - value: e.id, - child: Text(e.name), - ); - }).toList(), - onChanged: (v) { - setState(() { - selectedDevice = v; - if (v != null) { - preferences.device = v; - } - }); - }, - ), - ], - ), - ], - ); - } - ); - } -} diff --git a/lib/pages/vlm/widgets/horizontal_rule.dart b/lib/pages/vlm/widgets/horizontal_rule.dart deleted file mode 100644 index e513822b..00000000 --- a/lib/pages/vlm/widgets/horizontal_rule.dart +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) 2024 Intel Corporation -// -// SPDX-License-Identifier: Apache-2.0 - -import 'package:fluent_ui/fluent_ui.dart'; -import 'package:inference/theme_fluent.dart'; - -class HorizontalRule extends StatelessWidget { - const HorizontalRule({super.key}); - - @override - Widget build(BuildContext context) { - final theme = FluentTheme.of(context); - - return Padding( - padding: const EdgeInsets.symmetric(vertical: 20), - child: Container( - decoration: BoxDecoration( - border: Border( - bottom: BorderSide( - color: borderColor.of(theme), - width: 1, - ) - ) - ), - ), - ); - } -} diff --git a/lib/pages/vlm/widgets/image_grid.dart b/lib/pages/vlm/widgets/image_grid.dart index 03965695..da97641d 100644 --- a/lib/pages/vlm/widgets/image_grid.dart +++ b/lib/pages/vlm/widgets/image_grid.dart @@ -1,5 +1,8 @@ -import 'dart:io'; +// Copyright (c) 2024 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 +import 'dart:io'; import 'package:fluent_ui/fluent_ui.dart'; import 'package:inference/widgets/controls/drop_area.dart'; @@ -20,6 +23,8 @@ class ImageGrid extends StatefulWidget { class _ImageGridState extends State { late List galleryData; Map hoverStates = {}; + final ScrollController _scrollController = + ScrollController(); // Define controller @override void initState() { @@ -55,74 +60,90 @@ class _ImageGridState extends State { onUpload: onDrop, type: "image", extensions: const ["jpg", "jpeg", "bmp", "png", "tif", "tiff"], - child: GridView.count( - primary: false, - padding: const EdgeInsets.all(20), - crossAxisSpacing: 10, - mainAxisSpacing: 10, - crossAxisCount: 7, - children: List.generate(galleryData.length, (index) { - String path = galleryData[index]; - bool isLocalFile = File(path).existsSync(); // Check if the file exists locally + child: Align( + alignment: Alignment.centerLeft, + child: Scrollbar( + thumbVisibility: true, + controller: _scrollController, + child: SingleChildScrollView( + scrollDirection: Axis.horizontal, + clipBehavior: Clip.antiAlias, + controller: _scrollController, + child: Padding( + padding: const EdgeInsets.only(bottom: 20), + child: Row( + mainAxisAlignment: MainAxisAlignment.start, + spacing: 20, + children: List.generate(galleryData.length, (index) { + String path = galleryData[index]; + bool isLocalFile = File(path) + .existsSync(); // Check if the file exists locally - bool isHovered = hoverStates[index] ?? false; // Get hover state - return MouseRegion( - onEnter: (_) { - setState(() { - hoverStates[index] = true; - }); - }, - onHover: (_) { - setState(() { - hoverStates[index] = true; - }); - }, - onExit: (_) { - setState(() { - hoverStates[index] = false; - }); - }, - child: Stack( - children: [ - Container( - width: width * 0.3, - height: height * 0.3, - decoration: BoxDecoration( - borderRadius: BorderRadius.circular(10), - color: Colors.black, - image: DecorationImage( - image: isLocalFile - ? FileImage(File(path)) // Load local file - : NetworkImage(path) as ImageProvider, // Load from network - fit: BoxFit.cover, - ), - ), - ), - if (isHovered) - Positioned( - top: 5, - right: 5, - child: GestureDetector( - onTap: () => removeImage(index), - child: Container( - padding: const EdgeInsets.all(8), - decoration: BoxDecoration( - color: Colors.black.withAlpha(200), - shape: BoxShape.circle, + bool isHovered = + hoverStates[index] ?? false; // Get hover state + return MouseRegion( + onEnter: (_) { + setState(() { + hoverStates[index] = true; + }); + }, + onHover: (_) { + setState(() { + hoverStates[index] = true; + }); + }, + onExit: (_) { + setState(() { + hoverStates[index] = false; + }); + }, + child: Stack( + children: [ + AspectRatio( + aspectRatio: 1, + child: Container( + decoration: BoxDecoration( + borderRadius: BorderRadius.circular(10), + color: Colors.black, + image: DecorationImage( + image: isLocalFile + ? FileImage(File(path)) // Load local file + : NetworkImage(path) as ImageProvider, + // Load from network + fit: BoxFit.cover, + ), + ), + ), ), - child: const Icon( - FluentIcons.cancel, - size: 12, - color: Colors.white, - ), - ), + if (isHovered) + Positioned( + top: 5, + right: 5, + child: GestureDetector( + onTap: () => removeImage(index), + child: Container( + padding: const EdgeInsets.all(8), + decoration: BoxDecoration( + color: Colors.black.withAlpha(200), + shape: BoxShape.circle, + ), + child: const Icon( + FluentIcons.cancel, + size: 12, + color: Colors.white, + ), + ), + ), + ), + ], ), - ), - ], + ); + }), + ), ), - ); - }), + ), + ), ), ); } -} \ No newline at end of file +} diff --git a/lib/pages/vlm/widgets/user_message.dart b/lib/pages/vlm/widgets/user_message.dart new file mode 100644 index 00000000..09a8a8f8 --- /dev/null +++ b/lib/pages/vlm/widgets/user_message.dart @@ -0,0 +1,43 @@ +// Copyright (c) 2024 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 + +import 'package:fluent_ui/fluent_ui.dart'; +import 'package:flutter_markdown/flutter_markdown.dart'; +import 'package:markdown/markdown.dart' as md; +import 'package:inference/providers/vlm_inference_provider.dart'; +import 'package:inference/theme_fluent.dart'; + +class UserMessage extends StatelessWidget { + final Message message; + const UserMessage(this.message, {super.key}); + + @override + Widget build(BuildContext context) { + final theme = FluentTheme.of(context); + return Align( + alignment: Alignment.centerRight, + child: Padding(padding: const EdgeInsets.only(bottom: 20), child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Container( + decoration: BoxDecoration( + color: cosmosBackground.of(theme), + borderRadius: BorderRadius.circular(4), + ), + child: Padding( + padding: const EdgeInsets.all(12.0), + child: MarkdownBody( + data: message.message, + extensionSet: md.ExtensionSet( + md.ExtensionSet.gitHubFlavored.blockSyntaxes, + [md.EmojiSyntax(), ...md.ExtensionSet.gitHubFlavored.inlineSyntaxes], + ), + ), + ), + ) + ], + ),), + ); + } +} \ No newline at end of file diff --git a/lib/pages/vlm/widgets/vertical_rule.dart b/lib/pages/vlm/widgets/vertical_rule.dart deleted file mode 100644 index 2072adbb..00000000 --- a/lib/pages/vlm/widgets/vertical_rule.dart +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) 2024 Intel Corporation -// -// SPDX-License-Identifier: Apache-2.0 - -import 'package:fluent_ui/fluent_ui.dart'; -import 'package:inference/theme_fluent.dart'; - -class VerticalRule extends StatelessWidget { - const VerticalRule({super.key}); - - @override - Widget build(BuildContext context) { - final theme = FluentTheme.of(context); - - return Padding( - padding: const EdgeInsets.symmetric(horizontal: 20), - child: Container( - decoration: BoxDecoration( - border: Border( - left: BorderSide( - color: borderColor.of(theme), - width: 1, - ) - ) - ), - ), - ); - } -} diff --git a/lib/pages/vlm/widgets/vlm_chat_area.dart b/lib/pages/vlm/widgets/vlm_chat_area.dart deleted file mode 100644 index dd4f305d..00000000 --- a/lib/pages/vlm/widgets/vlm_chat_area.dart +++ /dev/null @@ -1,412 +0,0 @@ -// Copyright (c) 2024 Intel Corporation -// -// SPDX-License-Identifier: Apache-2.0 - -import 'package:fluent_ui/fluent_ui.dart'; -import 'package:flutter_svg/svg.dart'; -import 'package:inference/interop/openvino_bindings.dart'; -import 'package:inference/pages/vlm/providers/vlm_inference_provider.dart'; -import 'package:inference/pages/vlm/widgets/horizontal_rule.dart'; -import 'package:inference/pages/vlm/widgets/vlm_metrics_grid.dart'; -import 'package:inference/theme_fluent.dart'; -import 'package:provider/provider.dart'; -import 'package:super_clipboard/super_clipboard.dart'; - -import 'image_grid.dart'; - -class VLMChatArea extends StatefulWidget { - const VLMChatArea({super.key}); - - @override - State createState() => VLMChatAreaState(); -} - -class VLMChatAreaState extends State { - final _controller = TextEditingController(); - final _scrollController = ScrollController(); - bool attachedToBottom = true; - - void handleFileListChange(List paths) { - final vlm = provider(); - if (!vlm.initialized) { - return; - } - vlm.setImagePaths(paths); - } - - void jumpToBottom({offset = 0}) { - if (_scrollController.hasClients) { - _scrollController - .jumpTo(_scrollController.position.maxScrollExtent + offset); - } - } - - void message(String message) async { - if (message.isEmpty) { - return; - } - final vlm = provider(); - if (!vlm.initialized) { - return; - } - - if (vlm.response != null) { - return; - } - _controller.text = ""; - jumpToBottom(offset: 110); //move to bottom including both - vlm.message(message); - } - - VLMInferenceProvider provider() => - Provider.of(context, listen: false); - - @override - void initState() { - super.initState(); - _scrollController.addListener(() { - setState(() { - attachedToBottom = _scrollController.position.pixels + 0.001 >= - _scrollController.position.maxScrollExtent; - }); - }); - } - - @override - void dispose() { - super.dispose(); - _controller.dispose(); - _scrollController.dispose(); - } - - @override - Widget build(BuildContext context) { - return Consumer( - builder: (context, inference, child) { - WidgetsBinding.instance.addPostFrameCallback((_) { - if (attachedToBottom) { - jumpToBottom(); - } - }); - - final theme = FluentTheme.of(context); - final textColor = theme.typography.body?.color ?? Colors.black; - - return Column( - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - Builder(builder: (context) { - if (!inference.initialized) { - return Expanded( - child: Center( - child: Column( - mainAxisAlignment: MainAxisAlignment.center, - children: [ - Image.asset('images/intel-loading.gif', width: 100), - const Text("Loading model...") - ], - )), - ); - } - return Expanded( - child: Container( - decoration: const BoxDecoration( - shape: BoxShape.rectangle, - borderRadius: BorderRadius.all(Radius.circular(8)), - ), - child: Column( - children: [ - SizedBox( - height: 220, - child: ImageGrid( - initialGalleryData: const [], - onFileListChange: handleFileListChange, - )), - const HorizontalRule(), - Expanded( - child: Builder(builder: (context) { - if (inference.messages.isEmpty) { - return Center( - child: Text( - "Type a message to ${inference.project?.name ?? "assistant"}")); - } - return Stack( - alignment: Alignment.topCenter, - children: [ - SingleChildScrollView( - controller: _scrollController, - child: Padding( - padding: const EdgeInsets.all(20), - child: Column( - - // mainAxisAlignment: MainAxisAlignment.start, - crossAxisAlignment: - CrossAxisAlignment.stretch, - children: inference.messages.map((message) { - switch (message.speaker) { - case Speaker.user: - return UserInputMessage(message); - case Speaker.assistant: - return GeneratedResponseMessage( - message, - inference.project! - .thumbnailImage(), - inference.project!.name); - } - }).toList()), - ), - ), - Positioned( - bottom: 10, - child: Builder(builder: (context) { - if (attachedToBottom) { - return Container(); - } - return Center( - child: Padding( - padding: const EdgeInsets.only(top: 2.0), - child: SizedBox( - width: 200, - height: 40, - // Adjusted height to match Fluent UI's button dimensions - child: FilledButton( - child: const Text("Jump to bottom"), - onPressed: () { - jumpToBottom(); - setState(() { - attachedToBottom = true; - }); - }, - ), - ), - ), - ); - }), - ), - ], - ); - }), - ), - - // SizedBox( - // height: 30, - // child: Builder( - // builder: (context) { - // if (inference.interimResponse == null){ - // return Container(); - // } - // return Center( - // child: OutlinedButton.icon( - // onPressed: () => inference.forceStop(), - // icon: const Icon(Icons.stop), - // label: const Text("Stop responding") - // ), - // ); - // } - // ), - // ), - Padding( - padding: const EdgeInsets.only( - left: 45, right: 45, top: 10, bottom: 25), - child: SizedBox( - height: 40, - child: Row( - crossAxisAlignment: CrossAxisAlignment.center, - children: [ - Padding( - padding: const EdgeInsets.only(right: 8), - child: IconButton( - icon: SvgPicture.asset( - "images/clear.svg", - width: 20, - colorFilter: ColorFilter.mode(textColor, BlendMode.srcIn), - ), - onPressed: () => inference.reset(), - ), - ), - Expanded( - child: TextBox( - maxLines: null, - keyboardType: TextInputType.text, - placeholder: "Ask me anything...", - controller: _controller, - onSubmitted: message, - style: const TextStyle( - fontSize: 14, - ), - suffix: IconButton( - icon: Icon( - FluentIcons.send, - color: - (inference.interimResponse == null - ? textColor - : textColor.withOpacity(0.2)), - ), - onPressed: () => - inference.interimResponse != null ? null : - message(_controller.text), - ), - ), - ) - ], - ), - ), - ), - ], - ), - ), - ); - }), - ], - ); - }); - } -} - -class UserInputMessage extends StatelessWidget { - final Message message; - - const UserInputMessage(this.message, {super.key}); - - @override - Widget build(BuildContext context) { - return Padding( - padding: const EdgeInsets.only(bottom: 20), - child: Column( - crossAxisAlignment: CrossAxisAlignment.end, - children: [ - Container( - constraints: const BoxConstraints(maxWidth: 500), - child: Column( - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - Padding( - padding: const EdgeInsets.only(right: 30.0), - child: MessageWidget( - message: message.message, - innerPadding: 8, - isSender: true), - ), - ])) - ], - ), - ); - } -} - -class GeneratedResponseMessage extends StatelessWidget { - final Message message; - final ImageProvider icon; - final String name; - - const GeneratedResponseMessage(this.message, this.icon, this.name, {super.key}); - - @override - Widget build(BuildContext context) { - return Padding( - padding: const EdgeInsets.only(bottom: 20), - child: Column( - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - Container( - constraints: const BoxConstraints(maxWidth: 500), - child: Column( - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - Padding( - padding: const EdgeInsets.only(right: 30.0), - child: MessageWidget( - message: message.message, - innerPadding: 8, - isSender: false), - ), - ])) - ], - ), - ); - } -} - - -void showMetricsDialog(BuildContext context, VLMMetrics metrics) async { - await showDialog( - context: context, - barrierDismissible: true, - builder: (context) => ContentDialog( - constraints: const BoxConstraints(maxWidth: double.infinity), - content: VLMMetricsGrid(metrics: metrics), - ), - ); -} - -class RoundedPicture extends StatelessWidget { - final String name; - final ImageProvider icon; // Icon widget provided - - const RoundedPicture({super.key, required this.name, required this.icon}); - - @override - Widget build(BuildContext context) { - return ClipOval( - child: Container( - width: 40, - height: 40, - decoration: BoxDecoration( - image: DecorationImage( - image: icon, // Adjust this to fit your `name` field - fit: BoxFit.cover, - ), - ), - ), - ); - } -} - -class MessageWidget extends StatelessWidget { - final String? message; - final VLMMetrics? metrics; // If not set, no copy-paste options. - final double innerPadding; - final bool isSender; - - const MessageWidget( - {super.key, - this.message, - this.metrics, - required this.innerPadding, - required this.isSender}); - - @override - Widget build(BuildContext context) { - final theme = FluentTheme.of(context); - final textColor = theme.typography.body?.color ?? Colors.black; - - return Container( - decoration: BoxDecoration( - color: - isSender ? userMessageColor.of(theme) : modelMessageColor.of(theme), - borderRadius: const BorderRadius.only( - topLeft: Radius.circular(4.0), - topRight: Radius.circular(4.0), - bottomLeft: Radius.circular(4.0), - bottomRight: Radius.circular(4.0), - ), - ), - padding: EdgeInsets.all(innerPadding), - child: Column(children: [ - message != null - ? SelectableText( - message!, - style: TextStyle( - color: textColor, - fontSize: 14, - fontWeight: FontWeight.w400, - ), - ) - : const SizedBox.shrink(), - ]), - ); - } -} - - - diff --git a/lib/pages/vlm/providers/vlm_inference_provider.dart b/lib/providers/vlm_inference_provider.dart similarity index 80% rename from lib/pages/vlm/providers/vlm_inference_provider.dart rename to lib/providers/vlm_inference_provider.dart index b3338912..a9bfbdb9 100644 --- a/lib/pages/vlm/providers/vlm_inference_provider.dart +++ b/lib/providers/vlm_inference_provider.dart @@ -11,16 +11,17 @@ import 'package:inference/interop/generated_bindings.dart'; import 'package:inference/interop/vlm_inference.dart'; import 'package:inference/project.dart'; -enum Speaker { assistant, user } +enum Speaker { assistant, system, user } class Message { final Speaker speaker; final String message; final VLMMetrics? metrics; + final DateTime? time; final bool allowedCopy; // Don't allow loading images to be copied - const Message(this.speaker, this.message, this.metrics, this.allowedCopy); + const Message(this.speaker, this.message, this.metrics, this.time, this.allowedCopy); } class VLMInferenceProvider extends ChangeNotifier { @@ -28,6 +29,7 @@ class VLMInferenceProvider extends ChangeNotifier { Project? _project; String? _device; + List _imagePaths = []; Project? get project => _project; @@ -126,7 +128,7 @@ class VLMInferenceProvider extends ChangeNotifier { return null; } - return Message(Speaker.assistant, response!, null, false); + return Message(Speaker.assistant, response!, null, DateTime.now(), false); } List get messages { @@ -143,13 +145,13 @@ class VLMInferenceProvider extends ChangeNotifier { Future message(String message) async { _response = "..."; - _messages.add(Message(Speaker.user, message, null, false)); + _messages.add(Message(Speaker.user, message, null, DateTime.now(), false)); notifyListeners(); final response = await _inference!.prompt(message, maxTokens); if (_messages.isNotEmpty) { - _messages.add(Message(Speaker.assistant, response.content, response.metrics, true)); + _messages.add(Message(Speaker.assistant, response.content, response.metrics, DateTime.now(), true)); } _response = null; @@ -161,8 +163,22 @@ class VLMInferenceProvider extends ChangeNotifier { void setImagePaths(List paths) { _inference?.setImagePaths(paths); + _imagePaths = paths; } + List getImagePaths(){ + return _imagePaths; + } + + void resetInterimResponse(){ + if (_response != '...' && response != null) { + _messages.add(Message(Speaker.assistant, _response!, null, DateTime.now(), true)); + } + _response = null; + if (hasListeners) { + notifyListeners(); + } + } void close() { _messages.clear(); @@ -175,10 +191,7 @@ class VLMInferenceProvider extends ChangeNotifier { void forceStop() { _inference?.forceStop(); - if (_response != '...') { - _messages.add(Message(Speaker.assistant, _response!, null, true)); - } - _response = null; + resetInterimResponse(); if (hasListeners) { notifyListeners(); } @@ -186,7 +199,7 @@ class VLMInferenceProvider extends ChangeNotifier { void reset() { //_inference?.close(); - _inference?.forceStop(); + // _inference?.forceStop(); // _inference?.clearHistory(); _messages.clear(); _response = null; @@ -194,25 +207,6 @@ class VLMInferenceProvider extends ChangeNotifier { } - Future _closeInferenceInIsolate(dynamic inference) async { - final receivePort = ReceivePort(); - - // Spawn an isolate and pass the SendPort and inference - await Isolate.spawn((List args) { - final SendPort sendPort = args[0]; - final dynamic inference = args[1]; - try { - inference?.close(); // Perform the blocking operation - } catch (e) { - print("Error closing inference: $e"); - } finally { - sendPort.send(null); // Notify that the operation is complete - } - }, [receivePort.sendPort, inference]); - - // Wait for the isolate to complete - await receivePort.first; - } Future _waitForLoadCompletion() async { if (!loaded.isCompleted) { @@ -228,7 +222,7 @@ class VLMInferenceProvider extends ChangeNotifier { if (_inference != null) { print("Closing inference"); - await _closeInferenceInIsolate(_inference!); + _inference?.close(); print("Closing inference done"); } else { close(); diff --git a/lib/widgets/controls/drop_area.dart b/lib/widgets/controls/drop_area.dart index 51f48cdb..7c4a7c63 100644 --- a/lib/widgets/controls/drop_area.dart +++ b/lib/widgets/controls/drop_area.dart @@ -31,8 +31,11 @@ class _DropAreaState extends State { bool _showReleaseMessage = false; void handleDrop(DropDoneDetails details) { - if (details.files.isNotEmpty) { - widget.onUpload(details.files[0].path); + for (var file in details.files) { + String extension = file.path.split('.').last.toLowerCase(); // Extract file extension + if (widget.extensions?.contains(extension) ?? false) { + widget.onUpload(file.path); + } } } diff --git a/lib/pages/vlm/widgets/toolbar_text_input.dart b/lib/widgets/toolbar_text_input.dart similarity index 100% rename from lib/pages/vlm/widgets/toolbar_text_input.dart rename to lib/widgets/toolbar_text_input.dart diff --git a/lib/pages/text_to_image/widgets/vertical_rule.dart b/lib/widgets/vertical_rule.dart similarity index 100% rename from lib/pages/text_to_image/widgets/vertical_rule.dart rename to lib/widgets/vertical_rule.dart diff --git a/openvino_bindings/src/utils/utils.cc b/openvino_bindings/src/utils/utils.cc index d5125fb7..27e45a90 100644 --- a/openvino_bindings/src/utils/utils.cc +++ b/openvino_bindings/src/utils/utils.cc @@ -30,3 +30,18 @@ Metrics convertToMetricsStruct(ov::genai::PerfMetrics m) { int(m.num_input_tokens) }; } + +VLMMetrics convertToVLMMetricsStruct(ov::genai::PerfMetrics m) { + return VLMMetrics{ + nan_safe(m.get_load_time()), + nan_safe(m.get_generate_duration().mean), + nan_safe(m.get_tokenization_duration().mean), + nan_safe(m.get_detokenization_duration().mean), + nan_safe(m.get_ttft().mean), + nan_safe(m.get_tpot().mean), + nan_safe(m.get_throughput().mean), + int(m.num_generated_tokens), + int(m.num_input_tokens) + }; +} + diff --git a/openvino_bindings/src/utils/utils.h b/openvino_bindings/src/utils/utils.h index 0816307d..4e702608 100644 --- a/openvino_bindings/src/utils/utils.h +++ b/openvino_bindings/src/utils/utils.h @@ -9,10 +9,12 @@ #include "openvino/genai/perf_metrics.hpp" #include "metrics.h" +#include "vlm_metrics.h" #include float nan_safe(const float& value); Metrics convertToMetricsStruct(ov::genai::PerfMetrics m); +VLMMetrics convertToVLMMetricsStruct(ov::genai::PerfMetrics m); #endif // UTILS_H_ diff --git a/openvino_bindings/src/utils/vlm_metrics.h b/openvino_bindings/src/utils/vlm_metrics.h index 2d7323f1..2b61f345 100644 --- a/openvino_bindings/src/utils/vlm_metrics.h +++ b/openvino_bindings/src/utils/vlm_metrics.h @@ -10,6 +10,13 @@ typedef struct { float load_time; float generate_time; + const float tokenization_time; + const float detokenization_time; + const float ttft; + const float tpot; + const float throughput; + const int number_of_generated_tokens; + const int number_of_input_tokens; } VLMMetrics; typedef struct { diff --git a/openvino_bindings/src/vlm/load_image.cpp b/openvino_bindings/src/vlm/load_image.cpp index b466614e..37aaceae 100644 --- a/openvino_bindings/src/vlm/load_image.cpp +++ b/openvino_bindings/src/vlm/load_image.cpp @@ -1,6 +1,8 @@ - -// Copyright (C) 2023-2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 +/* + * Copyright (c) 2024 Intel Corporation + * + * SPDX-License-Identifier: Apache-2.0 + */ #define STB_IMAGE_IMPLEMENTATION #include "load_image.hpp" @@ -37,6 +39,7 @@ ov::Tensor utils::load_image(const std::filesystem::path& image_path) { std::vector buffer((std::istreambuf_iterator(file)), std::istreambuf_iterator()); cv::Mat cv_image = cv::imdecode(buffer, cv::IMREAD_COLOR); + cv::cvtColor(cv_image, cv_image, cv::COLOR_BGR2RGB); if (cv_image.empty()) { throw std::runtime_error{"Failed to load the image."}; diff --git a/openvino_bindings/src/vlm/load_image.hpp b/openvino_bindings/src/vlm/load_image.hpp index f43b3373..10521e01 100644 --- a/openvino_bindings/src/vlm/load_image.hpp +++ b/openvino_bindings/src/vlm/load_image.hpp @@ -1,6 +1,8 @@ - -// Copyright (C) 2023-2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 +/* + * Copyright (c) 2024 Intel Corporation + * + * SPDX-License-Identifier: Apache-2.0 + */ #pragma once diff --git a/openvino_bindings/src/vlm/vlm_inference.cc b/openvino_bindings/src/vlm/vlm_inference.cc index 385c6aee..0d66d116 100644 --- a/openvino_bindings/src/vlm/vlm_inference.cc +++ b/openvino_bindings/src/vlm/vlm_inference.cc @@ -96,10 +96,9 @@ VLMStringWithMetrics VLMInference::prompt(std::string message, int max_new_token const auto load_time_f = static_cast(load_time); const auto generate_time_f = static_cast(generate_time); - const auto metrics = VLMMetrics{ - !std::isnan(load_time_f) ? load_time_f : 0.0f, - !std::isnan(generate_time_f) ? generate_time_f : 0.0f, - }; + auto metrics = convertToVLMMetricsStruct(results.perf_metrics); + metrics.generate_time = generate_time_f; + metrics.load_time = load_time_f; // Return auto res = VLMStringWithMetrics{strdup(join_texts(texts).c_str()), metrics}; diff --git a/openvino_bindings/src/vlm/vlm_inference.h b/openvino_bindings/src/vlm/vlm_inference.h index da5f06b2..dd0185f8 100644 --- a/openvino_bindings/src/vlm/vlm_inference.h +++ b/openvino_bindings/src/vlm/vlm_inference.h @@ -9,7 +9,7 @@ #include -#include "src/utils/vlm_metrics.h" +#include "src/utils/utils.h" #include "openvino/genai/visual_language/pipeline.hpp" class VLMInference From c5da9428eb1a7ab8307177f898ee451273ef4df3 Mon Sep 17 00:00:00 2001 From: Arend Jan Kramer Date: Tue, 21 Jan 2025 16:53:59 +0100 Subject: [PATCH 07/15] Bugfix for older flutter --- lib/pages/vlm/widgets/image_grid.dart | 70 ++++++++++++----------- openvino_bindings/src/tti/tti_inference.h | 8 ++- openvino_bindings/src/vlm/vlm_inference.h | 5 +- 3 files changed, 45 insertions(+), 38 deletions(-) diff --git a/lib/pages/vlm/widgets/image_grid.dart b/lib/pages/vlm/widgets/image_grid.dart index da97641d..eb7b7150 100644 --- a/lib/pages/vlm/widgets/image_grid.dart +++ b/lib/pages/vlm/widgets/image_grid.dart @@ -73,7 +73,6 @@ class _ImageGridState extends State { padding: const EdgeInsets.only(bottom: 20), child: Row( mainAxisAlignment: MainAxisAlignment.start, - spacing: 20, children: List.generate(galleryData.length, (index) { String path = galleryData[index]; bool isLocalFile = File(path) @@ -97,45 +96,48 @@ class _ImageGridState extends State { hoverStates[index] = false; }); }, - child: Stack( - children: [ - AspectRatio( - aspectRatio: 1, - child: Container( - decoration: BoxDecoration( - borderRadius: BorderRadius.circular(10), - color: Colors.black, - image: DecorationImage( - image: isLocalFile - ? FileImage(File(path)) // Load local file - : NetworkImage(path) as ImageProvider, - // Load from network - fit: BoxFit.cover, + child: Padding( + padding: const EdgeInsets.only(right: 20), + child: Stack( + children: [ + AspectRatio( + aspectRatio: 1, + child: Container( + decoration: BoxDecoration( + borderRadius: BorderRadius.circular(10), + color: Colors.black, + image: DecorationImage( + image: isLocalFile + ? FileImage(File(path)) // Load local file + : NetworkImage(path) as ImageProvider, + // Load from network + fit: BoxFit.cover, + ), ), ), ), - ), - if (isHovered) - Positioned( - top: 5, - right: 5, - child: GestureDetector( - onTap: () => removeImage(index), - child: Container( - padding: const EdgeInsets.all(8), - decoration: BoxDecoration( - color: Colors.black.withAlpha(200), - shape: BoxShape.circle, - ), - child: const Icon( - FluentIcons.cancel, - size: 12, - color: Colors.white, + if (isHovered) + Positioned( + top: 5, + right: 5, + child: GestureDetector( + onTap: () => removeImage(index), + child: Container( + padding: const EdgeInsets.all(8), + decoration: BoxDecoration( + color: Colors.black.withAlpha(200), + shape: BoxShape.circle, + ), + child: const Icon( + FluentIcons.cancel, + size: 12, + color: Colors.white, + ), ), ), ), - ), - ], + ], + ), ), ); }), diff --git a/openvino_bindings/src/tti/tti_inference.h b/openvino_bindings/src/tti/tti_inference.h index 427942eb..1adce8e5 100644 --- a/openvino_bindings/src/tti/tti_inference.h +++ b/openvino_bindings/src/tti/tti_inference.h @@ -24,7 +24,13 @@ class TTIInference // Use a lambda to initialize the 'pipe' and measure the construction time in one step ov_pipe([&]() { auto start_time = std::chrono::steady_clock::now(); - ov::genai::Text2ImagePipeline temp_pipe(model_path, device); // Construct the pipe + + ov::AnyMap enable_compile_cache; + // if (device == "GPU") { + enable_compile_cache.insert({ov::cache_dir(model_path + "/cache")}); + // } + + ov::genai::Text2ImagePipeline temp_pipe(model_path, device, enable_compile_cache); // Construct the pipe auto end_time = std::chrono::steady_clock::now(); std::filesystem::path bgr_path = std::filesystem::path(model_path) / "channel_info.json"; diff --git a/openvino_bindings/src/vlm/vlm_inference.h b/openvino_bindings/src/vlm/vlm_inference.h index dd0185f8..bed4dd4c 100644 --- a/openvino_bindings/src/vlm/vlm_inference.h +++ b/openvino_bindings/src/vlm/vlm_inference.h @@ -7,8 +7,7 @@ #ifndef VLM_INFERENCE_H_ #define VLM_INFERENCE_H_ -#include - +#include #include "src/utils/utils.h" #include "openvino/genai/visual_language/pipeline.hpp" @@ -30,7 +29,7 @@ class VLMInference if (device == "GPU") { // Cache compiled models on disk for GPU to save time on the // next run. It's not beneficial for CPU. - enable_compile_cache.insert({ov::cache_dir("vlm_cache")}); + enable_compile_cache.insert({ov::cache_dir(model_path + "/cache")}); } ov_pipe = std::make_unique(model_path, device, enable_compile_cache); From 98be817d71120430ccfb6844f9412a116510f9e8 Mon Sep 17 00:00:00 2001 From: Arend Jan Kramer Date: Tue, 21 Jan 2025 20:40:49 +0100 Subject: [PATCH 08/15] Fix copyright --- lib/pages/text_to_image/widgets/assistant_message.dart | 4 ++++ lib/pages/text_to_image/widgets/user_message.dart | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/lib/pages/text_to_image/widgets/assistant_message.dart b/lib/pages/text_to_image/widgets/assistant_message.dart index 9207b7d3..5fd867b6 100644 --- a/lib/pages/text_to_image/widgets/assistant_message.dart +++ b/lib/pages/text_to_image/widgets/assistant_message.dart @@ -1,3 +1,7 @@ +// Copyright (c) 2024 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 + import 'package:fluent_ui/fluent_ui.dart'; import 'package:inference/providers/text_to_image_inference_provider.dart'; import 'package:inference/theme_fluent.dart'; diff --git a/lib/pages/text_to_image/widgets/user_message.dart b/lib/pages/text_to_image/widgets/user_message.dart index 47e19423..21769a29 100644 --- a/lib/pages/text_to_image/widgets/user_message.dart +++ b/lib/pages/text_to_image/widgets/user_message.dart @@ -1,3 +1,7 @@ +// Copyright (c) 2024 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 + import 'package:flutter_markdown/flutter_markdown.dart'; import 'package:inference/providers/text_to_image_inference_provider.dart'; import 'package:fluent_ui/fluent_ui.dart'; From 2d99e212ab592c3a46b5cbe8b5e415d38ed43440 Mon Sep 17 00:00:00 2001 From: Arend Jan Kramer Date: Wed, 22 Jan 2025 10:19:18 +0100 Subject: [PATCH 09/15] Fix linter issues --- lib/pages/download_model/download_model.dart | 3 ++- lib/pages/text_to_image/playground.dart | 4 ---- lib/pages/vlm/live_inference_pane.dart | 1 - lib/pages/vlm/widgets/image_grid.dart | 3 --- lib/project.dart | 8 ++++++++ lib/providers/vlm_inference_provider.dart | 1 - lib/widgets/controls/drop_area.dart | 2 -- lib/widgets/toolbar_text_input.dart | 8 ++++---- 8 files changed, 14 insertions(+), 16 deletions(-) diff --git a/lib/pages/download_model/download_model.dart b/lib/pages/download_model/download_model.dart index 1053a2be..5067f76d 100644 --- a/lib/pages/download_model/download_model.dart +++ b/lib/pages/download_model/download_model.dart @@ -100,6 +100,7 @@ class _DownloadModelPageState extends State { } Future onClose() async { + final navigator = Navigator.of(context); final result = await showDialog(context: context, builder: (BuildContext context) => ContentDialog( title: const Text("Download in progress"), content: const Text("Press 'continue' to keep downloading the model"), @@ -117,7 +118,7 @@ class _DownloadModelPageState extends State { ); if (result == true && context.mounted) { - GoRouter.of(context).pop(); + navigator.pop(); } } diff --git a/lib/pages/text_to_image/playground.dart b/lib/pages/text_to_image/playground.dart index a35700e7..7564f0c1 100644 --- a/lib/pages/text_to_image/playground.dart +++ b/lib/pages/text_to_image/playground.dart @@ -15,7 +15,6 @@ import 'package:inference/providers/text_to_image_inference_provider.dart'; import 'package:inference/theme_fluent.dart'; import 'package:inference/widgets/device_selector.dart'; import 'package:inference/widgets/toolbar_text_input.dart'; -import 'package:intl/intl.dart'; import 'package:provider/provider.dart'; class TTIPlayground extends StatefulWidget { @@ -89,9 +88,6 @@ class _TTIPlaygroundState extends State { @override Widget build(BuildContext context) { - Locale locale = Localizations.localeOf(context); - final nf = NumberFormat.decimalPatternDigits( - locale: locale.languageCode, decimalDigits: 2); final theme = FluentTheme.of(context); return Row( diff --git a/lib/pages/vlm/live_inference_pane.dart b/lib/pages/vlm/live_inference_pane.dart index 0ba5df0a..73724c8b 100644 --- a/lib/pages/vlm/live_inference_pane.dart +++ b/lib/pages/vlm/live_inference_pane.dart @@ -16,7 +16,6 @@ import 'package:inference/providers/vlm_inference_provider.dart'; import 'package:inference/theme_fluent.dart'; import 'package:inference/widgets/device_selector.dart'; import 'package:inference/widgets/grid_container.dart'; -import 'package:inference/widgets/horizontal_rule.dart'; import 'package:provider/provider.dart'; class VLMPlayground extends StatefulWidget { diff --git a/lib/pages/vlm/widgets/image_grid.dart b/lib/pages/vlm/widgets/image_grid.dart index eb7b7150..38d304f7 100644 --- a/lib/pages/vlm/widgets/image_grid.dart +++ b/lib/pages/vlm/widgets/image_grid.dart @@ -52,9 +52,6 @@ class _ImageGridState extends State { @override Widget build(BuildContext context) { - var width = MediaQuery.of(context).size.width; - var height = MediaQuery.of(context).size.height; - return DropArea( showChild: galleryData.isNotEmpty, onUpload: onDrop, diff --git a/lib/project.dart b/lib/project.dart index edc303b1..05563120 100644 --- a/lib/project.dart +++ b/lib/project.dart @@ -316,6 +316,14 @@ class GetiProject extends Project { return !checks.contains(false); } + + @override + int get hashCode{ + return Object.hash( + id, + const ListEquality().hash(tasks.map((m) => m.id).toList()), + ); + } } class PublicProject extends Project { diff --git a/lib/providers/vlm_inference_provider.dart b/lib/providers/vlm_inference_provider.dart index a9bfbdb9..3882d64f 100644 --- a/lib/providers/vlm_inference_provider.dart +++ b/lib/providers/vlm_inference_provider.dart @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 import 'dart:async'; -import 'dart:isolate'; import 'dart:typed_data'; import 'dart:ui' as ui; import 'package:fluent_ui/fluent_ui.dart'; diff --git a/lib/widgets/controls/drop_area.dart b/lib/widgets/controls/drop_area.dart index 7c4a7c63..23f568d9 100644 --- a/lib/widgets/controls/drop_area.dart +++ b/lib/widgets/controls/drop_area.dart @@ -57,8 +57,6 @@ class _DropAreaState extends State { @override Widget build(BuildContext context) { - final theme = FluentTheme.of(context); - return DropTarget( onDragDone: handleDrop, onDragExited: (_) => hideReleaseMessage(), diff --git a/lib/widgets/toolbar_text_input.dart b/lib/widgets/toolbar_text_input.dart index 3113bf9d..3946abb3 100644 --- a/lib/widgets/toolbar_text_input.dart +++ b/lib/widgets/toolbar_text_input.dart @@ -45,10 +45,10 @@ class _ToolbarTextInputState extends State { if (!_focusNode.hasFocus) { // When the TextBox loses focus, round and update final inputValue = int.tryParse(_controller.text.replaceAll(RegExp(r'[^0-9]'), '')) ?? 0; - // final rounded = _nearestPowerOfTwo(inputValue); - // - // _controller.text = rounded.toString(); - widget.onChanged!(inputValue); + final rounded = _nearestPowerOfTwo(inputValue); + + _controller.text = rounded.toString(); + widget.onChanged!(rounded); } } From 09fda770b9ca0304c8efd79f1547bb11e03d36d3 Mon Sep 17 00:00:00 2001 From: Arend Jan Kramer Date: Wed, 22 Jan 2025 14:11:26 +0100 Subject: [PATCH 10/15] Add test --- .../vlm/widgets/assistant_message_test.dart | 70 +++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 test/pages/vlm/widgets/assistant_message_test.dart diff --git a/test/pages/vlm/widgets/assistant_message_test.dart b/test/pages/vlm/widgets/assistant_message_test.dart new file mode 100644 index 00000000..bd5cb547 --- /dev/null +++ b/test/pages/vlm/widgets/assistant_message_test.dart @@ -0,0 +1,70 @@ +// Copyright (c) 2024 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 + +import 'package:fluent_ui/fluent_ui.dart'; +import 'package:flutter_test/flutter_test.dart'; +import 'package:inference/pages/vlm/widgets/assistant_message.dart'; +import 'package:inference/project.dart'; +import 'package:inference/providers/vlm_inference_provider.dart'; +import 'package:provider/provider.dart'; + +void main() { + TestWidgetsFlutterBinding.ensureInitialized(); + + group('AssistantMessage Widget Tests', () { + late Message testMessage; + late VLMInferenceProvider inferenceProvider; + late PublicProject project; + late Image thumbnail; + setUp(() { + thumbnail = Image.asset('images/logo_50.png'); + project = PublicProject("id", "model_id", "app_version", "name", + "creation_time", ProjectType.vlm, "/path/", thumbnail, null); + + inferenceProvider = VLMInferenceProvider(project, "auto"); + + testMessage = Message( + Speaker.assistant, + "Test message", + null, + DateTime.now(), + true, + ); + }); + + Widget createTestWidget(Widget child) { + return FluentApp( + home: ChangeNotifierProvider.value( + value: inferenceProvider, + child: ScaffoldPage(content: child), + ), + ); + } + + testWidgets('renders correctly with message and icon', + (WidgetTester tester) async { + await tester.pumpWidget(createTestWidget(AssistantMessage(testMessage))); + + expect(find.text("Test message"), findsOneWidget); + + // Find all `Container` widgets and pick the first one + final containerFinder = find.descendant( + of: find.byType(AssistantMessage), + matching: find.byWidgetPredicate( + (widget) => widget is Container && widget.decoration is BoxDecoration, + ), + ); + final container = tester.widget(containerFinder.first); + + final BoxDecoration? decoration = container.decoration as BoxDecoration?; + + expect(decoration, isNotNull, reason: 'BoxDecoration should not be null'); + expect(decoration!.image!.image, isA()); + + final AssetImage image = decoration.image!.image as AssetImage; + final AssetImage image2 = thumbnail.image as AssetImage; + expect(image.assetName, equals(image2.assetName)); + }); + }); +} From cd262937660b1181c37c025908cfee2f4fb9d1b1 Mon Sep 17 00:00:00 2001 From: Arend Jan Kramer Date: Wed, 22 Jan 2025 14:50:35 +0100 Subject: [PATCH 11/15] Add test --- lib/pages/vlm/live_inference_pane.dart | 2 +- lib/widgets/toolbar_text_input.dart | 2 +- .../vlm/widgets/assistant_message_test.dart | 36 ++++++++++++++++--- 3 files changed, 34 insertions(+), 6 deletions(-) diff --git a/lib/pages/vlm/live_inference_pane.dart b/lib/pages/vlm/live_inference_pane.dart index 73724c8b..c1fbdac6 100644 --- a/lib/pages/vlm/live_inference_pane.dart +++ b/lib/pages/vlm/live_inference_pane.dart @@ -121,7 +121,7 @@ class _VLMPlaygroundState extends State { labelText: "Max new tokens", suffix: "", initialValue: provider.maxTokens, - roundPowerOfTwo: true, + roundPowerOfTwo: false, onChanged: (value) { provider.maxTokens = value; }), diff --git a/lib/widgets/toolbar_text_input.dart b/lib/widgets/toolbar_text_input.dart index 3946abb3..84e58ef3 100644 --- a/lib/widgets/toolbar_text_input.dart +++ b/lib/widgets/toolbar_text_input.dart @@ -45,7 +45,7 @@ class _ToolbarTextInputState extends State { if (!_focusNode.hasFocus) { // When the TextBox loses focus, round and update final inputValue = int.tryParse(_controller.text.replaceAll(RegExp(r'[^0-9]'), '')) ?? 0; - final rounded = _nearestPowerOfTwo(inputValue); + final rounded = (widget.roundPowerOfTwo ?? false) ? _nearestPowerOfTwo(inputValue) : inputValue; _controller.text = rounded.toString(); widget.onChanged!(rounded); diff --git a/test/pages/vlm/widgets/assistant_message_test.dart b/test/pages/vlm/widgets/assistant_message_test.dart index bd5cb547..41636f72 100644 --- a/test/pages/vlm/widgets/assistant_message_test.dart +++ b/test/pages/vlm/widgets/assistant_message_test.dart @@ -2,7 +2,11 @@ // // SPDX-License-Identifier: Apache-2.0 +import 'dart:ui'; + import 'package:fluent_ui/fluent_ui.dart'; +import 'package:flutter/services.dart'; +import 'package:flutter_markdown/flutter_markdown.dart'; import 'package:flutter_test/flutter_test.dart'; import 'package:inference/pages/vlm/widgets/assistant_message.dart'; import 'package:inference/project.dart'; @@ -26,7 +30,7 @@ void main() { testMessage = Message( Speaker.assistant, - "Test message", + "Test message Test message Test message", null, DateTime.now(), true, @@ -42,11 +46,11 @@ void main() { ); } - testWidgets('renders correctly with message and icon', - (WidgetTester tester) async { + testWidgets('renders correctly with message', (WidgetTester tester) async { await tester.pumpWidget(createTestWidget(AssistantMessage(testMessage))); - expect(find.text("Test message"), findsOneWidget); + expect( + find.text("Test message Test message Test message"), findsOneWidget); // Find all `Container` widgets and pick the first one final containerFinder = find.descendant( @@ -66,5 +70,29 @@ void main() { final AssetImage image2 = thumbnail.image as AssetImage; expect(image.assetName, equals(image2.assetName)); }); + + testWidgets('copies message to clipboard when copy button is pressed', + (WidgetTester tester) async { + await tester.pumpWidget(createTestWidget(AssistantMessage(testMessage))); + + final assistantMessageFinder = find.byType(AssistantMessage); + expect(assistantMessageFinder, findsOneWidget); + + final markdown = find.byType(MarkdownBody); + + final gesture = await tester.createGesture(kind: PointerDeviceKind.mouse); + await gesture.addPointer(); + final center = tester.getCenter(markdown.first); + print(center); + await gesture.moveTo(center); + + await tester.pumpAndSettle(); + + await tester.tap(find.byIcon(FluentIcons.copy)); + await tester.pump(); + + final clipboardData = await Clipboard.getData(Clipboard.kTextPlain); + expect(clipboardData!.text, "Test message Test message Test message"); + }); }); } From 73c2d4fdeb326f63489826be6639726d0ecc9467 Mon Sep 17 00:00:00 2001 From: Arend Jan Kramer Date: Wed, 22 Jan 2025 15:15:14 +0100 Subject: [PATCH 12/15] Fix test --- test/pages/vlm/widgets/assistant_message_test.dart | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/test/pages/vlm/widgets/assistant_message_test.dart b/test/pages/vlm/widgets/assistant_message_test.dart index 41636f72..8986e837 100644 --- a/test/pages/vlm/widgets/assistant_message_test.dart +++ b/test/pages/vlm/widgets/assistant_message_test.dart @@ -81,18 +81,19 @@ void main() { final markdown = find.byType(MarkdownBody); final gesture = await tester.createGesture(kind: PointerDeviceKind.mouse); + addTearDown(gesture.removePointer); + await gesture.addPointer(); final center = tester.getCenter(markdown.first); - print(center); await gesture.moveTo(center); await tester.pumpAndSettle(); - await tester.tap(find.byIcon(FluentIcons.copy)); - await tester.pump(); + final clipboardIcon = find.byIcon(FluentIcons.copy); + await tester.pumpAndSettle(); + + expect(clipboardIcon, findsOneWidget); - final clipboardData = await Clipboard.getData(Clipboard.kTextPlain); - expect(clipboardData!.text, "Test message Test message Test message"); }); }); } From a3175c3f7abb8b521d010c01f3b71aab57dbb87e Mon Sep 17 00:00:00 2001 From: Arend Jan Kramer Date: Wed, 22 Jan 2025 15:15:43 +0100 Subject: [PATCH 13/15] Fix test --- test/pages/vlm/widgets/assistant_message_test.dart | 1 - 1 file changed, 1 deletion(-) diff --git a/test/pages/vlm/widgets/assistant_message_test.dart b/test/pages/vlm/widgets/assistant_message_test.dart index 8986e837..e1566154 100644 --- a/test/pages/vlm/widgets/assistant_message_test.dart +++ b/test/pages/vlm/widgets/assistant_message_test.dart @@ -5,7 +5,6 @@ import 'dart:ui'; import 'package:fluent_ui/fluent_ui.dart'; -import 'package:flutter/services.dart'; import 'package:flutter_markdown/flutter_markdown.dart'; import 'package:flutter_test/flutter_test.dart'; import 'package:inference/pages/vlm/widgets/assistant_message.dart'; From b2b10747c9d44c46df766a839fdc731209b57b0f Mon Sep 17 00:00:00 2001 From: Arend Jan Kramer Date: Thu, 23 Jan 2025 10:54:08 +0100 Subject: [PATCH 14/15] Simplify image loading and solve memory leak --- openvino_bindings/src/vlm/load_image.cpp | 39 +++++------------------- 1 file changed, 7 insertions(+), 32 deletions(-) diff --git a/openvino_bindings/src/vlm/load_image.cpp b/openvino_bindings/src/vlm/load_image.cpp index 37aaceae..14e19490 100644 --- a/openvino_bindings/src/vlm/load_image.cpp +++ b/openvino_bindings/src/vlm/load_image.cpp @@ -39,47 +39,22 @@ ov::Tensor utils::load_image(const std::filesystem::path& image_path) { std::vector buffer((std::istreambuf_iterator(file)), std::istreambuf_iterator()); cv::Mat cv_image = cv::imdecode(buffer, cv::IMREAD_COLOR); - cv::cvtColor(cv_image, cv_image, cv::COLOR_BGR2RGB); if (cv_image.empty()) { throw std::runtime_error{"Failed to load the image."}; } - // Ensure the image is converted to the desired number of channels + // Convert to RGB + cv::cvtColor(cv_image, cv_image, cv::COLOR_BGR2RGB); + if (cv_image.channels() != desired_channels) { throw std::runtime_error{"The loaded image does not have the desired number of channels."}; } - int width = cv_image.cols; - int height = cv_image.rows; - - struct SharedImageAllocator { - unsigned char* image; - int channels, height, width; - - void* allocate(size_t bytes, size_t) const { - if (image && static_cast(channels * height * width) == bytes) { - return image; - } - throw std::runtime_error{"Unexpected number of bytes was requested to allocate."}; - } - - void deallocate(void*, size_t bytes, size_t) { - if (static_cast(channels * height * width) != bytes) { - throw std::runtime_error{"Unexpected number of bytes was requested to deallocate."}; - } - image = nullptr; // Prevent dangling pointer - } - - bool is_equal(const SharedImageAllocator& other) const noexcept { - return this == &other; - } - }; - - // Wrap OpenCV image data into the custom allocator + // Create OpenVINO tensor directly from OpenCV image data return ov::Tensor( ov::element::u8, - ov::Shape{1, static_cast(height), static_cast(width), static_cast(desired_channels)}, - SharedImageAllocator{cv_image.data, desired_channels, height, width} + ov::Shape{1, static_cast(cv_image.rows), static_cast(cv_image.cols), static_cast(desired_channels)}, + cv_image.data // Directly pass OpenCV's memory buffer ); -} +} \ No newline at end of file From b0cdbd9ab4270c9392fdd8c886686485c7715ca6 Mon Sep 17 00:00:00 2001 From: Arend Jan Kramer Date: Thu, 23 Jan 2025 12:00:31 +0100 Subject: [PATCH 15/15] Fix vlm test --- openvino_bindings/src/vlm/vlm_inference_test.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/openvino_bindings/src/vlm/vlm_inference_test.cc b/openvino_bindings/src/vlm/vlm_inference_test.cc index a926208c..b14c9d6f 100644 --- a/openvino_bindings/src/vlm/vlm_inference_test.cc +++ b/openvino_bindings/src/vlm/vlm_inference_test.cc @@ -6,11 +6,12 @@ #include "gtest/gtest.h" -#include "tti_inference.h" +#include "vlm_inference.h" TEST(VLMInference, Sanity) { - std::string model_path = "data/TinyLlama-1.1B-Chat-v1.0-int4-ov"; - LLMInference inference(model_path, "CPU"); - std::string output = inference.prompt("What is the color of the sun?", 1.0f, 1.0f); - EXPECT_STREQ(output.c_str(), "The color of the sun is a beautiful and awe-inspiring yellow-amber color. It is a natural, radiant, and beautiful color that is associated with warmth, light, and lightning. The sun is often depicted as a radiant, yellow-amber ball of light that shines down on the earth, illuminating the world and inspiring wonder and awe in all who see it."); + std::string model_path = std::filesystem::absolute("data/OpenGVLab-InternVL2-4B-ov-fp16"); + VLMInference inference(model_path, "CPU"); + inference.setImagePaths({ std::filesystem::absolute("data/images/cat-in-box.jpg")}); + VLMStringWithMetrics output = inference.prompt("what do you see", 200); + EXPECT_STREQ(output.string, "In the image, there is a cat lying comfortably inside a cardboard box. The cat appears to be relaxed and content, with its eyes closed and a peaceful expression on its face. The box is placed on a carpeted floor, and in the background, there is a white sofa or couch. The overall setting suggests a cozy and homely environment."); }