diff --git a/lib/pages/knowledge_base/widgets/documents_list.dart b/lib/pages/knowledge_base/widgets/documents_list.dart index 9684f212..d4bb16b1 100644 --- a/lib/pages/knowledge_base/widgets/documents_list.dart +++ b/lib/pages/knowledge_base/widgets/documents_list.dart @@ -6,6 +6,7 @@ import 'package:inference/interop/sentence_transformer.dart'; import 'package:inference/langchain/object_box/embedding_entity.dart'; import 'package:inference/langchain/object_box/object_box.dart'; import 'package:inference/objectbox.g.dart'; +import 'package:inference/pages/knowledge_base/widgets/experiment.dart'; import 'package:inference/pages/knowledge_base/widgets/import_dialog.dart'; import 'package:inference/pages/models/widgets/grid_container.dart'; import 'package:inference/theme_fluent.dart'; diff --git a/lib/pages/text_generation/playground.dart b/lib/pages/text_generation/playground.dart index 84b6d44d..85bfdebb 100644 --- a/lib/pages/text_generation/playground.dart +++ b/lib/pages/text_generation/playground.dart @@ -6,6 +6,7 @@ import 'package:inference/pages/models/widgets/grid_container.dart'; import 'package:inference/pages/text_generation/widgets/assistant_message.dart'; import 'package:inference/pages/text_generation/widgets/model_properties.dart'; import 'package:inference/pages/text_generation/widgets/user_message.dart'; +import 'package:inference/pages/text_generation/widgets/knowledge_base_selector.dart'; import 'package:inference/project.dart'; import 'package:inference/providers/text_inference_provider.dart'; import 'package:inference/theme_fluent.dart'; @@ -122,6 +123,7 @@ class _PlaygroundState extends State { max: 1.0, min: 0.1, ), + const KnowledgeBaseSelector() ], ) ), @@ -283,4 +285,4 @@ class _PlaygroundState extends State { ], ); } -} \ No newline at end of file +} diff --git a/lib/pages/text_generation/widgets/knowledge_base_selector.dart b/lib/pages/text_generation/widgets/knowledge_base_selector.dart new file mode 100644 index 00000000..a0776c89 --- /dev/null +++ b/lib/pages/text_generation/widgets/knowledge_base_selector.dart @@ -0,0 +1,63 @@ +import 'package:fluent_ui/fluent_ui.dart'; +import 'package:inference/langchain/object_box/embedding_entity.dart'; +import 'package:inference/langchain/object_box/object_box.dart'; +import 'package:inference/objectbox.g.dart'; +import 'package:inference/providers/text_inference_provider.dart'; +import 'package:inference/widgets/controls/no_outline_button.dart'; +import 'package:provider/provider.dart'; + +class KnowledgeBaseSelector extends StatefulWidget { + const KnowledgeBaseSelector({super.key}); + + @override + State createState() => _KnowledgeBaseSelectorState(); +} + +class _KnowledgeBaseSelectorState extends State { + late List groups; + + @override + void initState() { + super.initState(); + groups = ObjectBox.instance.store.box().getAll(); + } + + @override + Widget build(BuildContext context) { + return Consumer(builder: (context, inference, child) { + return DropDownButton( + buttonBuilder: (context, callback) { + return NoOutlineButton( + onPressed: callback, + child: Padding( + padding: const EdgeInsets.all(8.0), + child: Row( + children: [ + (inference.knowledgeGroup == null + ? const Text("Knowledge Base") + : Text("Knowledge Base: ${inference.knowledgeGroup!.name}") + ), + const Padding( + padding: EdgeInsets.only(left: 8), + child: Icon(FluentIcons.chevron_down, size: 12), + ), + ], + ), + ), + ); + }, + items: [ + MenuFlyoutItem(text: const Text("None"), onPressed: () { + inference.knowledgeGroup = null; + }), + for (final group in groups) + + MenuFlyoutItem(text: Text(group.name), onPressed: () { + inference.knowledgeGroup = group; + }) + ] + ); + } + ); + } +} diff --git a/lib/providers/text_inference_provider.dart b/lib/providers/text_inference_provider.dart index 2089c299..9c0a250d 100644 --- a/lib/providers/text_inference_provider.dart +++ b/lib/providers/text_inference_provider.dart @@ -3,7 +3,14 @@ import 'dart:async'; import 'package:flutter/material.dart'; import 'package:inference/interop/generated_bindings.dart'; import 'package:inference/interop/llm_inference.dart'; +import 'package:inference/langchain/object_box/embedding_entity.dart'; +import 'package:inference/langchain/object_box_store.dart'; +import 'package:inference/langchain/openvino_embeddings.dart'; +import 'package:inference/langchain/openvino_llm.dart'; import 'package:inference/project.dart'; +import 'package:langchain/langchain.dart'; +import 'package:path/path.dart'; +import 'package:path_provider/path_provider.dart'; enum Speaker { system, assistant, user } @@ -37,6 +44,43 @@ class Message { const Message(this.speaker, this.message, this.metrics, this.time); } +Future buildChain(LLMInference inference, KnowledgeGroup? group) async { + final platformContext = Context(style: Style.platform); + final directory = await getApplicationSupportDirectory(); + const device = "CPU"; + final embeddingsModelPath = platformContext.join(directory.path, "test", "all-MiniLM-L6-v2", "fp16"); + final embeddingsModel = await OpenVINOEmbeddings.init(embeddingsModelPath, device); + + if (group != null) { + final vs = ObjectBoxStore(embeddings: embeddingsModel, group: group); + final model = OpenVINOLLM(inference, defaultOptions: const OpenVINOLLMOptions(temperature: 1, topP: 1, applyTemplate: false)); + + + final promptTemplate = ChatPromptTemplate.fromTemplate(''' +<|system|> +Answer the question based only on the following context without specifically naming that it's from that context: +{context} + +<|user|> +{question} +<|assistant|> + '''); + final retriever = vs.asRetriever(); + + return Runnable.fromMap({ + 'context': retriever | Runnable.mapInput((docs) => docs.map((d) => d.pageContent).join('\n')), + 'question': Runnable.passthrough(), + }) | promptTemplate | model | const StringOutputParser(); + } else { + final model = OpenVINOLLM(inference, defaultOptions: const OpenVINOLLMOptions(temperature: 1, topP: 1, applyTemplate: true)); + final promptTemplate = ChatPromptTemplate.fromTemplate("{question}"); + + return Runnable.fromMap({ + 'question': Runnable.passthrough(), + }) | promptTemplate | model | const StringOutputParser(); + } +} + class TextInferenceProvider extends ChangeNotifier { Completer loaded = Completer(); @@ -48,6 +92,15 @@ class TextInferenceProvider extends ChangeNotifier { String? get device => _device; Metrics? get metrics => _messages.lastOrNull?.metrics; + Future? chain; + + KnowledgeGroup? _knowledgeGroup; + KnowledgeGroup? get knowledgeGroup => _knowledgeGroup; + set knowledgeGroup(KnowledgeGroup? group) { + _knowledgeGroup = group; + notifyListeners(); + } + double _temperature = 1; double get temperature => _temperature; set temperature(double v) { @@ -73,8 +126,8 @@ class TextInferenceProvider extends ChangeNotifier { Future loadModel() async { if (project != null && device != null) { - _inference = await LLMInference.init(project!.storagePath, device!) - ..setListener(onMessage); + _inference = await LLMInference.init(project!.storagePath, device!); + chain = buildChain(_inference!, knowledgeGroup); loaded.complete(); notifyListeners(); } @@ -149,13 +202,24 @@ class TextInferenceProvider extends ChangeNotifier { } Future message(String message) async { + _response = "..."; _messages.add(Message(Speaker.user, message, null, DateTime.now())); notifyListeners(); - final response = await _inference!.prompt(message, true, temperature, topP); + chain = buildChain(_inference!, knowledgeGroup); + final runnable = (await chain)!; + //final response = await _inference!.prompt(message, true, temperature, topP); + + String modelOutput = ""; + await for (final output in runnable.stream(message)) { + final token = output.toString(); + modelOutput += token; + onMessage(token); + } + print("end..."); if (_messages.isNotEmpty) { - _messages.add(Message(Speaker.assistant, response.content, response.metrics, DateTime.now())); + _messages.add(Message(Speaker.assistant, modelOutput, null, DateTime.now())); } _response = null; n = 0;