-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathhpm.py
73 lines (62 loc) · 2.67 KB
/
hpm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import dataclasses
import json
import os
from typing import Dict, Any, List
from hivemind_bus_client.message import HiveMessage, HiveMessageType
from hivemind_core.protocol import AgentProtocol
from ovos_bus_client import MessageBusClient
from ovos_bus_client.message import Message
from ovos_bus_client.session import SessionManager
from ovos_utils.fakebus import FakeBus
from ovos_utils.log import LOG
from ovos_persona import Persona
@dataclasses.dataclass()
class PersonaProtocol(AgentProtocol):
bus: MessageBusClient = dataclasses.field(default_factory=FakeBus)
config: Dict[str, Any] = dataclasses.field(default_factory=dict)
sessions: Dict[str, List[Dict[str, str]]] = dataclasses.field(default_factory=dict)
persona: Persona = None
def __post_init__(self):
if not self.persona:
persona_json = self.config.get("persona")
if not persona_json:
persona = {
"name": "ChatGPT",
"solvers": [
"ovos-solver-openai-persona-plugin"
],
"ovos-solver-openai-persona-plugin": {
"api_url": "https://llama.smartgic.io/v1",
"key": "sk-xxxx",
"persona": "helpful, creative, clever, and very friendly."
}
}
name = persona.get("name")
else:
with open(persona_json) as f:
persona = json.load(f)
name = persona.get("name") or os.path.basename(persona_json)
self.persona = Persona(name=name, config=persona)
LOG.debug("registering internal OVOS bus handlers")
self.bus.on("recognizer_loop:utterance", self.handle_utterance) # catch all
def handle_utterance(self, message: Message):
"""
message (Message): mycroft bus message object
"""
utt = message.data["utterances"][0]
sess = SessionManager.get(message)
# track msg history
if sess.session_id not in self.sessions:
self.sessions[sess.session_id] = []
self.sessions[sess.session_id].append({"role": "user", "content": utt})
answer = self.persona.chat(self.sessions[sess.session_id], lang=sess.lang).strip()
peer = message.context["source"]
client = self.clients[peer]
msg = HiveMessage(
HiveMessageType.BUS,
source_peer=self.hm_protocol.peer,
target_peers=[peer],
payload=message.reply("speak", {"utterance": answer}),
)
client.send(msg)
self.sessions[sess.session_id].append({"role": "assistant", "content": answer})