diff --git a/lib/langchain/loaders/html_loader.dart b/lib/langchain/loaders/html_loader.dart
new file mode 100644
index 00000000..4596d22c
--- /dev/null
+++ b/lib/langchain/loaders/html_loader.dart
@@ -0,0 +1,27 @@
+import 'dart:io';
+import 'dart:math';
+import 'package:html/parser.dart';
+
+import 'package:langchain/langchain.dart';
+
+class HTMLLoader extends BaseDocumentLoader {
+
+ final String path;
+ final int windowSize;
+ const HTMLLoader(this.path, this.windowSize);
+
+ @override
+ Stream lazyLoad() async* {
+ final data = await File(path).readAsString();
+ final text = parse(data).body?.text;
+ if (text != null) {
+ for (int i = 0; i < text.length; i += windowSize) {
+ final content = text.substring(i, min(i + windowSize, text.length));
+ yield Document(
+ pageContent: content,
+ metadata: {"source": path},
+ );
+ }
+ }
+ }
+}
diff --git a/lib/langchain/pdf_loader.dart b/lib/langchain/loaders/pdf_loader.dart
similarity index 100%
rename from lib/langchain/pdf_loader.dart
rename to lib/langchain/loaders/pdf_loader.dart
diff --git a/lib/langchain/loaders/text_loader.dart b/lib/langchain/loaders/text_loader.dart
new file mode 100644
index 00000000..240f60ad
--- /dev/null
+++ b/lib/langchain/loaders/text_loader.dart
@@ -0,0 +1,23 @@
+import 'dart:io';
+import 'dart:math';
+
+import 'package:langchain/langchain.dart';
+
+class TextLoader extends BaseDocumentLoader {
+
+ final String path;
+ final int windowSize;
+ const TextLoader(this.path, this.windowSize);
+
+ @override
+ Stream lazyLoad() async* {
+ final text = await File(path).readAsString();
+ for (int i = 0; i < text.length; i += windowSize) {
+ final content = text.substring(i, min(i + windowSize, text.length));
+ yield Document(
+ pageContent: content,
+ metadata: {"source": path},
+ );
+ }
+ }
+}
diff --git a/lib/pages/knowledge_base/knowledge_base.dart b/lib/pages/knowledge_base/knowledge_base.dart
index 1ddf7f42..27e9e8be 100644
--- a/lib/pages/knowledge_base/knowledge_base.dart
+++ b/lib/pages/knowledge_base/knowledge_base.dart
@@ -68,10 +68,13 @@ class KnowledgeBase extends StatelessWidget {
Expanded(
child: Consumer(
builder: (context, data, child) {
- if (data.activeGroup == null) {
- return const Center(child: Text("Select a group from the list to the left"));
+ if (data.activeGroup != null) {
+ final group = data.groupBox.get(data.activeGroup!);
+ if (group != null) {
+ return DocumentsList(group: group, key: Key(group.internalId.toString()));
+ }
}
- return DocumentsList(group: data.activeGroup!);
+ return const Center(child: Text("Select a group from the list to the left"));
}
),
),
diff --git a/lib/pages/knowledge_base/providers/knowledge_base_provider.dart b/lib/pages/knowledge_base/providers/knowledge_base_provider.dart
index 9bd948b2..f5781b7a 100644
--- a/lib/pages/knowledge_base/providers/knowledge_base_provider.dart
+++ b/lib/pages/knowledge_base/providers/knowledge_base_provider.dart
@@ -8,14 +8,15 @@ class KnowledgeBaseProvider extends ChangeNotifier {
List _groups = [];
List get groups => _groups;
+
set groups(List value) {
_groups = value;
notifyListeners();
}
- KnowledgeGroup? _activeGroup;
- KnowledgeGroup? get activeGroup => _activeGroup;
- set activeGroup(KnowledgeGroup? value) {
+ int? _activeGroup;
+ int? get activeGroup => _activeGroup;
+ set activeGroup(int? value) {
_activeGroup = value;
notifyListeners();
}
@@ -47,16 +48,19 @@ class KnowledgeBaseProvider extends ChangeNotifier {
void addGroup() {
isEditingId = groupBox.put(KnowledgeGroup("New group"));
groups = groupBox.getAll();
+ if (groups.length == 1) {
+ activeGroup = isEditingId;
+ }
}
void setActiveGroup(KnowledgeGroup group) {
- activeGroup = group;
+ activeGroup = group.internalId;
}
KnowledgeBaseProvider({required this.groupBox}) {
groupBox.getAllAsync().then((value) {
groups = value;
- activeGroup = groups.firstOrNull;
+ activeGroup = groups.firstOrNull?.internalId;
});
}
}
diff --git a/lib/pages/knowledge_base/utils/loader_selector.dart b/lib/pages/knowledge_base/utils/loader_selector.dart
new file mode 100644
index 00000000..e7df99e0
--- /dev/null
+++ b/lib/pages/knowledge_base/utils/loader_selector.dart
@@ -0,0 +1,38 @@
+import 'package:inference/langchain/loaders/html_loader.dart';
+import 'package:inference/langchain/loaders/pdf_loader.dart';
+import 'package:inference/langchain/loaders/text_loader.dart';
+import 'package:langchain/langchain.dart';
+import 'package:path/path.dart';
+
+const windowSize = 400;
+
+String defaultLoaderSelector(String path) {
+ final ext = extension(path);
+ if (ext == ".pdf") {
+ return "PdfLoader";
+ }
+
+ if (ext == ".html") {
+ return "HTMLLoader";
+ }
+
+ //if (ext == ".json") {
+ // return JsonLoader(path,)
+ //}
+
+ return "TextLoader";
+}
+
+BaseDocumentLoader loaderFromName(String name, String path) {
+ switch (name) {
+ case "PdfLoader":
+ return PdfLoader(path, windowSize);
+ case "HTMLLoader":
+ return HTMLLoader(path, windowSize);
+ case "TextLoader":
+ return TextLoader(path, windowSize);
+ default:
+ throw Exception("Unknown loader name: $name");
+ }
+
+}
diff --git a/lib/pages/knowledge_base/widgets/documents_list.dart b/lib/pages/knowledge_base/widgets/documents_list.dart
index e105f0be..9684f212 100644
--- a/lib/pages/knowledge_base/widgets/documents_list.dart
+++ b/lib/pages/knowledge_base/widgets/documents_list.dart
@@ -1,15 +1,12 @@
import 'dart:convert';
+import 'dart:io';
import 'package:fluent_ui/fluent_ui.dart';
-import 'package:inference/interop/llm_inference.dart';
import 'package:inference/interop/sentence_transformer.dart';
import 'package:inference/langchain/object_box/embedding_entity.dart';
import 'package:inference/langchain/object_box/object_box.dart';
-import 'package:inference/langchain/object_box_store.dart';
-import 'package:inference/langchain/openvino_embeddings.dart';
-import 'package:inference/langchain/openvino_llm.dart';
-import 'package:inference/langchain/pdf_loader.dart';
import 'package:inference/objectbox.g.dart';
+import 'package:inference/pages/knowledge_base/widgets/import_dialog.dart';
import 'package:inference/pages/models/widgets/grid_container.dart';
import 'package:inference/theme_fluent.dart';
import 'package:inference/widgets/controls/drop_area.dart';
@@ -29,22 +26,25 @@ class DocumentsList extends StatefulWidget {
class _DocumentsListState extends State {
late Box documentBox;
late Box embeddingsBox;
+ Future? transformerFuture;
- void addDocument(String path) async {
+ late List documents;
+ Map? filesToImport;
+
+ Future addDocument(String path, BaseDocumentLoader loader) async {
print("importing $path");
final document = KnowledgeDocument(path);
document.group.target = widget.group;
documentBox.put(document);
-
- final platformContext = Context(style: Style.platform);
- final directory = await getApplicationSupportDirectory();
- final embeddingsModelPath = platformContext.join(directory.path, "test", "all-MiniLM-L6-v2", "fp16");
- final transformer = await SentenceTransformer.init(embeddingsModelPath, "CPU");
-
const uuid = Uuid();
+ final transformer = await transformerFuture;
+
+ if (transformer == null){
+ throw Exception("Could not loading transformer");
+ }
- final lcDocuments = await PdfLoader(path, 400).load();
+ final lcDocuments = await loader.load();
List entities = [];
for (final lcDocument in lcDocuments) {
final embeddings = await transformer.generate(lcDocument.pageContent);
@@ -55,6 +55,27 @@ class _DocumentsListState extends State {
embeddingsBox.putMany(entities);
print("Added ${entities.length} embeddings for $path");
+ return document;
+ }
+
+
+ Future initSentenceTransformer() async {
+ //bit hacky, perhaps move to init of provider?
+ final platformContext = Context(style: Style.platform);
+ final directory = await getApplicationSupportDirectory();
+ final embeddingsModelPath = platformContext.join(directory.path, "test", "all-MiniLM-L6-v2", "fp16");
+
+ return SentenceTransformer.init(embeddingsModelPath, "CPU");
+ }
+
+ void processUpload(BuildContext context, String path) async {
+ final files = await importDialog(context, path);
+ for (final file in files.keys){
+ final newDocument = await addDocument(file, files[file]!);
+ setState(() {
+ documents.add(newDocument);
+ });
+ }
}
@override
@@ -62,6 +83,8 @@ class _DocumentsListState extends State {
super.initState();
documentBox = ObjectBox.instance.store.box();
embeddingsBox = ObjectBox.instance.store.box();
+ documents = widget.group.documents;
+ transformerFuture = initSentenceTransformer();
}
@override
Widget build(BuildContext context) {
@@ -80,142 +103,32 @@ class _DocumentsListState extends State {
),
),
),
- Padding(
- padding: const EdgeInsets.all(8.0),
- child: Builder(
- builder: (context) {
- return Column(
- crossAxisAlignment: CrossAxisAlignment.start,
- children: [
- Button(
- onPressed: () {
-
- },
- child: const Text("add document"),
- ),
- DropArea(
- type: "a document",
- showChild: widget.group.documents.isNotEmpty,
- onUpload: (file) => addDocument(file),
- child: Column(
- children: [
- for (final document in widget.group.documents)
- Row(
- mainAxisAlignment: MainAxisAlignment.spaceBetween,
- children: [
- Text(document.source),
- Text("embeddings: ${document.sections.length}")
- ],
- )
- ],
- )
- ),
- Experiment(group: widget.group),
- ],
- );
- }
+ Expanded(
+ child: GridContainer(
+ color: backgroundColor.of(theme),
+ padding: const EdgeInsets.all(16),
+ child: Center(
+ child: DropArea(
+ type: "a document or folder",
+ showChild: documents.isNotEmpty,
+ onUpload: (file) => processUpload(context, file),
+ child: Column(
+ children: [
+ for (final document in documents)
+ Row(
+ mainAxisAlignment: MainAxisAlignment.spaceBetween,
+ children: [
+ Text(document.source),
+ Text("embeddings: ${document.sections.length}")
+ ],
+ )
+ ],
+ )
+ ),
+ ),
),
),
],
);
}
}
-
-class Experiment extends StatefulWidget {
- final KnowledgeGroup group;
- const Experiment({super.key, required this.group});
-
- @override
- State createState() => _ExperimentState();
-}
-
-class _ExperimentState extends State {
- VectorStore? vs;
- Future? chain;
-
- String? response;
-
- Future initMemoryStore() async {
- final platformContext = Context(style: Style.platform);
- final directory = await getApplicationSupportDirectory();
- const device = "CPU";
- final embeddingsModelPath = platformContext.join(directory.path, "test", "all-MiniLM-L6-v2", "fp16");
- final llmModelPath = platformContext.join(directory.path, "test", "TinyLlama-1.1B-Chat-v1.0-int8-ov");
- final embeddingsModel = await OpenVINOEmbeddings.init(embeddingsModelPath, device);
- vs = ObjectBoxStore(embeddings: embeddingsModel, group: widget.group);
- final model = OpenVINOLLM(await LLMInference.init(llmModelPath, device),
- defaultOptions: const OpenVINOLLMOptions(temperature: 1, topP: 1, applyTemplate: false)
- );
- final promptTemplate = ChatPromptTemplate.fromTemplate('''
-<|system|>
-Answer the question based only on the following context without specifically naming that it's from that context:
-{context}
-
-<|user|>
-{question}
-<|assistant|>
-''');
- final retriever = vs!.asRetriever();
-
- return Runnable.fromMap({
- 'context': retriever | Runnable.mapInput((docs) => docs.map((d) => d.pageContent).join('\n')),
- 'question': Runnable.passthrough(),
- }) | promptTemplate | model | const StringOutputParser();
- }
-
-
- void runChain(String text) async {
- final runnable = (await chain)!;
- setState(() {
- response = "";
- });
- await for (final output in runnable.stream(text)) {
- setState(() {
- response = (response ?? "") + output.toString();
- });
- }
- }
-
- @override
- void initState() {
- super.initState();
- chain = initMemoryStore();
- }
-
- @override
- Widget build(BuildContext context) {
- return FutureBuilder(
- future: chain,
- builder: (context, snapshot) {
- if (!snapshot.hasData) {
- return Center(
- child: Image.asset('images/intel-loading.gif', width: 100)
- );
- }
- return Padding(
- padding: const EdgeInsets.all(8.0),
- child: Column(
- crossAxisAlignment: CrossAxisAlignment.stretch,
- children: [
- const Padding(
- padding: EdgeInsets.only(top: 20),
- child: Text("Experiment area"),
- ),
- TextBox(
- onSubmitted: runChain,
- ),
- Builder(
- builder: (context) {
- if (response != null) {
- return SingleChildScrollView(child: SelectableText(response!.isEmpty ? "..." : response!));
- }
- return const Text("Type a message to test RAG");
- }
- )
- ],
- ),
- );
- }
- );
- }
-}
diff --git a/lib/pages/knowledge_base/widgets/experiment.dart b/lib/pages/knowledge_base/widgets/experiment.dart
new file mode 100644
index 00000000..26221ef5
--- /dev/null
+++ b/lib/pages/knowledge_base/widgets/experiment.dart
@@ -0,0 +1,107 @@
+import 'package:fluent_ui/fluent_ui.dart';
+import 'package:inference/interop/llm_inference.dart';
+import 'package:inference/langchain/object_box/embedding_entity.dart';
+import 'package:inference/langchain/object_box_store.dart';
+import 'package:inference/langchain/openvino_embeddings.dart';
+import 'package:inference/langchain/openvino_llm.dart';
+import 'package:langchain/langchain.dart';
+import 'package:path/path.dart';
+import 'package:path_provider/path_provider.dart';
+
+class Experiment extends StatefulWidget {
+ final KnowledgeGroup group;
+ const Experiment({super.key, required this.group});
+
+ @override
+ State createState() => _ExperimentState();
+}
+
+class _ExperimentState extends State {
+ VectorStore? vs;
+ Future? chain;
+ String? response;
+
+ Future initMemoryStore() async {
+ final platformContext = Context(style: Style.platform);
+ final directory = await getApplicationSupportDirectory();
+ const device = "CPU";
+ final embeddingsModelPath = platformContext.join(directory.path, "test", "all-MiniLM-L6-v2", "fp16");
+ final llmModelPath = platformContext.join(directory.path, "test", "TinyLlama-1.1B-Chat-v1.0-int8-ov");
+ final embeddingsModel = await OpenVINOEmbeddings.init(embeddingsModelPath, device);
+ vs = ObjectBoxStore(embeddings: embeddingsModel, group: widget.group);
+ final model = OpenVINOLLM(await LLMInference.init(llmModelPath, device),
+ defaultOptions: const OpenVINOLLMOptions(temperature: 1, topP: 1, applyTemplate: false)
+ );
+ final promptTemplate = ChatPromptTemplate.fromTemplate('''
+<|system|>
+Answer the question based only on the following context without specifically naming that it's from that context:
+{context}
+
+<|user|>
+{question}
+<|assistant|>
+''');
+ final retriever = vs!.asRetriever();
+
+ return Runnable.fromMap({
+ 'context': retriever | Runnable.mapInput((docs) => docs.map((d) => d.pageContent).join('\n')),
+ 'question': Runnable.passthrough(),
+ }) | promptTemplate | model | const StringOutputParser();
+ }
+
+
+ void runChain(String text) async {
+ final runnable = (await chain)!;
+ setState(() {
+ response = "";
+ });
+ await for (final output in runnable.stream(text)) {
+ setState(() {
+ response = (response ?? "") + output.toString();
+ });
+ }
+ }
+
+ @override
+ void initState() {
+ super.initState();
+ chain = initMemoryStore();
+ }
+
+ @override
+ Widget build(BuildContext context) {
+ return FutureBuilder(
+ future: chain,
+ builder: (context, snapshot) {
+ if (!snapshot.hasData) {
+ return Center(
+ child: Image.asset('images/intel-loading.gif', width: 100)
+ );
+ }
+ return Padding(
+ padding: const EdgeInsets.all(8.0),
+ child: Column(
+ crossAxisAlignment: CrossAxisAlignment.stretch,
+ children: [
+ const Padding(
+ padding: EdgeInsets.only(top: 20),
+ child: Text("Experiment area"),
+ ),
+ TextBox(
+ onSubmitted: runChain,
+ ),
+ Builder(
+ builder: (context) {
+ if (response != null) {
+ return SingleChildScrollView(child: SelectableText(response!.isEmpty ? "..." : response!));
+ }
+ return const Text("Type a message to test RAG");
+ }
+ )
+ ],
+ ),
+ );
+ }
+ );
+ }
+}
diff --git a/lib/pages/knowledge_base/widgets/group_item.dart b/lib/pages/knowledge_base/widgets/group_item.dart
index b81fe866..f0edf27f 100644
--- a/lib/pages/knowledge_base/widgets/group_item.dart
+++ b/lib/pages/knowledge_base/widgets/group_item.dart
@@ -38,14 +38,6 @@ class _GroupItemState extends State {
Widget build(BuildContext context) {
final theme = FluentTheme.of(context);
- if (widget.editable) {
- return TextBox(
- controller: controller,
- onSubmitted: (value) {
- widget.onRename?.call(value);
- },
- );
- }
return GestureDetector(
behavior: HitTestBehavior.opaque,
onTap: () {
@@ -55,8 +47,9 @@ class _GroupItemState extends State {
widget.onMakeEditable?.call();
},
child: Container(
- margin: const EdgeInsets.all(2),
- padding: const EdgeInsets.only(left: 20),
+ padding: const EdgeInsets.all(4),
+ margin: const EdgeInsets.all(4),
+ height: 40,
decoration: BoxDecoration(
border: Border(
left: BorderSide(
@@ -65,14 +58,30 @@ class _GroupItemState extends State {
),
)
),
- child: Row(
- mainAxisAlignment: MainAxisAlignment.spaceBetween,
- children: [
- Text(widget.group.name),
- IconButton(icon: const Icon(FluentIcons.delete), onPressed: () {
- widget.onDelete?.call();
- }),
- ],
+ child: Builder(
+ builder: (context) {
+ if (widget.editable) {
+ return TextBox(
+ controller: controller,
+ onSubmitted: (value) {
+ widget.onRename?.call(value);
+ },
+ );
+ }
+
+ return Padding(
+ padding: const EdgeInsets.only(left: 10, bottom: 5.5),
+ child: Row(
+ mainAxisAlignment: MainAxisAlignment.spaceBetween,
+ children: [
+ Text(widget.group.name),
+ IconButton(icon: const Icon(FluentIcons.delete), onPressed: () {
+ widget.onDelete?.call();
+ }),
+ ],
+ ),
+ );
+ }
),
),
);
diff --git a/lib/pages/knowledge_base/widgets/import_dialog.dart b/lib/pages/knowledge_base/widgets/import_dialog.dart
new file mode 100644
index 00000000..27e754a4
--- /dev/null
+++ b/lib/pages/knowledge_base/widgets/import_dialog.dart
@@ -0,0 +1,153 @@
+import 'dart:io';
+
+import 'package:fluent_ui/fluent_ui.dart';
+import 'package:inference/pages/computer_vision/widgets/horizontal_rule.dart';
+import 'package:inference/pages/knowledge_base/utils/loader_selector.dart';
+import 'package:langchain/langchain.dart';
+import 'package:path/path.dart';
+
+Future