Skip to content

Commit

Permalink
Make Categorize see more chat hisotry. (infiniflow#4538)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

infiniflow#4521

### Type of change
- [x] Performance Improvement
  • Loading branch information
KevinHuSh authored Jan 20, 2025
1 parent 2962284 commit 367babd
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
11 changes: 6 additions & 5 deletions agent/component/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,11 +482,12 @@ def get_input(self):
continue

if q["component_id"].lower().find("answer") == 0:
for r, c in self._canvas.history[::-1]:
if r == "user":
self._param.inputs.append({"content": c, "component_id": q["component_id"]})
outs.append(pd.DataFrame([{"content": c}]))
break
txt = []
for r, c in self._canvas.history[::-1][:self._param.message_history_window_size]:
txt.append(f"{r.upper()}: {c}")
txt = "\n".join(txt)
self._param.inputs.append({"content": txt, "component_id": q["component_id"]})
outs.append(pd.DataFrame([{"content": txt}]))
continue

outs.append(self._canvas.get_component(q["component_id"])["obj"].output(allow_partial=False)[1])
Expand Down
13 changes: 8 additions & 5 deletions agent/component/categorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ def check(self):
if not v.get("to"):
raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!")

def get_prompt(self):
def get_prompt(self, chat_hist):
cate_lines = []
for c, desc in self.category_description.items():
for line in desc.get("examples", "").split("\n"):
if not line:
continue
cate_lines.append("Question: {}\tCategory: {}".format(line, c))
cate_lines.append("USER: {}\nCategory: {}".format(line, c))
descriptions = []
for c, desc in self.category_description.items():
if desc.get("description"):
Expand All @@ -62,11 +62,15 @@ def get_prompt(self):
{}
You could learn from the above examples.
Just mention the category names, no need for any additional words.
---- Real Data ----
{}
""".format(
len(self.category_description.keys()),
"/".join(list(self.category_description.keys())),
"\n".join(descriptions),
"- ".join(cate_lines)
"- ".join(cate_lines),
chat_hist
)
return self.prompt

Expand All @@ -76,9 +80,8 @@ class Categorize(Generate, ABC):

def _run(self, history, **kwargs):
input = self.get_input()
input = "Question: " + (list(input["content"])[-1] if "content" in input else "") + "\tCategory: "
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content": input}],
ans = chat_mdl.chat(self._param.get_prompt(input), [{"role": "user", "content": "\nCategory: "}],
self._param.gen_conf())
logging.debug(f"input: {input}, answer: {str(ans)}")
for c in self._param.category_description.keys():
Expand Down

0 comments on commit 367babd

Please sign in to comment.