Skip to content

Commit

Permalink
Badly working knowledge base selector in llm
Browse files Browse the repository at this point in the history
However, its rough, it assumes a hardcoded path to an embedder
Looks like state is not well done thus it doesnt know when it's done.
Just a lot of problems to fix before release but a working prototype.
  • Loading branch information
RHeckerIntel committed Dec 4, 2024
1 parent 69d2450 commit fce62b3
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 5 deletions.
1 change: 1 addition & 0 deletions lib/pages/knowledge_base/widgets/documents_list.dart
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import 'package:inference/interop/sentence_transformer.dart';
import 'package:inference/langchain/object_box/embedding_entity.dart';
import 'package:inference/langchain/object_box/object_box.dart';
import 'package:inference/objectbox.g.dart';
import 'package:inference/pages/knowledge_base/widgets/experiment.dart';
import 'package:inference/pages/knowledge_base/widgets/import_dialog.dart';
import 'package:inference/pages/models/widgets/grid_container.dart';
import 'package:inference/theme_fluent.dart';
Expand Down
4 changes: 3 additions & 1 deletion lib/pages/text_generation/playground.dart
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import 'package:inference/pages/models/widgets/grid_container.dart';
import 'package:inference/pages/text_generation/widgets/assistant_message.dart';
import 'package:inference/pages/text_generation/widgets/model_properties.dart';
import 'package:inference/pages/text_generation/widgets/user_message.dart';
import 'package:inference/pages/text_generation/widgets/knowledge_base_selector.dart';
import 'package:inference/project.dart';
import 'package:inference/providers/text_inference_provider.dart';
import 'package:inference/theme_fluent.dart';
Expand Down Expand Up @@ -122,6 +123,7 @@ class _PlaygroundState extends State<Playground> {
max: 1.0,
min: 0.1,
),
const KnowledgeBaseSelector()
],
)
),
Expand Down Expand Up @@ -283,4 +285,4 @@ class _PlaygroundState extends State<Playground> {
],
);
}
}
}
63 changes: 63 additions & 0 deletions lib/pages/text_generation/widgets/knowledge_base_selector.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import 'package:fluent_ui/fluent_ui.dart';
import 'package:inference/langchain/object_box/embedding_entity.dart';
import 'package:inference/langchain/object_box/object_box.dart';
import 'package:inference/objectbox.g.dart';
import 'package:inference/providers/text_inference_provider.dart';
import 'package:inference/widgets/controls/no_outline_button.dart';
import 'package:provider/provider.dart';

class KnowledgeBaseSelector extends StatefulWidget {
const KnowledgeBaseSelector({super.key});

@override
State<KnowledgeBaseSelector> createState() => _KnowledgeBaseSelectorState();
}

class _KnowledgeBaseSelectorState extends State<KnowledgeBaseSelector> {
late List<KnowledgeGroup> groups;

@override
void initState() {
super.initState();
groups = ObjectBox.instance.store.box<KnowledgeGroup>().getAll();
}

@override
Widget build(BuildContext context) {
return Consumer<TextInferenceProvider>(builder: (context, inference, child) {
return DropDownButton(
buttonBuilder: (context, callback) {
return NoOutlineButton(
onPressed: callback,
child: Padding(
padding: const EdgeInsets.all(8.0),
child: Row(
children: [
(inference.knowledgeGroup == null
? const Text("Knowledge Base")
: Text("Knowledge Base: ${inference.knowledgeGroup!.name}")
),
const Padding(
padding: EdgeInsets.only(left: 8),
child: Icon(FluentIcons.chevron_down, size: 12),
),
],
),
),
);
},
items: [
MenuFlyoutItem(text: const Text("None"), onPressed: () {
inference.knowledgeGroup = null;
}),
for (final group in groups)

MenuFlyoutItem(text: Text(group.name), onPressed: () {
inference.knowledgeGroup = group;
})
]
);
}
);
}
}
72 changes: 68 additions & 4 deletions lib/providers/text_inference_provider.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@ import 'dart:async';
import 'package:flutter/material.dart';
import 'package:inference/interop/generated_bindings.dart';
import 'package:inference/interop/llm_inference.dart';
import 'package:inference/langchain/object_box/embedding_entity.dart';
import 'package:inference/langchain/object_box_store.dart';
import 'package:inference/langchain/openvino_embeddings.dart';
import 'package:inference/langchain/openvino_llm.dart';
import 'package:inference/project.dart';
import 'package:langchain/langchain.dart';
import 'package:path/path.dart';
import 'package:path_provider/path_provider.dart';

