diff --git a/lib/langchain/loaders/html_loader.dart b/lib/langchain/loaders/html_loader.dart new file mode 100644 index 00000000..4596d22c --- /dev/null +++ b/lib/langchain/loaders/html_loader.dart @@ -0,0 +1,27 @@ +import 'dart:io'; +import 'dart:math'; +import 'package:html/parser.dart'; + +import 'package:langchain/langchain.dart'; + +class HTMLLoader extends BaseDocumentLoader { + + final String path; + final int windowSize; + const HTMLLoader(this.path, this.windowSize); + + @override + Stream lazyLoad() async* { + final data = await File(path).readAsString(); + final text = parse(data).body?.text; + if (text != null) { + for (int i = 0; i < text.length; i += windowSize) { + final content = text.substring(i, min(i + windowSize, text.length)); + yield Document( + pageContent: content, + metadata: {"source": path}, + ); + } + } + } +} diff --git a/lib/langchain/pdf_loader.dart b/lib/langchain/loaders/pdf_loader.dart similarity index 100% rename from lib/langchain/pdf_loader.dart rename to lib/langchain/loaders/pdf_loader.dart diff --git a/lib/langchain/loaders/text_loader.dart b/lib/langchain/loaders/text_loader.dart new file mode 100644 index 00000000..240f60ad --- /dev/null +++ b/lib/langchain/loaders/text_loader.dart @@ -0,0 +1,23 @@ +import 'dart:io'; +import 'dart:math'; + +import 'package:langchain/langchain.dart'; + +class TextLoader extends BaseDocumentLoader { + + final String path; + final int windowSize; + const TextLoader(this.path, this.windowSize); + + @override + Stream lazyLoad() async* { + final text = await File(path).readAsString(); + for (int i = 0; i < text.length; i += windowSize) { + final content = text.substring(i, min(i + windowSize, text.length)); + yield Document( + pageContent: content, + metadata: {"source": path}, + ); + } + } +} diff --git a/lib/pages/knowledge_base/knowledge_base.dart b/lib/pages/knowledge_base/knowledge_base.dart index 1ddf7f42..27e9e8be 100644 --- a/lib/pages/knowledge_base/knowledge_base.dart +++ b/lib/pages/knowledge_base/knowledge_base.dart @@ -68,10 +68,13 @@ class KnowledgeBase extends StatelessWidget { Expanded( child: Consumer( builder: (context, data, child) { - if (data.activeGroup == null) { - return const Center(child: Text("Select a group from the list to the left")); + if (data.activeGroup != null) { + final group = data.groupBox.get(data.activeGroup!); + if (group != null) { + return DocumentsList(group: group, key: Key(group.internalId.toString())); + } } - return DocumentsList(group: data.activeGroup!); + return const Center(child: Text("Select a group from the list to the left")); } ), ), diff --git a/lib/pages/knowledge_base/providers/knowledge_base_provider.dart b/lib/pages/knowledge_base/providers/knowledge_base_provider.dart index 9bd948b2..f5781b7a 100644 --- a/lib/pages/knowledge_base/providers/knowledge_base_provider.dart +++ b/lib/pages/knowledge_base/providers/knowledge_base_provider.dart @@ -8,14 +8,15 @@ class KnowledgeBaseProvider extends ChangeNotifier { List _groups = []; List get groups => _groups; + set groups(List value) { _groups = value; notifyListeners(); } - KnowledgeGroup? _activeGroup; - KnowledgeGroup? get activeGroup => _activeGroup; - set activeGroup(KnowledgeGroup? value) { + int? _activeGroup; + int? get activeGroup => _activeGroup; + set activeGroup(int? value) { _activeGroup = value; notifyListeners(); } @@ -47,16 +48,19 @@ class KnowledgeBaseProvider extends ChangeNotifier { void addGroup() { isEditingId = groupBox.put(KnowledgeGroup("New group")); groups = groupBox.getAll(); + if (groups.length == 1) { + activeGroup = isEditingId; + } } void setActiveGroup(KnowledgeGroup group) { - activeGroup = group; + activeGroup = group.internalId; } KnowledgeBaseProvider({required this.groupBox}) { groupBox.getAllAsync().then((value) { groups = value; - activeGroup = groups.firstOrNull; + activeGroup = groups.firstOrNull?.internalId; }); } } diff --git a/lib/pages/knowledge_base/utils/loader_selector.dart b/lib/pages/knowledge_base/utils/loader_selector.dart new file mode 100644 index 00000000..e7df99e0 --- /dev/null +++ b/lib/pages/knowledge_base/utils/loader_selector.dart @@ -0,0 +1,38 @@ +import 'package:inference/langchain/loaders/html_loader.dart'; +import 'package:inference/langchain/loaders/pdf_loader.dart'; +import 'package:inference/langchain/loaders/text_loader.dart'; +import 'package:langchain/langchain.dart'; +import 'package:path/path.dart'; + +const windowSize = 400; + +String defaultLoaderSelector(String path) { + final ext = extension(path); + if (ext == ".pdf") { + return "PdfLoader"; + } + + if (ext == ".html") { + return "HTMLLoader"; + } + + //if (ext == ".json") { + // return JsonLoader(path,) + //} + + return "TextLoader"; +} + +BaseDocumentLoader loaderFromName(String name, String path) { + switch (name) { + case "PdfLoader": + return PdfLoader(path, windowSize); + case "HTMLLoader": + return HTMLLoader(path, windowSize); + case "TextLoader": + return TextLoader(path, windowSize); + default: + throw Exception("Unknown loader name: $name"); + } + +} diff --git a/lib/pages/knowledge_base/widgets/documents_list.dart b/lib/pages/knowledge_base/widgets/documents_list.dart index e105f0be..9684f212 100644 --- a/lib/pages/knowledge_base/widgets/documents_list.dart +++ b/lib/pages/knowledge_base/widgets/documents_list.dart @@ -1,15 +1,12 @@ import 'dart:convert'; +import 'dart:io'; import 'package:fluent_ui/fluent_ui.dart'; -import 'package:inference/interop/llm_inference.dart'; import 'package:inference/interop/sentence_transformer.dart'; import 'package:inference/langchain/object_box/embedding_entity.dart'; import 'package:inference/langchain/object_box/object_box.dart'; -import 'package:inference/langchain/object_box_store.dart'; -import 'package:inference/langchain/openvino_embeddings.dart'; -import 'package:inference/langchain/openvino_llm.dart'; -import 'package:inference/langchain/pdf_loader.dart'; import 'package:inference/objectbox.g.dart'; +import 'package:inference/pages/knowledge_base/widgets/import_dialog.dart'; import 'package:inference/pages/models/widgets/grid_container.dart'; import 'package:inference/theme_fluent.dart'; import 'package:inference/widgets/controls/drop_area.dart'; @@ -29,22 +26,25 @@ class DocumentsList extends StatefulWidget { class _DocumentsListState extends State { late Box documentBox; late Box embeddingsBox; + Future? transformerFuture; - void addDocument(String path) async { + late List documents; + Map? filesToImport; + + Future addDocument(String path, BaseDocumentLoader loader) async { print("importing $path"); final document = KnowledgeDocument(path); document.group.target = widget.group; documentBox.put(document); - - final platformContext = Context(style: Style.platform); - final directory = await getApplicationSupportDirectory(); - final embeddingsModelPath = platformContext.join(directory.path, "test", "all-MiniLM-L6-v2", "fp16"); - final transformer = await SentenceTransformer.init(embeddingsModelPath, "CPU"); - const uuid = Uuid(); + final transformer = await transformerFuture; + + if (transformer == null){ + throw Exception("Could not loading transformer"); + } - final lcDocuments = await PdfLoader(path, 400).load(); + final lcDocuments = await loader.load(); List entities = []; for (final lcDocument in lcDocuments) { final embeddings = await transformer.generate(lcDocument.pageContent); @@ -55,6 +55,27 @@ class _DocumentsListState extends State { embeddingsBox.putMany(entities); print("Added ${entities.length} embeddings for $path"); + return document; + } + + + Future initSentenceTransformer() async { + //bit hacky, perhaps move to init of provider? + final platformContext = Context(style: Style.platform); + final directory = await getApplicationSupportDirectory(); + final embeddingsModelPath = platformContext.join(directory.path, "test", "all-MiniLM-L6-v2", "fp16"); + + return SentenceTransformer.init(embeddingsModelPath, "CPU"); + } + + void processUpload(BuildContext context, String path) async { + final files = await importDialog(context, path); + for (final file in files.keys){ + final newDocument = await addDocument(file, files[file]!); + setState(() { + documents.add(newDocument); + }); + } } @override @@ -62,6 +83,8 @@ class _DocumentsListState extends State { super.initState(); documentBox = ObjectBox.instance.store.box(); embeddingsBox = ObjectBox.instance.store.box(); + documents = widget.group.documents; + transformerFuture = initSentenceTransformer(); } @override Widget build(BuildContext context) { @@ -80,142 +103,32 @@ class _DocumentsListState extends State { ), ), ), - Padding( - padding: const EdgeInsets.all(8.0), - child: Builder( - builder: (context) { - return Column( - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - Button( - onPressed: () { - - }, - child: const Text("add document"), - ), - DropArea( - type: "a document", - showChild: widget.group.documents.isNotEmpty, - onUpload: (file) => addDocument(file), - child: Column( - children: [ - for (final document in widget.group.documents) - Row( - mainAxisAlignment: MainAxisAlignment.spaceBetween, - children: [ - Text(document.source), - Text("embeddings: ${document.sections.length}") - ], - ) - ], - ) - ), - Experiment(group: widget.group), - ], - ); - } + Expanded( + child: GridContainer( + color: backgroundColor.of(theme), + padding: const EdgeInsets.all(16), + child: Center( + child: DropArea( + type: "a document or folder", + showChild: documents.isNotEmpty, + onUpload: (file) => processUpload(context, file), + child: Column( + children: [ + for (final document in documents) + Row( + mainAxisAlignment: MainAxisAlignment.spaceBetween, + children: [ + Text(document.source), + Text("embeddings: ${document.sections.length}") + ], + ) + ], + ) + ), + ), ), ), ], ); } } - -class Experiment extends StatefulWidget { - final KnowledgeGroup group; - const Experiment({super.key, required this.group}); - - @override - State createState() => _ExperimentState(); -} - -class _ExperimentState extends State { - VectorStore? vs; - Future? chain; - - String? response; - - Future initMemoryStore() async { - final platformContext = Context(style: Style.platform); - final directory = await getApplicationSupportDirectory(); - const device = "CPU"; - final embeddingsModelPath = platformContext.join(directory.path, "test", "all-MiniLM-L6-v2", "fp16"); - final llmModelPath = platformContext.join(directory.path, "test", "TinyLlama-1.1B-Chat-v1.0-int8-ov"); - final embeddingsModel = await OpenVINOEmbeddings.init(embeddingsModelPath, device); - vs = ObjectBoxStore(embeddings: embeddingsModel, group: widget.group); - final model = OpenVINOLLM(await LLMInference.init(llmModelPath, device), - defaultOptions: const OpenVINOLLMOptions(temperature: 1, topP: 1, applyTemplate: false) - ); - final promptTemplate = ChatPromptTemplate.fromTemplate(''' -<|system|> -Answer the question based only on the following context without specifically naming that it's from that context: -{context} - -<|user|> -{question} -<|assistant|> -'''); - final retriever = vs!.asRetriever(); - - return Runnable.fromMap({ - 'context': retriever | Runnable.mapInput((docs) => docs.map((d) => d.pageContent).join('\n')), - 'question': Runnable.passthrough(), - }) | promptTemplate | model | const StringOutputParser(); - } - - - void runChain(String text) async { - final runnable = (await chain)!; - setState(() { - response = ""; - }); - await for (final output in runnable.stream(text)) { - setState(() { - response = (response ?? "") + output.toString(); - }); - } - } - - @override - void initState() { - super.initState(); - chain = initMemoryStore(); - } - - @override - Widget build(BuildContext context) { - return FutureBuilder( - future: chain, - builder: (context, snapshot) { - if (!snapshot.hasData) { - return Center( - child: Image.asset('images/intel-loading.gif', width: 100) - ); - } - return Padding( - padding: const EdgeInsets.all(8.0), - child: Column( - crossAxisAlignment: CrossAxisAlignment.stretch, - children: [ - const Padding( - padding: EdgeInsets.only(top: 20), - child: Text("Experiment area"), - ), - TextBox( - onSubmitted: runChain, - ), - Builder( - builder: (context) { - if (response != null) { - return SingleChildScrollView(child: SelectableText(response!.isEmpty ? "..." : response!)); - } - return const Text("Type a message to test RAG"); - } - ) - ], - ), - ); - } - ); - } -} diff --git a/lib/pages/knowledge_base/widgets/experiment.dart b/lib/pages/knowledge_base/widgets/experiment.dart new file mode 100644 index 00000000..26221ef5 --- /dev/null +++ b/lib/pages/knowledge_base/widgets/experiment.dart @@ -0,0 +1,107 @@ +import 'package:fluent_ui/fluent_ui.dart'; +import 'package:inference/interop/llm_inference.dart'; +import 'package:inference/langchain/object_box/embedding_entity.dart'; +import 'package:inference/langchain/object_box_store.dart'; +import 'package:inference/langchain/openvino_embeddings.dart'; +import 'package:inference/langchain/openvino_llm.dart'; +import 'package:langchain/langchain.dart'; +import 'package:path/path.dart'; +import 'package:path_provider/path_provider.dart'; + +class Experiment extends StatefulWidget { + final KnowledgeGroup group; + const Experiment({super.key, required this.group}); + + @override + State createState() => _ExperimentState(); +} + +class _ExperimentState extends State { + VectorStore? vs; + Future? chain; + String? response; + + Future initMemoryStore() async { + final platformContext = Context(style: Style.platform); + final directory = await getApplicationSupportDirectory(); + const device = "CPU"; + final embeddingsModelPath = platformContext.join(directory.path, "test", "all-MiniLM-L6-v2", "fp16"); + final llmModelPath = platformContext.join(directory.path, "test", "TinyLlama-1.1B-Chat-v1.0-int8-ov"); + final embeddingsModel = await OpenVINOEmbeddings.init(embeddingsModelPath, device); + vs = ObjectBoxStore(embeddings: embeddingsModel, group: widget.group); + final model = OpenVINOLLM(await LLMInference.init(llmModelPath, device), + defaultOptions: const OpenVINOLLMOptions(temperature: 1, topP: 1, applyTemplate: false) + ); + final promptTemplate = ChatPromptTemplate.fromTemplate(''' +<|system|> +Answer the question based only on the following context without specifically naming that it's from that context: +{context} + +<|user|> +{question} +<|assistant|> +'''); + final retriever = vs!.asRetriever(); + + return Runnable.fromMap({ + 'context': retriever | Runnable.mapInput((docs) => docs.map((d) => d.pageContent).join('\n')), + 'question': Runnable.passthrough(), + }) | promptTemplate | model | const StringOutputParser(); + } + + + void runChain(String text) async { + final runnable = (await chain)!; + setState(() { + response = ""; + }); + await for (final output in runnable.stream(text)) { + setState(() { + response = (response ?? "") + output.toString(); + }); + } + } + + @override + void initState() { + super.initState(); + chain = initMemoryStore(); + } + + @override + Widget build(BuildContext context) { + return FutureBuilder( + future: chain, + builder: (context, snapshot) { + if (!snapshot.hasData) { + return Center( + child: Image.asset('images/intel-loading.gif', width: 100) + ); + } + return Padding( + padding: const EdgeInsets.all(8.0), + child: Column( + crossAxisAlignment: CrossAxisAlignment.stretch, + children: [ + const Padding( + padding: EdgeInsets.only(top: 20), + child: Text("Experiment area"), + ), + TextBox( + onSubmitted: runChain, + ), + Builder( + builder: (context) { + if (response != null) { + return SingleChildScrollView(child: SelectableText(response!.isEmpty ? "..." : response!)); + } + return const Text("Type a message to test RAG"); + } + ) + ], + ), + ); + } + ); + } +} diff --git a/lib/pages/knowledge_base/widgets/group_item.dart b/lib/pages/knowledge_base/widgets/group_item.dart index b81fe866..f0edf27f 100644 --- a/lib/pages/knowledge_base/widgets/group_item.dart +++ b/lib/pages/knowledge_base/widgets/group_item.dart @@ -38,14 +38,6 @@ class _GroupItemState extends State { Widget build(BuildContext context) { final theme = FluentTheme.of(context); - if (widget.editable) { - return TextBox( - controller: controller, - onSubmitted: (value) { - widget.onRename?.call(value); - }, - ); - } return GestureDetector( behavior: HitTestBehavior.opaque, onTap: () { @@ -55,8 +47,9 @@ class _GroupItemState extends State { widget.onMakeEditable?.call(); }, child: Container( - margin: const EdgeInsets.all(2), - padding: const EdgeInsets.only(left: 20), + padding: const EdgeInsets.all(4), + margin: const EdgeInsets.all(4), + height: 40, decoration: BoxDecoration( border: Border( left: BorderSide( @@ -65,14 +58,30 @@ class _GroupItemState extends State { ), ) ), - child: Row( - mainAxisAlignment: MainAxisAlignment.spaceBetween, - children: [ - Text(widget.group.name), - IconButton(icon: const Icon(FluentIcons.delete), onPressed: () { - widget.onDelete?.call(); - }), - ], + child: Builder( + builder: (context) { + if (widget.editable) { + return TextBox( + controller: controller, + onSubmitted: (value) { + widget.onRename?.call(value); + }, + ); + } + + return Padding( + padding: const EdgeInsets.only(left: 10, bottom: 5.5), + child: Row( + mainAxisAlignment: MainAxisAlignment.spaceBetween, + children: [ + Text(widget.group.name), + IconButton(icon: const Icon(FluentIcons.delete), onPressed: () { + widget.onDelete?.call(); + }), + ], + ), + ); + } ), ), ); diff --git a/lib/pages/knowledge_base/widgets/import_dialog.dart b/lib/pages/knowledge_base/widgets/import_dialog.dart new file mode 100644 index 00000000..27e754a4 --- /dev/null +++ b/lib/pages/knowledge_base/widgets/import_dialog.dart @@ -0,0 +1,153 @@ +import 'dart:io'; + +import 'package:fluent_ui/fluent_ui.dart'; +import 'package:inference/pages/computer_vision/widgets/horizontal_rule.dart'; +import 'package:inference/pages/knowledge_base/utils/loader_selector.dart'; +import 'package:langchain/langchain.dart'; +import 'package:path/path.dart'; + +Future> importDialog(BuildContext context, String path) async { + final result = await showDialog?>( + context: context, + builder: (context) => ImportRAGWidget( + path: path, + ), + ); + if (result != null) { + return result.map((key, value) => MapEntry(key, loaderFromName(value, key))); + } else { + return Map.of({}); + } +} + +class ImportRAGWidget extends StatefulWidget { + final String path; + + const ImportRAGWidget({ + super.key, + required this.path, + }); + + @override + State createState() => _ImportRAGWidgetState(); +} + +class _ImportRAGWidgetState extends State { + late final String baseDir; + late final Map files; + + @override + void initState() { + super.initState(); + final dir = Directory(widget.path); + if (dir.existsSync()) { + // its a directory + final content = dir.listSync(recursive: true).map((p) => p.path); + baseDir = widget.path; + files = Map.fromIterable(content, key: (path) { + return path; + }, + value: (path) { + return defaultLoaderSelector(path); + } + ); + } else { + // its a file + final loader = defaultLoaderSelector(widget.path); + baseDir = dirname(widget.path); + files = {widget.path: loader}; + } + + } + @override + Widget build(BuildContext context) { + return ContentDialog( + constraints: const BoxConstraints( + maxWidth: 756, + maxHeight: 500, + ), + title: const Text('Import files?'), + content: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Text("From directory: $baseDir"), + const HorizontalRule(), + Expanded( + child: SingleChildScrollView( + child: Column( + children: [ + for (final key in files.keys) + Padding( + padding: const EdgeInsets.all(4), + child: Row( + mainAxisAlignment: MainAxisAlignment.spaceBetween, + children: [ + Text(key.substring(baseDir.length)), + Row( + children: [ + LoaderSelector( + value: files[key]!, + onChange: (value) { + setState(() { + files[key] = value; + }); + }, + ), + IconButton( + onPressed: () { + setState(() { + files.remove(key); + }); + }, + icon: const Icon(FluentIcons.delete), + ) + ], + ), + ] + ), + ), + ] + ), + ), + ), + ], + ), + actions: [ + FilledButton( + child: const Text('Import'), + onPressed: () { + Navigator.pop(context, files); + // Delete file here + }, + ), + Button( + child: const Text('Cancel'), + onPressed: () => Navigator.pop(context, null), + ), + ], + ); + } +} + +const List loaders = ["PdfLoader", "TextLoader", "HTMLLoader"]; + +class LoaderSelector extends StatelessWidget { + final String value; + final Function(String) onChange; + const LoaderSelector({super.key, required this.value, required this.onChange}); + + @override + Widget build(BuildContext context) { + return ComboBox( + value: value, + items: [ + for (final loader in loaders) + ComboBoxItem( + value: loader, + child: Text(loader), + ), + ], + onChanged: (v) => onChange(v!) + ); + } +} diff --git a/lib/pages/knowledge_base/widgets/tree.dart b/lib/pages/knowledge_base/widgets/tree.dart index 9e7a7002..7e2bbe07 100644 --- a/lib/pages/knowledge_base/widgets/tree.dart +++ b/lib/pages/knowledge_base/widgets/tree.dart @@ -29,11 +29,11 @@ class _TreeState extends State { children: [ for (final group in data.groups) GroupItem( - isActive: data.activeGroup == group, + isActive: data.activeGroup == group.internalId, editable: data.isEditingId == group.internalId, group: group, onActivate: () { - data.activeGroup = group; + data.activeGroup = group.internalId; }, onRename: (value) => data.renameGroup(group, value), onMakeEditable: () { diff --git a/pubspec.lock b/pubspec.lock index 61a8479a..d65cef92 100644 --- a/pubspec.lock +++ b/pubspec.lock @@ -437,7 +437,7 @@ packages: source: hosted version: "2.3.2" html: - dependency: "direct dev" + dependency: "direct main" description: name: html sha256: "1fc58edeaec4307368c60d59b7e15b9d658b57d7f3125098b6294153c75337ec" diff --git a/pubspec.yaml b/pubspec.yaml index 6130d7a9..5c0e8bdb 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -60,6 +60,7 @@ dependencies: langchain: ^0.7.7+1 objectbox: ^4.0.3 objectbox_flutter_libs: any + html: ^0.15.5 dev_dependencies: flutter_test: @@ -73,7 +74,6 @@ dev_dependencies: flutter_lints: ^3.0.0 msix: ^3.16.7 ffigen: ^13.0.0 - html: ^0.15.4 integration_test: sdk: flutter path_provider_platform_interface: ^2.1.2