diff --git a/backend/file_management/file_management_helper.py b/backend/file_management/file_management_helper.py
index e1bf7111b..4bbc6ccf8 100644
--- a/backend/file_management/file_management_helper.py
+++ b/backend/file_management/file_management_helper.py
@@ -139,12 +139,20 @@ def upload_file(
# adding filename with path
file_path += file_name
with fs.open(file_path, mode="wb") as remote_file:
- remote_file.write(file.read())
+ if isinstance(file, bytes):
+ remote_file.write(file)
+ else:
+ remote_file.write(file.read())
@staticmethod
@deprecated(reason="Use remote FS APIs from SDK")
- def fetch_file_contents(file_system: UnstractFileSystem, file_path: str) -> Any:
+ def fetch_file_contents(
+ file_system: UnstractFileSystem,
+ file_path: str,
+ allowed_content_types: list[str],
+ ) -> Any:
fs = file_system.get_fsspec_fs()
+
try:
file_info = fs.info(file_path)
except FileNotFoundError:
@@ -152,8 +160,10 @@ def fetch_file_contents(file_system: UnstractFileSystem, file_path: str) -> Any:
file_content_type = file_info.get("ContentType")
file_type = file_info.get("type")
+
if file_type != "file":
raise InvalidFileType
+
try:
if not file_content_type:
file_content_type, _ = mimetypes.guess_type(file_path)
@@ -165,19 +175,26 @@ def fetch_file_contents(file_system: UnstractFileSystem, file_path: str) -> Any:
except ApiRequestError as exception:
logger.error(f"ApiRequestError from {file_info} {exception}")
raise ConnectorApiRequestError
+
+ data = ""
+ # Check if the file type is in the allowed list
+ if file_content_type not in allowed_content_types:
+ raise InvalidFileType(f"File type '{file_content_type}' is not allowed.")
+
+ # Handle allowed file types
if file_content_type == "application/pdf":
- # Read contents of PDF file into a string
with fs.open(file_path, "rb") as file:
- encoded_string = base64.b64encode(file.read())
- return encoded_string
+ data = base64.b64encode(file.read())
elif file_content_type == "text/plain":
with fs.open(file_path, "r") as file:
logger.info(f"Reading text file: {file_path}")
- text_content = file.read()
- return text_content
+ data = file.read()
+
else:
- raise InvalidFileType
+ logger.warning(f"File type '{file_content_type}' is not handled.")
+
+ return {"data": data, "mime_type": file_content_type}
@staticmethod
def _delete_file(fs, file_path):
diff --git a/backend/prompt_studio/prompt_studio_core_v2/serializers.py b/backend/prompt_studio/prompt_studio_core_v2/serializers.py
index d6e79483a..30bc76199 100644
--- a/backend/prompt_studio/prompt_studio_core_v2/serializers.py
+++ b/backend/prompt_studio/prompt_studio_core_v2/serializers.py
@@ -4,7 +4,6 @@
from account_v2.models import User
from account_v2.serializer import UserSerializer
from django.core.exceptions import ObjectDoesNotExist
-from file_management.constants import FileInformationKey
from prompt_studio.prompt_profile_manager_v2.models import ProfileManager
from prompt_studio.prompt_studio_core_v2.constants import ToolStudioKeys as TSKeys
from prompt_studio.prompt_studio_core_v2.exceptions import DefaultProfileError
@@ -23,6 +22,13 @@
logger = logging.getLogger(__name__)
+try:
+ from plugins.processor.file_converter.constants import (
+ ExtentedFileInformationKey as FileKey,
+ )
+except ImportError:
+ from file_management.constants import FileInformationKey as FileKey
+
class CustomToolSerializer(IntegrityErrorMixin, AuditSerializer):
shared_users = serializers.PrimaryKeyRelatedField(
@@ -151,10 +157,10 @@ class FileUploadIdeSerializer(serializers.Serializer):
required=True,
validators=[
FileValidator(
- allowed_extensions=FileInformationKey.FILE_UPLOAD_ALLOWED_EXT,
- allowed_mimetypes=FileInformationKey.FILE_UPLOAD_ALLOWED_MIME,
+ allowed_extensions=FileKey.FILE_UPLOAD_ALLOWED_EXT,
+ allowed_mimetypes=FileKey.FILE_UPLOAD_ALLOWED_MIME,
min_size=0,
- max_size=FileInformationKey.FILE_UPLOAD_MAX_SIZE,
+ max_size=FileKey.FILE_UPLOAD_MAX_SIZE,
)
],
)
diff --git a/backend/prompt_studio/prompt_studio_core_v2/views.py b/backend/prompt_studio/prompt_studio_core_v2/views.py
index 46228b7bc..87c1b034a 100644
--- a/backend/prompt_studio/prompt_studio_core_v2/views.py
+++ b/backend/prompt_studio/prompt_studio_core_v2/views.py
@@ -6,6 +6,7 @@
from django.db import IntegrityError
from django.db.models import QuerySet
from django.http import HttpRequest
+from file_management.constants import FileInformationKey as FileKey
from file_management.exceptions import FileNotFound
from file_management.file_management_helper import FileManagerHelper
from permissions.permission import IsOwner, IsOwnerOrSharedUser
@@ -404,7 +405,14 @@ def fetch_contents_ide(self, request: HttpRequest, pk: Any = None) -> Response:
document: DocumentManager = DocumentManager.objects.get(pk=document_id)
file_name: str = document.document_name
view_type: str = serializer.validated_data.get("view_type")
+ file_converter = get_plugin_class_by_name(
+ name="file_converter",
+ plugins=self.processor_plugins,
+ )
+ allowed_content_types = FileKey.FILE_UPLOAD_ALLOWED_MIME
+ if file_converter:
+ allowed_content_types = file_converter.get_extented_file_information_key()
filename_without_extension = file_name.rsplit(".", 1)[0]
if view_type == FileViewTypes.EXTRACT:
file_name = (
@@ -430,7 +438,9 @@ def fetch_contents_ide(self, request: HttpRequest, pk: Any = None) -> Response:
file_path += file_name
# Temporary Hack for frictionless onboarding as the user id will be empty
try:
- contents = FileManagerHelper.fetch_file_contents(file_system, file_path)
+ contents = FileManagerHelper.fetch_file_contents(
+ file_system, file_path, allowed_content_types
+ )
except FileNotFound:
file_path = file_path = (
FileManagerHelper.handle_sub_directory_for_tenants(
@@ -462,10 +472,23 @@ def upload_for_ide(self, request: HttpRequest, pk: Any = None) -> Response:
serializer = FileUploadIdeSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
uploaded_files: Any = serializer.validated_data.get("file")
+ file_converter = get_plugin_class_by_name(
+ name="file_converter",
+ plugins=self.processor_plugins,
+ )
+
documents = []
for uploaded_file in uploaded_files:
# Store file
file_name = uploaded_file.name
+ file_data = uploaded_file
+ file_type = uploaded_file.content_type
+ # Convert non-PDF files
+ if file_converter and file_type != "application/pdf":
+ file_data, file_name = file_converter.process_file(
+ uploaded_file, file_name
+ )
+
logger.info(
f"Uploading file: {file_name}" if file_name else "Uploading file"
)
@@ -480,7 +503,7 @@ def upload_for_ide(self, request: HttpRequest, pk: Any = None) -> Response:
FileManagerHelper.upload_file(
file_system,
file_path,
- uploaded_file,
+ file_data,
file_name,
)
else:
@@ -488,7 +511,7 @@ def upload_for_ide(self, request: HttpRequest, pk: Any = None) -> Response:
org_id=UserSessionUtils.get_organization_id(request),
user_id=custom_tool.created_by.user_id,
tool_id=str(custom_tool.tool_id),
- uploaded_file=uploaded_file,
+ uploaded_file=file_data,
)
# Create a record in the db for the file
diff --git a/frontend/src/components/custom-tools/document-manager/DocumentManager.jsx b/frontend/src/components/custom-tools/document-manager/DocumentManager.jsx
index e5e724db5..2cafd2dca 100644
--- a/frontend/src/components/custom-tools/document-manager/DocumentManager.jsx
+++ b/frontend/src/components/custom-tools/document-manager/DocumentManager.jsx
@@ -13,7 +13,10 @@ import { useEffect, useState } from "react";
import { useParams } from "react-router-dom";
import "./DocumentManager.css";
-import { base64toBlob, docIndexStatus } from "../../../helpers/GetStaticData";
+import {
+ base64toBlobWithMime,
+ docIndexStatus,
+} from "../../../helpers/GetStaticData";
import { useAxiosPrivate } from "../../../hooks/useAxiosPrivate";
import { useCustomToolStore } from "../../../store/custom-tool-store";
import { useSessionStore } from "../../../store/session-store";
@@ -22,6 +25,7 @@ import { ManageDocsModal } from "../manage-docs-modal/ManageDocsModal";
import { PdfViewer } from "../pdf-viewer/PdfViewer";
import { TextViewerPre } from "../text-viewer-pre/TextViewerPre";
import usePostHogEvents from "../../../hooks/usePostHogEvents";
+import { TextViewer } from "../text-viewer/TextViewer";
let items = [
{
@@ -103,6 +107,20 @@ function DocumentManager({ generateIndex, handleUpdateTool, handleDocChange }) {
const { id } = useParams();
const highlightData = selectedHighlight?.highlight || [];
+ const [blobFileUrl, setBlobFileUrl] = useState("");
+ const [fileData, setFileData] = useState({});
+
+ useEffect(() => {
+ // Convert blob URL to an object URL
+ if (fileData.blob) {
+ const objectUrl = URL.createObjectURL(fileData.blob);
+ setBlobFileUrl(objectUrl);
+
+ // Clean up the URL after component unmount
+ return () => URL.revokeObjectURL(objectUrl);
+ }
+ }, [fileData]);
+
useEffect(() => {
if (isSimplePromptStudio) {
items = [
@@ -199,7 +217,8 @@ function DocumentManager({ generateIndex, handleUpdateTool, handleDocChange }) {
getDocsFunc(details?.tool_id, selectedDoc?.document_id, viewType)
.then((res) => {
const data = res?.data?.data || "";
- processGetDocsResponse(data, viewType);
+ const mimeType = res?.data?.mime_type || "";
+ processGetDocsResponse(data, viewType, mimeType);
})
.catch((err) => {
handleGetDocsError(err, viewType);
@@ -226,11 +245,19 @@ function DocumentManager({ generateIndex, handleUpdateTool, handleDocChange }) {
});
};
- const processGetDocsResponse = (data, viewType) => {
+ const processGetDocsResponse = (data, viewType, mimeType) => {
if (viewType === viewTypes.original) {
const base64String = data || "";
- const blob = base64toBlob(base64String);
- setFileUrl(URL.createObjectURL(blob));
+ const blob = base64toBlobWithMime(base64String, mimeType);
+ setFileData({ blob, mimeType });
+ const reader = new FileReader();
+ reader.readAsDataURL(blob);
+ reader.onload = () => {
+ setFileUrl(reader.result);
+ };
+ reader.onerror = () => {
+ throw new Error("Fail to load the file");
+ };
} else if (viewType === viewTypes.extract) {
setExtractTxt(data);
}
@@ -317,6 +344,19 @@ function DocumentManager({ generateIndex, handleUpdateTool, handleDocChange }) {
}
};
+ const renderDoc = (docName, fileUrl, highlightData) => {
+ const fileType = docName?.split(".").pop().toLowerCase(); // Get the file extension
+ switch (fileType) {
+ case "pdf":
+ return