From 806177c1ac8a1daa791cae6c789d069a89a56142 Mon Sep 17 00:00:00 2001 From: Dmitry Kalinin Date: Thu, 28 Nov 2024 14:05:23 +0100 Subject: [PATCH 1/5] Added text generation screen --- lib/pages/models/inference.dart | 3 +- lib/pages/text_generation/playground.dart | 245 ++++++++++++++++++ .../text_generation/text_generation.dart | 125 +++++++++ .../widgets/assistant_message.dart | 160 ++++++++++++ .../widgets/model_properties.dart | 58 +++++ .../text_generation/widgets/user_message.dart | 32 +++ lib/providers/project_provider.dart | 2 - lib/theme_fluent.dart | 3 + pubspec.lock | 16 ++ pubspec.yaml | 1 + 10 files changed, 642 insertions(+), 3 deletions(-) create mode 100644 lib/pages/text_generation/playground.dart create mode 100644 lib/pages/text_generation/text_generation.dart create mode 100644 lib/pages/text_generation/widgets/assistant_message.dart create mode 100644 lib/pages/text_generation/widgets/model_properties.dart create mode 100644 lib/pages/text_generation/widgets/user_message.dart diff --git a/lib/pages/models/inference.dart b/lib/pages/models/inference.dart index 1445420b..5e7231e9 100644 --- a/lib/pages/models/inference.dart +++ b/lib/pages/models/inference.dart @@ -1,5 +1,6 @@ import 'package:fluent_ui/fluent_ui.dart'; import 'package:inference/pages/computer_vision/computer_vision.dart'; +import 'package:inference/pages/text_generation/text_generation.dart'; import 'package:inference/project.dart'; class InferencePage extends StatelessWidget { @@ -12,7 +13,7 @@ class InferencePage extends StatelessWidget { case ProjectType.image: return ComputerVisionPage(project); case ProjectType.text: - return Container(); + return TextGenerationPage(project); case ProjectType.speech: return Container(); } diff --git a/lib/pages/text_generation/playground.dart b/lib/pages/text_generation/playground.dart new file mode 100644 index 00000000..8100c290 --- /dev/null +++ b/lib/pages/text_generation/playground.dart @@ -0,0 +1,245 @@ + +import 'package:fluent_ui/fluent_ui.dart'; +import 'package:flutter/services.dart'; +import 'package:inference/pages/models/widgets/grid_container.dart'; +import 'package:inference/pages/text_generation/widgets/assistant_message.dart'; +import 'package:inference/pages/text_generation/widgets/model_properties.dart'; +import 'package:inference/pages/text_generation/widgets/user_message.dart'; +import 'package:inference/project.dart'; +import 'package:inference/providers/text_inference_provider.dart'; +import 'package:inference/theme_fluent.dart'; +import 'package:inference/widgets/device_selector.dart'; +import 'package:intl/intl.dart'; +import 'package:provider/provider.dart'; + +class Playground extends StatefulWidget { + final Project project; + + const Playground({required this.project, super.key}); + + + @override + _PlaygroundState createState() => _PlaygroundState(); +} + +class SubmitMessageIntent extends Intent {} + +class _PlaygroundState extends State { + final textController = TextEditingController(); + final scrollController = ScrollController(); + bool attachedToBottom = true; + + void jumpToBottom({ offset = 0 }) { + if (scrollController.hasClients) { + scrollController.jumpTo(scrollController.position.maxScrollExtent + offset); + } + } + + void message(String message) { + if (message.isEmpty) return; + final provider = Provider.of(context, listen: false); + if (!provider.initialized || provider.response != null) return; + textController.text = ''; + jumpToBottom(offset: 110); //move to bottom including both + // TODO: add error handling + provider.message(message).catchError((e) { print(e); }); + } + + @override + void initState() { + super.initState(); + scrollController.addListener(() { + setState(() { + attachedToBottom = scrollController.position.pixels + 0.001 >= scrollController.position.maxScrollExtent; + }); + }); + } + + @override + void dispose() { + textController.dispose(); + scrollController.dispose(); + super.dispose(); + } + + @override + void didChangeDependencies() { + super.didChangeDependencies(); + if (attachedToBottom) { + jumpToBottom(); + } + } + + @override + Widget build(BuildContext context) { + Locale locale = Localizations.localeOf(context); + final nf = NumberFormat.decimalPatternDigits( + locale: locale.languageCode, decimalDigits: 2); + final theme = FluentTheme.of(context); + + return Row( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Consumer(builder: (context, provider, child) => + Expanded(child: Column( + children: [ + SizedBox( + height: 64, + child: GridContainer( + child: Padding( + padding: const EdgeInsets.symmetric(horizontal: 16), + child: Row( + children: [ + const DeviceSelector(), + const Divider(size: 24,direction: Axis.vertical,), + const SizedBox(width: 24,), + const Text('Temperature: '), + Slider( + value: provider.temperature, + onChanged: (value) { provider.temperature = value; }, + label: nf.format(provider.temperature), + min: 0.1, + max: 2.0, + ), + const SizedBox(width: 24,), + const Text('Top P: '), + Slider( + value: provider.topP, + onChanged: (value) { provider.topP = value; }, + label: nf.format(provider.topP), + max: 1.0, + min: 0.1, + ), + ], + ) + ), + ), + ), + Expanded(child: DecoratedBox( + decoration: BoxDecoration( + color: theme.brightness.isDark ? backgroundColor.dark : theme.scaffoldBackgroundColor + ), + child: GridContainer(child: SizedBox( + width: double.infinity, + child: Builder(builder: (context) { + if (!provider.initialized) { + return const Column( + mainAxisAlignment: MainAxisAlignment.center, + children: [ + SizedBox(width: 64,height: 64, child: ProgressRing()), + Padding( + padding: EdgeInsets.only(top: 18), + child: Text("Loading model..."), + ) + ], + ); + } + return Column( + children: [ + Expanded( + child: Builder(builder: (context) { + if (provider.messages.isEmpty) { + return Center( + child: Text("Start chatting with ${provider.project?.name ?? "the model"}!"), + ); + } + return Stack( + alignment: Alignment.bottomCenter, + children: [ + SingleChildScrollView( + controller: scrollController, + child: Padding(padding: const EdgeInsets.symmetric(horizontal: 64, vertical: 20), child: Column( + children: provider.messages.map((message) { switch (message.speaker) { + case Speaker.user: return UserMessage(message); + case Speaker.system: return Text('System: ${message.message}'); + case Speaker.assistant: return AssistantMessage(message); + }}).toList(), + ),), + ), + Positioned( + bottom: 10, + child: Builder(builder: (context) => attachedToBottom + ? const SizedBox() + : Padding( + padding: const EdgeInsets.only(top:2), + child: FilledButton(child: const Row( + children: [ + Icon(FluentIcons.chevron_down, size: 12), + SizedBox(width: 4), + Text('Scroll to bottom'), + ], + ), onPressed: () { + jumpToBottom(); + setState(() { + attachedToBottom = true; + }); + }), + ) + ), + ) + ], + ); + }), + ), + Padding( + padding: const EdgeInsets.symmetric(horizontal: 64, vertical: 24), + child: Row( + crossAxisAlignment: CrossAxisAlignment.end, + mainAxisAlignment: MainAxisAlignment.center, + children: [ + Tooltip( + message: "Create new thread", + child: Button(child: const Icon(FluentIcons.rocket, size: 18,), onPressed: () {}), + ), + Expanded( + child: Padding( + padding: const EdgeInsets.symmetric(horizontal: 8), + child: Shortcuts( + shortcuts: { + LogicalKeySet(LogicalKeyboardKey.meta, LogicalKeyboardKey.enter): SubmitMessageIntent(), + LogicalKeySet(LogicalKeyboardKey.control, LogicalKeyboardKey.enter): SubmitMessageIntent(), + }, + child: Actions( + actions: >{ + SubmitMessageIntent: CallbackAction( + onInvoke: (SubmitMessageIntent intent) => message(textController.text), + ), + }, + child: TextBox( + placeholder: "Type a message...", + keyboardType: TextInputType.text, + controller: textController, + maxLines: null, + expands: true, + onSubmitted: message, + autofocus: true, + ), + ), + ), + ), + ), + Builder(builder: (context) => provider.interimResponse != null + ? Tooltip( + message: "Stop", + child: Button(child: const Icon(FluentIcons.stop, size: 18,), onPressed: () { provider.forceStop(); }), + ) + : Tooltip( + message: "Send message", + child: Button(child: const Icon(FluentIcons.send, size: 18,), onPressed: () { message(textController.text); }), + ) + ) + ] + ), + ) + ], + ); + }), + )), + )), + ], + ))), + const ModelProperties(), + ], + ); + } +} \ No newline at end of file diff --git a/lib/pages/text_generation/text_generation.dart b/lib/pages/text_generation/text_generation.dart new file mode 100644 index 00000000..fa9ba926 --- /dev/null +++ b/lib/pages/text_generation/text_generation.dart @@ -0,0 +1,125 @@ + +import 'package:fluent_ui/fluent_ui.dart'; +import 'package:go_router/go_router.dart'; +import 'package:inference/pages/text_generation/playground.dart'; +import 'package:inference/project.dart'; +import 'package:inference/providers/preference_provider.dart'; +import 'package:inference/providers/text_inference_provider.dart'; +import 'package:provider/provider.dart'; + +class TextGenerationPage extends StatefulWidget { + final Project project; + const TextGenerationPage(this.project, {super.key}); + + @override + State createState() => _TextGenerationPageState(); +} + +class _TextGenerationPageState extends State { + int selected = 0; + + @override + Widget build(BuildContext context) { + final theme = FluentTheme.of(context); + final updatedTheme = theme.copyWith( + navigationPaneTheme: theme.navigationPaneTheme.merge(NavigationPaneThemeData( + backgroundColor: theme.scaffoldBackgroundColor, + )) + ); + + return ChangeNotifierProxyProvider( + create: (_) { + return TextInferenceProvider(widget.project, null); + }, + update: (_, preferences, textInferenceProvider) { + final init = textInferenceProvider == null || + !textInferenceProvider.sameProps(widget.project, preferences.device); + if (init) { + final textInferenceProvider = TextInferenceProvider(widget.project, preferences.device); + textInferenceProvider.loadModel().catchError((e) { + // TODO: Error handling + print(e); + }); + return textInferenceProvider; + } + if (!textInferenceProvider.sameProps(widget.project, preferences.device)) { + return TextInferenceProvider(widget.project, preferences.device); + } + return textInferenceProvider; + }, + child: Stack( + children: [ + FluentTheme( + data: updatedTheme, + child: NavigationView( + pane: NavigationPane( + size: const NavigationPaneSize(topHeight: 64), + header: Row( + children: [ + Padding( + padding: const EdgeInsets.only(left: 12.0), + child: ClipRRect( + borderRadius: BorderRadius.circular(4.0), + child: Container( + width: 40, + height: 40, + decoration: BoxDecoration( + image: DecorationImage( + image: widget.project.thumbnailImage(), + fit: BoxFit.cover), + ), + ), + ), + ), + Padding( + padding: const EdgeInsets.symmetric(horizontal: 16), + child: Text(widget.project.name, + style: const TextStyle(fontSize: 20, fontWeight: FontWeight.bold), + ), + ), + ], + ), + selected: selected, + onChanged: (i) => setState(() {selected = i;}), + displayMode: PaneDisplayMode.top, + items: [ + PaneItem( + icon: const Icon(FluentIcons.game), + title: const Text("Playground"), + body: Playground(project: widget.project), + ), + ], + ) + ), + ), + SizedBox( + height: 64, + child: Padding( + padding: const EdgeInsets.symmetric(horizontal: 25), + child: Row( + mainAxisAlignment: MainAxisAlignment.end, + children: [ + Padding( + padding: const EdgeInsets.all(4), + child: OutlinedButton( + style: ButtonStyle( + shape:WidgetStatePropertyAll(RoundedRectangleBorder( + borderRadius: BorderRadius.circular(4.0), + side: const BorderSide(color: Color(0XFF545454)), + )), + ), + child: const Text("Close"), + onPressed: () => GoRouter.of(context).canPop() + ? GoRouter.of(context).pop() + : GoRouter.of(context).push('/models'), + ), + ), + ] + ), + ), + ) + ], + ), + ); + } +} \ No newline at end of file diff --git a/lib/pages/text_generation/widgets/assistant_message.dart b/lib/pages/text_generation/widgets/assistant_message.dart new file mode 100644 index 00000000..dd13cf84 --- /dev/null +++ b/lib/pages/text_generation/widgets/assistant_message.dart @@ -0,0 +1,160 @@ +import 'package:fluent_ui/fluent_ui.dart'; +import 'package:flutter/services.dart'; +import 'package:flutter_markdown/flutter_markdown.dart'; +import 'package:inference/providers/text_inference_provider.dart'; +import 'package:inference/theme_fluent.dart'; +import 'package:intl/intl.dart'; +import 'package:provider/provider.dart'; + +class AssistantMessage extends StatefulWidget { + final Message message; + const AssistantMessage(this.message, {super.key}); + + @override + _AssistantMessageState createState() => _AssistantMessageState(); +} + +class _AssistantMessageState extends State { + bool _hovering = false; + + @override + Widget build(BuildContext context) { + Locale locale = Localizations.localeOf(context); + final nf = NumberFormat.decimalPatternDigits( + locale: locale.languageCode, decimalDigits: 0); + final theme = FluentTheme.of(context); + final backgroundColor = theme.brightness.isDark + ? theme.scaffoldBackgroundColor + : const Color(0xFFF5F5F5); + + return Consumer(builder: (context, inferenceProvider, child) => + Align( + alignment: Alignment.centerLeft, + child: Padding( + padding: const EdgeInsets.only(bottom: 20), + child: Row( + crossAxisAlignment: CrossAxisAlignment.end, + children: [ + Padding( + padding: const EdgeInsets.only(right: 10, bottom: 36), + child: ClipRRect( + borderRadius: BorderRadius.circular(16.0), + child: Container( + width: 32, + height: 32, + decoration: BoxDecoration( + image: DecorationImage( + image: inferenceProvider.project!.thumbnailImage(), + fit: BoxFit.cover, + ), + ), + ), + ), + ), + Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Padding( + padding: const EdgeInsets.only(bottom: 4), + child: Text( + inferenceProvider.project!.name, + style: TextStyle( + color: subtleTextColor.of(theme), + ), + ), + ), + MouseRegion( + onEnter: (_) => setState(() { _hovering = true; }), + onExit: (_) => setState(() { _hovering = false; }), + child: Column( + crossAxisAlignment: CrossAxisAlignment.end, + children: [ + Container( + constraints: BoxConstraints(maxWidth: MediaQuery.of(context).size.width - 502), + decoration: BoxDecoration( + color: backgroundColor, + borderRadius: BorderRadius.circular(4), + ), + child: Markdown( + data: widget.message.message, + selectable: true, + shrinkWrap: true, + padding: const EdgeInsets.all(8), + physics: const NeverScrollableScrollPhysics(), + ), + ), + if (_hovering) + Padding( + padding: const EdgeInsets.only(top: 4), + child: Row( + children: [ + if (widget.message.metrics != null) Padding( + padding: const EdgeInsets.only(right: 8), + child: Tooltip( + message: 'Time to first token', + child: Text( + 'TTF: ${nf.format(widget.message.metrics!.ttft)}ms', + style: TextStyle( + fontSize: 12, + color: subtleTextColor.of(theme), + ), + ), + ), + ), + if (widget.message.metrics != null) Padding( + padding: const EdgeInsets.only(right: 8), + child: Tooltip( + message: 'Time per output token', + child: Text( + 'TPOT: ${nf.format(widget.message.metrics!.tpot)}ms', + style: TextStyle( + fontSize: 12, + color: subtleTextColor.of(theme), + ), + ), + ), + ), + if (widget.message.metrics != null) Padding( + padding: const EdgeInsets.only(right: 8), + child: Tooltip( + message: 'Generate total duration', + child: Text( + 'Generate: ${nf.format(widget.message.metrics!.generate_time/1000)}s', + style: TextStyle( + fontSize: 12, + color: subtleTextColor.of(theme), + ), + ), + ), + ), + IconButton( + icon: const Icon(FluentIcons.copy), + onPressed: () async{ + await displayInfoBar(context, builder: (context, close) => + InfoBar( + title: const Text('Copied to clipboard'), + severity: InfoBarSeverity.info, + action: IconButton( + icon: const Icon(FluentIcons.clear), + onPressed: close, + ), + ), + ); + Clipboard.setData(ClipboardData(text: widget.message.message)); + }, + ), + ], + ), + ) else const SizedBox(height: 34) + ], + ), + ), + ], + ), + ], + ), + ) + ), + ); + } +} \ No newline at end of file diff --git a/lib/pages/text_generation/widgets/model_properties.dart b/lib/pages/text_generation/widgets/model_properties.dart new file mode 100644 index 00000000..cd6bc53b --- /dev/null +++ b/lib/pages/text_generation/widgets/model_properties.dart @@ -0,0 +1,58 @@ +import 'package:flutter/widgets.dart'; +import 'package:inference/pages/computer_vision/widgets/model_properties.dart'; +import 'package:inference/pages/models/widgets/grid_container.dart'; +import 'package:inference/providers/text_inference_provider.dart'; +import 'package:intl/intl.dart'; +import 'package:provider/provider.dart'; + +class ModelProperties extends StatelessWidget { + const ModelProperties({super.key}); + + @override + Widget build(BuildContext context) { + Locale locale = Localizations.localeOf(context); + final nf = NumberFormat.decimalPatternDigits( + locale: locale.languageCode, decimalDigits: 2); + + return Consumer(builder: (context, inference, child) { + return SizedBox( + width: 280, + child: GridContainer( + padding: const EdgeInsets.symmetric(vertical: 18, horizontal: 24), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + const Text("Model parameters", style: TextStyle( + fontSize: 20, + )), + Container( + padding: const EdgeInsets.only(top: 16), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + ModelProperty( + title: "Model name", + value: inference.project!.name, + ), + ModelProperty( + title: "Architecture", + value: inference.project!.architecture, + ), + ModelProperty( + title: "Temperature", + value: nf.format(inference.temperature), + ), + ModelProperty( + title: "Top P", + value: nf.format(inference.topP), + ), + ] + ) + ), + ] + ) + ) + ); + }); + } +} \ No newline at end of file diff --git a/lib/pages/text_generation/widgets/user_message.dart b/lib/pages/text_generation/widgets/user_message.dart new file mode 100644 index 00000000..e440d0fc --- /dev/null +++ b/lib/pages/text_generation/widgets/user_message.dart @@ -0,0 +1,32 @@ +import 'package:fluent_ui/fluent_ui.dart'; +import 'package:flutter/material.dart'; +import 'package:inference/providers/text_inference_provider.dart'; +import 'package:inference/theme_fluent.dart'; + +class UserMessage extends StatelessWidget { + final Message message; + const UserMessage(this.message, {super.key}); + + @override + Widget build(BuildContext context) { + final theme = FluentTheme.of(context); + return Align( + alignment: Alignment.centerRight, + child: Padding(padding: const EdgeInsets.only(bottom: 20), child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Container( + decoration: BoxDecoration( + color: cosmosBackground.of(theme), + borderRadius: BorderRadius.circular(4), + ), + child: Padding( + padding: const EdgeInsets.all(8.0), + child: SelectableText(message.message,), + ), + ) + ], + ),), + ); + } +} \ No newline at end of file diff --git a/lib/providers/project_provider.dart b/lib/providers/project_provider.dart index e9d94542..cc4670c7 100644 --- a/lib/providers/project_provider.dart +++ b/lib/providers/project_provider.dart @@ -1,5 +1,3 @@ -import 'dart:async'; - import 'package:collection/collection.dart'; import 'package:flutter/material.dart'; import 'package:inference/deployment_processor.dart'; diff --git a/lib/theme_fluent.dart b/lib/theme_fluent.dart index 25fff0a4..1c17a629 100644 --- a/lib/theme_fluent.dart +++ b/lib/theme_fluent.dart @@ -103,12 +103,15 @@ class DarkLightColor { const borderColor = DarkLightColor(Color(0xFFF0F0F0), Color(0xFF3B3B3B)); const backgroundColor = DarkLightColor(Color(0xFFF9F9F9), Color(0xFF282828)); const subtleTextColor = DarkLightColor(Color(0xFF616161), Color(0xFF9F9F9F)); +const neutralBackground = DarkLightColor(Color(0xFFF5F5F5), Color(0xFF343434)); +const cosmosBackground = DarkLightColor(Color(0xFFEFEAFF), Color(0xFF463d66)); final AccentColor electricCosmos = AccentColor.swatch(const { 'normal': Color(0xFF7000FF), }); final AccentColor cosmos = AccentColor.swatch(const { + 'darkest': Color(0xFF463d66), 'normal': Color(0xFFAF98FF), 'lightest': Color(0xFFEFEAFF), }); diff --git a/pubspec.lock b/pubspec.lock index 9e93ab40..498a26de 100644 --- a/pubspec.lock +++ b/pubspec.lock @@ -341,6 +341,14 @@ packages: description: flutter source: sdk version: "0.0.0" + flutter_markdown: + dependency: "direct main" + description: + name: flutter_markdown + sha256: "255b00afa1a7bad19727da6a7780cf3db6c3c12e68d302d85e0ff1fdf173db9e" + url: "https://pub.dev" + source: hosted + version: "0.7.4+3" flutter_plugin_android_lifecycle: dependency: transitive description: @@ -545,6 +553,14 @@ packages: url: "https://pub.dev" source: hosted version: "0.1.2-main.4" + markdown: + dependency: transitive + description: + name: markdown + sha256: ef2a1298144e3f985cc736b22e0ccdaf188b5b3970648f2d9dc13efd1d9df051 + url: "https://pub.dev" + source: hosted + version: "7.2.2" matcher: dependency: transitive description: diff --git a/pubspec.yaml b/pubspec.yaml index 5e35670e..6c06fb78 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -56,6 +56,7 @@ dependencies: fluent_ui: ^4.9.2 system_theme: ^3.1.2 flutter_acrylic: ^1.1.4 + flutter_markdown: ^0.7.4+3 dev_dependencies: flutter_test: From d4f149f298ceb61f88a897a409636e4e8e182f50 Mon Sep 17 00:00:00 2001 From: Dmitry Kalinin Date: Fri, 29 Nov 2024 12:09:50 +0100 Subject: [PATCH 2/5] Fixed minor issues --- lib/pages/text_generation/playground.dart | 124 +++++++++++------- .../widgets/assistant_message.dart | 40 ++++-- .../text_generation/widgets/user_message.dart | 42 +++--- lib/providers/text_inference_provider.dart | 18 ++- pubspec.lock | 2 +- pubspec.yaml | 1 + 6 files changed, 148 insertions(+), 79 deletions(-) diff --git a/lib/pages/text_generation/playground.dart b/lib/pages/text_generation/playground.dart index 8100c290..be882913 100644 --- a/lib/pages/text_generation/playground.dart +++ b/lib/pages/text_generation/playground.dart @@ -1,3 +1,4 @@ +import 'dart:io'; import 'package:fluent_ui/fluent_ui.dart'; import 'package:flutter/services.dart'; @@ -41,8 +42,18 @@ class _PlaygroundState extends State { if (!provider.initialized || provider.response != null) return; textController.text = ''; jumpToBottom(offset: 110); //move to bottom including both - // TODO: add error handling - provider.message(message).catchError((e) { print(e); }); + provider.message(message).catchError((e) async { + // ignore: use_build_context_synchronously + await displayInfoBar(context, builder: (context, close) => InfoBar( + title: const Text("An error occurred processing the message"), + content: Text(e.toString()), + severity: InfoBarSeverity.error, + action: IconButton( + icon: const Icon(FluentIcons.clear), + onPressed: close, + ), + )); + }); } @override @@ -150,7 +161,10 @@ class _PlaygroundState extends State { controller: scrollController, child: Padding(padding: const EdgeInsets.symmetric(horizontal: 64, vertical: 20), child: Column( children: provider.messages.map((message) { switch (message.speaker) { - case Speaker.user: return UserMessage(message); + case Speaker.user: return Padding( + padding: const EdgeInsets.only(left: 42), + child: UserMessage(message), + ); case Speaker.system: return Text('System: ${message.message}'); case Speaker.assistant: return AssistantMessage(message); }}).toList(), @@ -183,52 +197,74 @@ class _PlaygroundState extends State { ), Padding( padding: const EdgeInsets.symmetric(horizontal: 64, vertical: 24), - child: Row( - crossAxisAlignment: CrossAxisAlignment.end, - mainAxisAlignment: MainAxisAlignment.center, + child: Column( children: [ - Tooltip( - message: "Create new thread", - child: Button(child: const Icon(FluentIcons.rocket, size: 18,), onPressed: () {}), - ), - Expanded( - child: Padding( - padding: const EdgeInsets.symmetric(horizontal: 8), - child: Shortcuts( - shortcuts: { - LogicalKeySet(LogicalKeyboardKey.meta, LogicalKeyboardKey.enter): SubmitMessageIntent(), - LogicalKeySet(LogicalKeyboardKey.control, LogicalKeyboardKey.enter): SubmitMessageIntent(), - }, - child: Actions( - actions: >{ - SubmitMessageIntent: CallbackAction( - onInvoke: (SubmitMessageIntent intent) => message(textController.text), + Row( + crossAxisAlignment: CrossAxisAlignment.end, + mainAxisAlignment: MainAxisAlignment.center, + children: [ + Padding( + padding: const EdgeInsets.only(bottom: 20), + child: Tooltip( + message: "Create new thread", + child: Button(child: const Icon(FluentIcons.rocket, size: 18,), onPressed: () { provider.reset(); }), + ), + ), + Expanded( + child: Padding( + padding: const EdgeInsets.symmetric(horizontal: 8), + child: Shortcuts( + shortcuts: { + LogicalKeySet(LogicalKeyboardKey.meta, LogicalKeyboardKey.enter): SubmitMessageIntent(), + LogicalKeySet(LogicalKeyboardKey.control, LogicalKeyboardKey.enter): SubmitMessageIntent(), + }, + child: Actions( + actions: >{ + SubmitMessageIntent: CallbackAction( + onInvoke: (SubmitMessageIntent intent) => message(textController.text), + ), + }, + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + TextBox( + placeholder: "Type a message...", + keyboardType: TextInputType.multiline, + controller: textController, + maxLines: null, + expands: true, + onSubmitted: message, + autofocus: true, + ), + Padding( + padding: const EdgeInsets.only(top: 6, left: 10), + child: Text( + 'Press ${Platform.isMacOS ? '⌘' : 'Ctrl'} + Enter to submit, Enter for newline', + style: TextStyle(fontSize: 11, color: subtleTextColor.of(theme)), + ), + ), + ], + ), ), - }, - child: TextBox( - placeholder: "Type a message...", - keyboardType: TextInputType.text, - controller: textController, - maxLines: null, - expands: true, - onSubmitted: message, - autofocus: true, ), - ), ), - ), + ), + Padding( + padding: const EdgeInsets.only(bottom: 20), + child: Builder(builder: (context) => provider.interimResponse != null + ? Tooltip( + message: "Stop", + child: Button(child: const Icon(FluentIcons.stop, size: 18,), onPressed: () { provider.forceStop(); }), + ) + : Tooltip( + message: "Send message", + child: Button(child: const Icon(FluentIcons.send, size: 18,), onPressed: () { message(textController.text); }), + ) + ), + ) + ] ), - Builder(builder: (context) => provider.interimResponse != null - ? Tooltip( - message: "Stop", - child: Button(child: const Icon(FluentIcons.stop, size: 18,), onPressed: () { provider.forceStop(); }), - ) - : Tooltip( - message: "Send message", - child: Button(child: const Icon(FluentIcons.send, size: 18,), onPressed: () { message(textController.text); }), - ) - ) - ] + ], ), ) ], diff --git a/lib/pages/text_generation/widgets/assistant_message.dart b/lib/pages/text_generation/widgets/assistant_message.dart index dd13cf84..1ff268f1 100644 --- a/lib/pages/text_generation/widgets/assistant_message.dart +++ b/lib/pages/text_generation/widgets/assistant_message.dart @@ -1,6 +1,7 @@ import 'package:fluent_ui/fluent_ui.dart'; import 'package:flutter/services.dart'; import 'package:flutter_markdown/flutter_markdown.dart'; +import 'package:markdown/markdown.dart' as md; import 'package:inference/providers/text_inference_provider.dart'; import 'package:inference/theme_fluent.dart'; import 'package:intl/intl.dart'; @@ -31,7 +32,7 @@ class _AssistantMessageState extends State { Align( alignment: Alignment.centerLeft, child: Padding( - padding: const EdgeInsets.only(bottom: 20), + padding: const EdgeInsets.only(bottom: 8), child: Row( crossAxisAlignment: CrossAxisAlignment.end, children: [ @@ -56,11 +57,21 @@ class _AssistantMessageState extends State { children: [ Padding( padding: const EdgeInsets.only(bottom: 4), - child: Text( - inferenceProvider.project!.name, - style: TextStyle( - color: subtleTextColor.of(theme), - ), + child: Row( + children: [ + Text( + inferenceProvider.project!.name, + style: TextStyle( + color: subtleTextColor.of(theme), + ), + ), + if (widget.message.time != null) Text( + DateFormat(' | yyyy-MM-dd HH:mm:ss').format(widget.message.time!), + style: TextStyle( + color: subtleTextColor.of(theme), + ), + ), + ], ), ), MouseRegion( @@ -79,8 +90,11 @@ class _AssistantMessageState extends State { data: widget.message.message, selectable: true, shrinkWrap: true, - padding: const EdgeInsets.all(8), - physics: const NeverScrollableScrollPhysics(), + padding: const EdgeInsets.all(12), + extensionSet: md.ExtensionSet( + md.ExtensionSet.gitHubFlavored.blockSyntaxes, + [md.EmojiSyntax(), ...md.ExtensionSet.gitHubFlavored.inlineSyntaxes], + ), ), ), if (_hovering) @@ -134,10 +148,10 @@ class _AssistantMessageState extends State { InfoBar( title: const Text('Copied to clipboard'), severity: InfoBarSeverity.info, - action: IconButton( - icon: const Icon(FluentIcons.clear), - onPressed: close, - ), + action: IconButton( + icon: const Icon(FluentIcons.clear), + onPressed: close, + ), ), ); Clipboard.setData(ClipboardData(text: widget.message.message)); @@ -153,7 +167,7 @@ class _AssistantMessageState extends State { ), ], ), - ) + ), ), ); } diff --git a/lib/pages/text_generation/widgets/user_message.dart b/lib/pages/text_generation/widgets/user_message.dart index e440d0fc..5e73496b 100644 --- a/lib/pages/text_generation/widgets/user_message.dart +++ b/lib/pages/text_generation/widgets/user_message.dart @@ -1,5 +1,6 @@ import 'package:fluent_ui/fluent_ui.dart'; -import 'package:flutter/material.dart'; +import 'package:flutter_markdown/flutter_markdown.dart'; +import 'package:markdown/markdown.dart' as md; import 'package:inference/providers/text_inference_provider.dart'; import 'package:inference/theme_fluent.dart'; @@ -12,21 +13,30 @@ class UserMessage extends StatelessWidget { final theme = FluentTheme.of(context); return Align( alignment: Alignment.centerRight, - child: Padding(padding: const EdgeInsets.only(bottom: 20), child: Column( - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - Container( - decoration: BoxDecoration( - color: cosmosBackground.of(theme), - borderRadius: BorderRadius.circular(4), - ), - child: Padding( - padding: const EdgeInsets.all(8.0), - child: SelectableText(message.message,), - ), - ) - ], - ),), + child: ConstrainedBox( + constraints: const BoxConstraints(maxWidth: 1000), + child: Padding(padding: const EdgeInsets.only(bottom: 20), child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Container( + decoration: BoxDecoration( + color: cosmosBackground.of(theme), + borderRadius: BorderRadius.circular(4), + ), + child: Markdown( + data: message.message, + selectable: true, + shrinkWrap: true, + padding: const EdgeInsets.all(12), + extensionSet: md.ExtensionSet( + md.ExtensionSet.gitHubFlavored.blockSyntaxes, + [md.EmojiSyntax(), ...md.ExtensionSet.gitHubFlavored.inlineSyntaxes], + ), + ), + ) + ], + ),), + ), ); } } \ No newline at end of file diff --git a/lib/providers/text_inference_provider.dart b/lib/providers/text_inference_provider.dart index bfa278d7..443464f3 100644 --- a/lib/providers/text_inference_provider.dart +++ b/lib/providers/text_inference_provider.dart @@ -33,7 +33,8 @@ class Message { final Speaker speaker; final String message; final Metrics? metrics; - const Message(this.speaker, this.message, this.metrics); + final DateTime? time; + const Message(this.speaker, this.message, this.metrics, this.time); } class TextInferenceProvider extends ChangeNotifier { @@ -137,7 +138,7 @@ class TextInferenceProvider extends ChangeNotifier { if (_response == null) { return null; } - return Message(Speaker.assistant, response!, null); + return Message(Speaker.assistant, response!, null, null); } List get messages { @@ -149,12 +150,12 @@ class TextInferenceProvider extends ChangeNotifier { Future message(String message) async { _response = "..."; - _messages.add(Message(Speaker.user, message, null)); + _messages.add(Message(Speaker.user, message, null, DateTime.now())); notifyListeners(); final response = await _inference!.prompt(message, temperature, topP); if (_messages.isNotEmpty) { - _messages.add(Message(Speaker.assistant, response.content, response.metrics)); + _messages.add(Message(Speaker.assistant, response.content, response.metrics, DateTime.now())); } _response = null; n = 0; @@ -173,7 +174,14 @@ class TextInferenceProvider extends ChangeNotifier { } void forceStop() { - _inference?.forceStop(); //TODO + _inference?.forceStop(); + if (_response != '...') { + _messages.add(Message(Speaker.assistant, _response!, null, DateTime.now())); + } + _response = null; + if (hasListeners) { + notifyListeners(); + } } void reset() { diff --git a/pubspec.lock b/pubspec.lock index 498a26de..0bb8d121 100644 --- a/pubspec.lock +++ b/pubspec.lock @@ -554,7 +554,7 @@ packages: source: hosted version: "0.1.2-main.4" markdown: - dependency: transitive + dependency: "direct main" description: name: markdown sha256: ef2a1298144e3f985cc736b22e0ccdaf188b5b3970648f2d9dc13efd1d9df051 diff --git a/pubspec.yaml b/pubspec.yaml index 6c06fb78..6ecc19bd 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -57,6 +57,7 @@ dependencies: system_theme: ^3.1.2 flutter_acrylic: ^1.1.4 flutter_markdown: ^0.7.4+3 + markdown: ^7.2.2 dev_dependencies: flutter_test: From 93a8a15bdcf047171e0f266a7d75c49bef685173 Mon Sep 17 00:00:00 2001 From: Dmitry Kalinin Date: Fri, 29 Nov 2024 13:09:47 +0100 Subject: [PATCH 3/5] Added proper selection area --- lib/pages/text_generation/playground.dart | 49 +++--- .../widgets/assistant_message.dart | 142 +++++++++--------- .../text_generation/widgets/user_message.dart | 33 ++-- 3 files changed, 115 insertions(+), 109 deletions(-) diff --git a/lib/pages/text_generation/playground.dart b/lib/pages/text_generation/playground.dart index be882913..84b6d44d 100644 --- a/lib/pages/text_generation/playground.dart +++ b/lib/pages/text_generation/playground.dart @@ -26,13 +26,14 @@ class Playground extends StatefulWidget { class SubmitMessageIntent extends Intent {} class _PlaygroundState extends State { - final textController = TextEditingController(); - final scrollController = ScrollController(); + final _textController = TextEditingController(); + final _scrollController = ScrollController(); + final _focusNode = FocusNode(); bool attachedToBottom = true; void jumpToBottom({ offset = 0 }) { - if (scrollController.hasClients) { - scrollController.jumpTo(scrollController.position.maxScrollExtent + offset); + if (_scrollController.hasClients) { + _scrollController.jumpTo(_scrollController.position.maxScrollExtent + offset); } } @@ -40,7 +41,7 @@ class _PlaygroundState extends State { if (message.isEmpty) return; final provider = Provider.of(context, listen: false); if (!provider.initialized || provider.response != null) return; - textController.text = ''; + _textController.text = ''; jumpToBottom(offset: 110); //move to bottom including both provider.message(message).catchError((e) async { // ignore: use_build_context_synchronously @@ -59,17 +60,17 @@ class _PlaygroundState extends State { @override void initState() { super.initState(); - scrollController.addListener(() { + _scrollController.addListener(() { setState(() { - attachedToBottom = scrollController.position.pixels + 0.001 >= scrollController.position.maxScrollExtent; + attachedToBottom = _scrollController.position.pixels + 0.001 >= _scrollController.position.maxScrollExtent; }); }); } @override void dispose() { - textController.dispose(); - scrollController.dispose(); + _textController.dispose(); + _scrollController.dispose(); super.dispose(); } @@ -158,16 +159,20 @@ class _PlaygroundState extends State { alignment: Alignment.bottomCenter, children: [ SingleChildScrollView( - controller: scrollController, - child: Padding(padding: const EdgeInsets.symmetric(horizontal: 64, vertical: 20), child: Column( - children: provider.messages.map((message) { switch (message.speaker) { - case Speaker.user: return Padding( - padding: const EdgeInsets.only(left: 42), - child: UserMessage(message), - ); - case Speaker.system: return Text('System: ${message.message}'); - case Speaker.assistant: return AssistantMessage(message); - }}).toList(), + controller: _scrollController, + child: Padding(padding: const EdgeInsets.symmetric(horizontal: 64, vertical: 20), child: SelectionArea( + child: SelectionArea( + child: Column( + children: provider.messages.map((message) { switch (message.speaker) { + case Speaker.user: return Padding( + padding: const EdgeInsets.only(left: 42), + child: UserMessage(message), + ); + case Speaker.system: return Text('System: ${message.message}'); + case Speaker.assistant: return AssistantMessage(message); + }}).toList(), + ), + ), ),), ), Positioned( @@ -221,7 +226,7 @@ class _PlaygroundState extends State { child: Actions( actions: >{ SubmitMessageIntent: CallbackAction( - onInvoke: (SubmitMessageIntent intent) => message(textController.text), + onInvoke: (SubmitMessageIntent intent) => message(_textController.text), ), }, child: Column( @@ -230,7 +235,7 @@ class _PlaygroundState extends State { TextBox( placeholder: "Type a message...", keyboardType: TextInputType.multiline, - controller: textController, + controller: _textController, maxLines: null, expands: true, onSubmitted: message, @@ -258,7 +263,7 @@ class _PlaygroundState extends State { ) : Tooltip( message: "Send message", - child: Button(child: const Icon(FluentIcons.send, size: 18,), onPressed: () { message(textController.text); }), + child: Button(child: const Icon(FluentIcons.send, size: 18,), onPressed: () { message(_textController.text); }), ) ), ) diff --git a/lib/pages/text_generation/widgets/assistant_message.dart b/lib/pages/text_generation/widgets/assistant_message.dart index 1ff268f1..6ff98113 100644 --- a/lib/pages/text_generation/widgets/assistant_message.dart +++ b/lib/pages/text_generation/widgets/assistant_message.dart @@ -57,21 +57,23 @@ class _AssistantMessageState extends State { children: [ Padding( padding: const EdgeInsets.only(bottom: 4), - child: Row( - children: [ - Text( - inferenceProvider.project!.name, - style: TextStyle( - color: subtleTextColor.of(theme), + child: SelectionContainer.disabled( + child: Row( + children: [ + Text( + inferenceProvider.project!.name, + style: TextStyle( + color: subtleTextColor.of(theme), + ), ), - ), - if (widget.message.time != null) Text( - DateFormat(' | yyyy-MM-dd HH:mm:ss').format(widget.message.time!), - style: TextStyle( - color: subtleTextColor.of(theme), + if (widget.message.time != null) Text( + DateFormat(' | yyyy-MM-dd HH:mm:ss').format(widget.message.time!), + style: TextStyle( + color: subtleTextColor.of(theme), + ), ), - ), - ], + ], + ), ), ), MouseRegion( @@ -86,78 +88,80 @@ class _AssistantMessageState extends State { color: backgroundColor, borderRadius: BorderRadius.circular(4), ), - child: Markdown( - data: widget.message.message, - selectable: true, - shrinkWrap: true, - padding: const EdgeInsets.all(12), - extensionSet: md.ExtensionSet( - md.ExtensionSet.gitHubFlavored.blockSyntaxes, - [md.EmojiSyntax(), ...md.ExtensionSet.gitHubFlavored.inlineSyntaxes], + child: Padding( + padding: const EdgeInsets.all(12.0), + child: MarkdownBody( + data: widget.message.message, + extensionSet: md.ExtensionSet( + md.ExtensionSet.gitHubFlavored.blockSyntaxes, + [md.EmojiSyntax(), ...md.ExtensionSet.gitHubFlavored.inlineSyntaxes], + ), ), ), ), if (_hovering) Padding( padding: const EdgeInsets.only(top: 4), - child: Row( - children: [ - if (widget.message.metrics != null) Padding( - padding: const EdgeInsets.only(right: 8), - child: Tooltip( - message: 'Time to first token', - child: Text( - 'TTF: ${nf.format(widget.message.metrics!.ttft)}ms', - style: TextStyle( - fontSize: 12, - color: subtleTextColor.of(theme), + child: SelectionContainer.disabled( + child: Row( + children: [ + if (widget.message.metrics != null) Padding( + padding: const EdgeInsets.only(right: 8), + child: Tooltip( + message: 'Time to first token', + child: Text( + 'TTF: ${nf.format(widget.message.metrics!.ttft)}ms', + style: TextStyle( + fontSize: 12, + color: subtleTextColor.of(theme), + ), ), ), ), - ), - if (widget.message.metrics != null) Padding( - padding: const EdgeInsets.only(right: 8), - child: Tooltip( - message: 'Time per output token', - child: Text( - 'TPOT: ${nf.format(widget.message.metrics!.tpot)}ms', - style: TextStyle( - fontSize: 12, - color: subtleTextColor.of(theme), + if (widget.message.metrics != null) Padding( + padding: const EdgeInsets.only(right: 8), + child: Tooltip( + message: 'Time per output token', + child: Text( + 'TPOT: ${nf.format(widget.message.metrics!.tpot)}ms', + style: TextStyle( + fontSize: 12, + color: subtleTextColor.of(theme), + ), ), ), ), - ), - if (widget.message.metrics != null) Padding( - padding: const EdgeInsets.only(right: 8), - child: Tooltip( - message: 'Generate total duration', - child: Text( - 'Generate: ${nf.format(widget.message.metrics!.generate_time/1000)}s', - style: TextStyle( - fontSize: 12, - color: subtleTextColor.of(theme), + if (widget.message.metrics != null) Padding( + padding: const EdgeInsets.only(right: 8), + child: Tooltip( + message: 'Generate total duration', + child: Text( + 'Generate: ${nf.format(widget.message.metrics!.generate_time/1000)}s', + style: TextStyle( + fontSize: 12, + color: subtleTextColor.of(theme), + ), ), ), ), - ), - IconButton( - icon: const Icon(FluentIcons.copy), - onPressed: () async{ - await displayInfoBar(context, builder: (context, close) => - InfoBar( - title: const Text('Copied to clipboard'), - severity: InfoBarSeverity.info, - action: IconButton( - icon: const Icon(FluentIcons.clear), - onPressed: close, + IconButton( + icon: const Icon(FluentIcons.copy), + onPressed: () async{ + await displayInfoBar(context, builder: (context, close) => + InfoBar( + title: const Text('Copied to clipboard'), + severity: InfoBarSeverity.info, + action: IconButton( + icon: const Icon(FluentIcons.clear), + onPressed: close, + ), ), - ), - ); - Clipboard.setData(ClipboardData(text: widget.message.message)); - }, - ), - ], + ); + Clipboard.setData(ClipboardData(text: widget.message.message)); + }, + ), + ], + ), ), ) else const SizedBox(height: 34) ], diff --git a/lib/pages/text_generation/widgets/user_message.dart b/lib/pages/text_generation/widgets/user_message.dart index 5e73496b..dd09a36e 100644 --- a/lib/pages/text_generation/widgets/user_message.dart +++ b/lib/pages/text_generation/widgets/user_message.dart @@ -13,30 +13,27 @@ class UserMessage extends StatelessWidget { final theme = FluentTheme.of(context); return Align( alignment: Alignment.centerRight, - child: ConstrainedBox( - constraints: const BoxConstraints(maxWidth: 1000), - child: Padding(padding: const EdgeInsets.only(bottom: 20), child: Column( - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - Container( - decoration: BoxDecoration( - color: cosmosBackground.of(theme), - borderRadius: BorderRadius.circular(4), - ), - child: Markdown( + child: Padding(padding: const EdgeInsets.only(bottom: 20), child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Container( + decoration: BoxDecoration( + color: cosmosBackground.of(theme), + borderRadius: BorderRadius.circular(4), + ), + child: Padding( + padding: const EdgeInsets.all(12.0), + child: MarkdownBody( data: message.message, - selectable: true, - shrinkWrap: true, - padding: const EdgeInsets.all(12), extensionSet: md.ExtensionSet( md.ExtensionSet.gitHubFlavored.blockSyntaxes, [md.EmojiSyntax(), ...md.ExtensionSet.gitHubFlavored.inlineSyntaxes], ), ), - ) - ], - ),), - ), + ), + ) + ], + ),), ); } } \ No newline at end of file From 2866937de1bcc565e79f713ed84c1cebffcc2938 Mon Sep 17 00:00:00 2001 From: Dmitry Kalinin Date: Mon, 2 Dec 2024 12:24:14 +0100 Subject: [PATCH 4/5] Fixed minor issues --- lib/importers/manifest_importer.dart | 2 +- .../widgets/assistant_message.dart | 5 ++--- lib/router.dart | 2 -- .../widgets/user_message_test.dart | 21 +++++++++++++++++++ 4 files changed, 24 insertions(+), 6 deletions(-) create mode 100644 test/pages/text_generation/widgets/user_message_test.dart diff --git a/lib/importers/manifest_importer.dart b/lib/importers/manifest_importer.dart index 7391c7d5..5c0d3f82 100644 --- a/lib/importers/manifest_importer.dart +++ b/lib/importers/manifest_importer.dart @@ -51,7 +51,7 @@ class Model { ); } - Future convertToProject() async { + Future convertToProject() async { final directory = await getApplicationSupportDirectory(); final projectId = const Uuid().v4(); final storagePath = platformContext.join(directory.path, projectId.toString()); diff --git a/lib/pages/text_generation/widgets/assistant_message.dart b/lib/pages/text_generation/widgets/assistant_message.dart index 6ff98113..321f6683 100644 --- a/lib/pages/text_generation/widgets/assistant_message.dart +++ b/lib/pages/text_generation/widgets/assistant_message.dart @@ -30,14 +30,13 @@ class _AssistantMessageState extends State { return Consumer(builder: (context, inferenceProvider, child) => Align( - alignment: Alignment.centerLeft, child: Padding( padding: const EdgeInsets.only(bottom: 8), child: Row( - crossAxisAlignment: CrossAxisAlignment.end, + crossAxisAlignment: CrossAxisAlignment.start, children: [ Padding( - padding: const EdgeInsets.only(right: 10, bottom: 36), + padding: const EdgeInsets.only(right: 10, top: 20), child: ClipRRect( borderRadius: BorderRadius.circular(16.0), child: Container( diff --git a/lib/router.dart b/lib/router.dart index 1c32c68f..48423cfb 100644 --- a/lib/router.dart +++ b/lib/router.dart @@ -7,8 +7,6 @@ import 'package:inference/pages/home/home.dart'; import 'package:inference/pages/import/import.dart'; import 'package:inference/pages/models/models.dart'; import 'package:inference/project.dart'; -import 'package:inference/providers/download_provider.dart'; -import 'package:provider/provider.dart'; import 'package:inference/pages/models/inference.dart'; final rootNavigatorKey = GlobalKey(); diff --git a/test/pages/text_generation/widgets/user_message_test.dart b/test/pages/text_generation/widgets/user_message_test.dart new file mode 100644 index 00000000..7d605a11 --- /dev/null +++ b/test/pages/text_generation/widgets/user_message_test.dart @@ -0,0 +1,21 @@ +import 'package:flutter_test/flutter_test.dart'; +import 'package:fluent_ui/fluent_ui.dart'; +import 'package:inference/providers/text_inference_provider.dart'; +import 'package:inference/pages/text_generation/widgets/user_message.dart'; + +void main() { + testWidgets('UserMessage renders text correctly', (WidgetTester tester) async { + // Create a sample message + const message = Message(Speaker.user, 'Hello, this is a test message!', null, null); + + // Build the UserMessage widget + await tester.pumpWidget( + const FluentApp( + home: UserMessage(message), + ), + ); + + // Verify if the text is rendered correctly + expect(find.text('Hello, this is a test message!'), findsOneWidget); + }); +} \ No newline at end of file From 2d62523e68397d28734c1ce2a7a452c078a50b7d Mon Sep 17 00:00:00 2001 From: Dmitry Kalinin Date: Mon, 2 Dec 2024 14:01:28 +0100 Subject: [PATCH 5/5] Changed push to go --- lib/pages/text_generation/text_generation.dart | 2 +- lib/widgets/import_model_button.dart | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/pages/text_generation/text_generation.dart b/lib/pages/text_generation/text_generation.dart index fa9ba926..637683c9 100644 --- a/lib/pages/text_generation/text_generation.dart +++ b/lib/pages/text_generation/text_generation.dart @@ -111,7 +111,7 @@ class _TextGenerationPageState extends State { child: const Text("Close"), onPressed: () => GoRouter.of(context).canPop() ? GoRouter.of(context).pop() - : GoRouter.of(context).push('/models'), + : GoRouter.of(context).go('/models'), ), ), ] diff --git a/lib/widgets/import_model_button.dart b/lib/widgets/import_model_button.dart index 68ecc759..92b4dd6c 100644 --- a/lib/widgets/import_model_button.dart +++ b/lib/widgets/import_model_button.dart @@ -52,7 +52,7 @@ class ImportModelButton extends StatelessWidget { return FilledDropDownButton( title: const Text('Import model'), items: [ - MenuFlyoutItem(text: const Text('Hugging Face'), onPressed: () { GoRouter.of(context).push('/models/import'); }), + MenuFlyoutItem(text: const Text('Hugging Face'), onPressed: () { GoRouter.of(context).go('/models/import'); }), MenuFlyoutItem(text: const Text('Local disk'), onPressed: () { addProject(context); }), ] );