enum Speaker { system, assistant, user }

Expand Down Expand Up @@ -37,6 +44,43 @@ class Message {
const Message(this.speaker, this.message, this.metrics, this.time);
}

Future<Runnable> buildChain(LLMInference inference, KnowledgeGroup? group) async {
final platformContext = Context(style: Style.platform);
final directory = await getApplicationSupportDirectory();
const device = "CPU";
final embeddingsModelPath = platformContext.join(directory.path, "test", "all-MiniLM-L6-v2", "fp16");
final embeddingsModel = await OpenVINOEmbeddings.init(embeddingsModelPath, device);

if (group != null) {
final vs = ObjectBoxStore(embeddings: embeddingsModel, group: group);
final model = OpenVINOLLM(inference, defaultOptions: const OpenVINOLLMOptions(temperature: 1, topP: 1, applyTemplate: false));


final promptTemplate = ChatPromptTemplate.fromTemplate('''
<|system|>
Answer the question based only on the following context without specifically naming that it's from that context:
{context}
<|user|>
{question}
<|assistant|>
''');
final retriever = vs.asRetriever();

return Runnable.fromMap<String>({
'context': retriever | Runnable.mapInput((docs) => docs.map((d) => d.pageContent).join('\n')),
'question': Runnable.passthrough(),
}) | promptTemplate | model | const StringOutputParser();
} else {
final model = OpenVINOLLM(inference, defaultOptions: const OpenVINOLLMOptions(temperature: 1, topP: 1, applyTemplate: true));
final promptTemplate = ChatPromptTemplate.fromTemplate("{question}");

return Runnable.fromMap<String>({
'question': Runnable.passthrough(),
}) | promptTemplate | model | const StringOutputParser();
}
}

class TextInferenceProvider extends ChangeNotifier {

Completer<void> loaded = Completer<void>();
Expand All @@ -48,6 +92,15 @@ class TextInferenceProvider extends ChangeNotifier {
String? get device => _device;
Metrics? get metrics => _messages.lastOrNull?.metrics;

Future<Runnable>? chain;

KnowledgeGroup? _knowledgeGroup;
KnowledgeGroup? get knowledgeGroup => _knowledgeGroup;
set knowledgeGroup(KnowledgeGroup? group) {
_knowledgeGroup = group;
notifyListeners();
}

double _temperature = 1;
double get temperature => _temperature;
set temperature(double v) {
Expand All @@ -73,8 +126,8 @@ class TextInferenceProvider extends ChangeNotifier {

Future<void> loadModel() async {
if (project != null && device != null) {
_inference = await LLMInference.init(project!.storagePath, device!)
..setListener(onMessage);
_inference = await LLMInference.init(project!.storagePath, device!);
chain = buildChain(_inference!, knowledgeGroup);
loaded.complete();
notifyListeners();
}
Expand Down Expand Up @@ -149,13 +202,24 @@ class TextInferenceProvider extends ChangeNotifier {
}

Future<void> message(String message) async {

_response = "...";
_messages.add(Message(Speaker.user, message, null, DateTime.now()));
notifyListeners();
final response = await _inference!.prompt(message, true, temperature, topP);
chain = buildChain(_inference!, knowledgeGroup);
final runnable = (await chain)!;
//final response = await _inference!.prompt(message, true, temperature, topP);

String modelOutput = "";
await for (final output in runnable.stream(message)) {
final token = output.toString();
modelOutput += token;
onMessage(token);
}
print("end...");

if (_messages.isNotEmpty) {
_messages.add(Message(Speaker.assistant, response.content, response.metrics, DateTime.now()));
_messages.add(Message(Speaker.assistant, modelOutput, null, DateTime.now()));
}
_response = null;
n = 0;
Expand Down

0 comments on commit fce62b3

Please sign in to comment.