-
Notifications
You must be signed in to change notification settings - Fork 0
/
tech-writer.py
executable file
·128 lines (114 loc) · 4.73 KB
/
tech-writer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# tech-writer
#
# @Author: clvgt12
# @URL: https://github.com/clvgt12/tech-writer
# @License: MIT
#
# Implement a simple spelling and grammar check use case
# Use a customized Llama2 model with a specialized system prompt.
# That model is called 'tech-writer'
#
# LLM requirements: install Ollama with tech-writer LLM created
# pip requirements: requests, streamlit
#
# invoke this application as: streamlit run tech-writer.py
import os
import json
import requests
import time
import streamlit as st
from streamlit.logger import get_logger
# Initialize an application logging service
logger = get_logger(__name__)
# Define global variables
ollama_base_url = os.getenv('OLLAMA_BASE_URL', 'http://host.docker.internal:11434')
ollama_logo_url = os.getenv('OLLAMA_LOGO_URL', 'https://ollama.com/public/ollama.png')
ollama_model = os.getenv('OLLAMA_MODEL', 'tech-writer:latest')
class ChatOllama:
"""
Represents a client for interacting with the Ollama service for language model operations.
"""
def __init__(self, base_url=ollama_base_url, model=ollama_model):
"""
Initializes the ChatOllama client with the base URL and model name.
:param base_url: URL of the Ollama API endpoint.
:param model: Name of the model to use for requests.
"""
self.base_url = base_url
self.model = model
def invoke(self, prompt):
"""
Invokes the Ollama model with a given prompt for generating a response.
Enables streaming for real-time processing of the model's output.
:param prompt: Text prompt to send to the model.
:return: Yields responses from the model as they are received.
"""
url = f"{self.base_url}/api/generate"
data = {
"model": self.model,
"prompt": prompt,
"stream": True
}
try:
response = requests.post(url, json=data, stream=True)
if response.status_code == 200:
for line in response.iter_lines():
if line:
json_line = json.loads(line.decode('utf-8'))
if json_line.get("done") is not False:
break
yield json_line.get("response")
else:
logger.error(f"Ollama API call failed with status code {response.status_code}")
yield None
except Exception as e:
logger.error(f"Failed to call Ollama API: {e}")
yield None
def query_ollama(prompt):
"""
Sends a prompt to the Ollama container and retrieves the streamed response.
:param prompt: Text prompt to be corrected or processed by the model.
:return: Streamed responses from the model or an error message.
"""
llm = ChatOllama()
return llm.invoke(prompt)
def check_ollama_status():
"""
Checks the status of the Ollama service by making a GET request to the base endpoint.
Displays a status message in the Streamlit interface based on the response.
"""
try:
response = requests.get(ollama_base_url)
if response.status_code == 200:
st.markdown(f'<div style="background-color:green;color:white;padding:0.5rem;">{response.text}</div>', unsafe_allow_html=True)
else:
st.markdown('<div style="background-color:red;color:white;padding:0.5rem;">Ollama is unavailable.</div>', unsafe_allow_html=True)
except requests.exceptions.RequestException as e:
st.markdown('<div style="background-color:red;color:white;padding:0.5rem;">Ollama is unavailable.</div>', unsafe_allow_html=True)
logger.error(f"Failed to connect to Ollama API: {e}")
def front_end():
"""
Sets up the Streamlit front-end interface, allowing users to input text for processing
by the Ollama model and displays the results.
"""
st.image(ollama_logo_url, width=56)
st.title('Grammar and Spelling Correction')
# Check and display the Ollama service status
check_ollama_status()
st.markdown('<style>div.row-widget.stTextArea { padding-top: 0.5rem; }</style>', unsafe_allow_html=True)
text = st.text_area("Enter text to check:", height=150)
if st.button('Check Syntax'):
st.markdown('<style>h2 { font-size: 1.2rem; }</style>', unsafe_allow_html=True)
if text:
output_placeholder = st.empty()
accumulated_text = ""
for part in query_ollama(text):
if part:
accumulated_text += part
output_placeholder.markdown(accumulated_text)
else:
st.error("Failed to get a response from Ollama")
else:
st.error("Please enter some text to check.")
if __name__ == "__main__":
front_end()