diff --git a/pkg/api/transactions.go b/pkg/api/transactions.go index e7099ee5..75c7ca20 100644 --- a/pkg/api/transactions.go +++ b/pkg/api/transactions.go @@ -2,6 +2,7 @@ package api import ( "sort" + "strings" orderedmap "github.com/wk8/go-ordered-map/v2" @@ -45,14 +46,14 @@ func (a *TransactionsApi) TransactionCountHandler(c *core.Context) (any, *errs.E uid := c.GetCurrentUid() - allAccountIds, err := a.getAccountOrSubAccountIds(c, transactionCountReq.AccountId, uid) + allAccountIds, err := a.getAccountOrSubAccountIds(c, transactionCountReq.AccountIds, uid) if err != nil { log.WarnfWithRequestId(c, "[transactions.TransactionCountHandler] get account error, because %s", err.Error()) return nil, errs.Or(err, errs.ErrOperationFailed) } - allCategoryIds, err := a.getCategoryOrSubCategoryIds(c, transactionCountReq.CategoryId, uid) + allCategoryIds, err := a.getCategoryOrSubCategoryIds(c, transactionCountReq.CategoryIds, uid) if err != nil { log.WarnfWithRequestId(c, "[transactions.TransactionCountHandler] get transaction category error, because %s", err.Error()) @@ -101,14 +102,14 @@ func (a *TransactionsApi) TransactionListHandler(c *core.Context) (any, *errs.Er return nil, errs.ErrUserNotFound } - allAccountIds, err := a.getAccountOrSubAccountIds(c, transactionListReq.AccountId, uid) + allAccountIds, err := a.getAccountOrSubAccountIds(c, transactionListReq.AccountIds, uid) if err != nil { log.WarnfWithRequestId(c, "[transactions.TransactionListHandler] get account error, because %s", err.Error()) return nil, errs.Or(err, errs.ErrOperationFailed) } - allCategoryIds, err := a.getCategoryOrSubCategoryIds(c, transactionListReq.CategoryId, uid) + allCategoryIds, err := a.getCategoryOrSubCategoryIds(c, transactionListReq.CategoryIds, uid) if err != nil { log.WarnfWithRequestId(c, "[transactions.TransactionListHandler] get transaction category error, because %s", err.Error()) @@ -192,14 +193,14 @@ func (a *TransactionsApi) TransactionMonthListHandler(c *core.Context) (any, *er return nil, errs.ErrUserNotFound } - allAccountIds, err := a.getAccountOrSubAccountIds(c, transactionListReq.AccountId, uid) + allAccountIds, err := a.getAccountOrSubAccountIds(c, transactionListReq.AccountIds, uid) if err != nil { log.WarnfWithRequestId(c, "[transactions.TransactionMonthListHandler] get account error, because %s", err.Error()) return nil, errs.Or(err, errs.ErrOperationFailed) } - allCategoryIds, err := a.getCategoryOrSubCategoryIds(c, transactionListReq.CategoryId, uid) + allCategoryIds, err := a.getCategoryOrSubCategoryIds(c, transactionListReq.CategoryIds, uid) if err != nil { log.WarnfWithRequestId(c, "[transactions.TransactionMonthListHandler] get transaction category error, because %s", err.Error()) @@ -839,44 +840,104 @@ func (a *TransactionsApi) filterTransactions(c *core.Context, uid int64, transac return finalTransactions } -func (a *TransactionsApi) getAccountOrSubAccountIds(c *core.Context, accountId int64, uid int64) ([]int64, error) { +func (a *TransactionsApi) getAccountOrSubAccountIds(c *core.Context, accountIds string, uid int64) ([]int64, error) { + if accountIds == "" || accountIds == "0" { + return nil, nil + } + + requestAccountIds, err := utils.StringArrayToInt64Array(strings.Split(accountIds, ",")) + + if err != nil { + return nil, errs.Or(err, errs.ErrAccountIdInvalid) + } + var allAccountIds []int64 - if accountId > 0 { - allSubAccounts, err := a.accounts.GetSubAccountsByAccountId(c, uid, accountId) + if len(requestAccountIds) > 0 { + allSubAccounts, err := a.accounts.GetSubAccountsByAccountIds(c, uid, requestAccountIds) if err != nil { return nil, err } - if len(allSubAccounts) > 0 { - for i := 0; i < len(allSubAccounts); i++ { - allAccountIds = append(allAccountIds, allSubAccounts[i].AccountId) + accountIdsMap := make(map[int64]int32, len(requestAccountIds)) + + for i := 0; i < len(requestAccountIds); i++ { + accountIdsMap[requestAccountIds[i]] = 0 + } + + for i := 0; i < len(allSubAccounts); i++ { + subAccount := allSubAccounts[i] + + if refCount, exists := accountIdsMap[subAccount.ParentAccountId]; exists { + accountIdsMap[subAccount.ParentAccountId] = refCount + 1 + } else { + accountIdsMap[subAccount.ParentAccountId] = 1 + } + + if _, exists := accountIdsMap[subAccount.AccountId]; exists { + delete(accountIdsMap, subAccount.AccountId) + } + + allAccountIds = append(allAccountIds, subAccount.AccountId) + } + + for accountId, refCount := range accountIdsMap { + if refCount < 1 { + allAccountIds = append(allAccountIds, accountId) } - } else { - allAccountIds = append(allAccountIds, accountId) } } return allAccountIds, nil } -func (a *TransactionsApi) getCategoryOrSubCategoryIds(c *core.Context, categoryId int64, uid int64) ([]int64, error) { +func (a *TransactionsApi) getCategoryOrSubCategoryIds(c *core.Context, categoryIds string, uid int64) ([]int64, error) { + if categoryIds == "" || categoryIds == "0" { + return nil, nil + } + + requestCategoryIds, err := utils.StringArrayToInt64Array(strings.Split(categoryIds, ",")) + + if err != nil { + return nil, errs.Or(err, errs.ErrTransactionCategoryIdInvalid) + } + var allCategoryIds []int64 - if categoryId > 0 { - allSubCategories, err := a.transactionCategories.GetAllCategoriesByUid(c, uid, 0, categoryId) + if len(requestCategoryIds) > 0 { + allSubCategories, err := a.transactionCategories.GetSubCategoriesByCategoryIds(c, uid, requestCategoryIds) if err != nil { return nil, err } - if len(allSubCategories) > 0 { - for i := 0; i < len(allSubCategories); i++ { - allCategoryIds = append(allCategoryIds, allSubCategories[i].CategoryId) + categoryIdsMap := make(map[int64]int32, len(requestCategoryIds)) + + for i := 0; i < len(requestCategoryIds); i++ { + categoryIdsMap[requestCategoryIds[i]] = 0 + } + + for i := 0; i < len(allSubCategories); i++ { + subCategory := allSubCategories[i] + + if refCount, exists := categoryIdsMap[subCategory.ParentCategoryId]; exists { + categoryIdsMap[subCategory.ParentCategoryId] = refCount + 1 + } else { + categoryIdsMap[subCategory.ParentCategoryId] = 1 + } + + if _, exists := categoryIdsMap[subCategory.CategoryId]; exists { + delete(categoryIdsMap, subCategory.CategoryId) + } + + allCategoryIds = append(allCategoryIds, subCategory.CategoryId) + } + + for accountId, refCount := range categoryIdsMap { + if refCount < 1 { + allCategoryIds = append(allCategoryIds, accountId) } - } else { - allCategoryIds = append(allCategoryIds, categoryId) } } diff --git a/pkg/models/transaction.go b/pkg/models/transaction.go index 9bb5786e..8c6ff7f6 100644 --- a/pkg/models/transaction.go +++ b/pkg/models/transaction.go @@ -95,8 +95,8 @@ type TransactionModifyRequest struct { // TransactionCountRequest represents transaction count request type TransactionCountRequest struct { Type TransactionDbType `form:"type" binding:"min=0,max=4"` - CategoryId int64 `form:"category_id" binding:"min=0"` - AccountId int64 `form:"account_id" binding:"min=0"` + CategoryIds string `form:"category_ids"` + AccountIds string `form:"account_ids"` AmountFilter string `form:"amount_filter" binding:"validAmountFilter"` Keyword string `form:"keyword"` MaxTime int64 `form:"max_time" binding:"min=0"` @@ -106,8 +106,8 @@ type TransactionCountRequest struct { // TransactionListByMaxTimeRequest represents all parameters of transaction listing by max time request type TransactionListByMaxTimeRequest struct { Type TransactionDbType `form:"type" binding:"min=0,max=4"` - CategoryId int64 `form:"category_id" binding:"min=0"` - AccountId int64 `form:"account_id" binding:"min=0"` + CategoryIds string `form:"category_ids"` + AccountIds string `form:"account_ids"` AmountFilter string `form:"amount_filter" binding:"validAmountFilter"` Keyword string `form:"keyword"` MaxTime int64 `form:"max_time" binding:"min=0"` @@ -125,8 +125,8 @@ type TransactionListInMonthByPageRequest struct { Year int32 `form:"year" binding:"required,min=1"` Month int32 `form:"month" binding:"required,min=1"` Type TransactionDbType `form:"type" binding:"min=0,max=4"` - CategoryId int64 `form:"category_id" binding:"min=0"` - AccountId int64 `form:"account_id" binding:"min=0"` + CategoryIds string `form:"category_ids"` + AccountIds string `form:"account_ids"` AmountFilter string `form:"amount_filter" binding:"validAmountFilter"` Keyword string `form:"keyword"` TrimAccount bool `form:"trim_account"` diff --git a/pkg/services/accounts.go b/pkg/services/accounts.go index 08959fea..dcb27869 100644 --- a/pkg/services/accounts.go +++ b/pkg/services/accounts.go @@ -1,6 +1,7 @@ package services import ( + "strings" "time" "xorm.io/xorm" @@ -86,6 +87,48 @@ func (s *AccountService) GetSubAccountsByAccountId(c *core.Context, uid int64, a return accounts, err } +// GetSubAccountsByAccountIds returns sub-account models according to account ids +func (s *AccountService) GetSubAccountsByAccountIds(c *core.Context, uid int64, accountIds []int64) ([]*models.Account, error) { + if uid <= 0 { + return nil, errs.ErrUserIdInvalid + } + + if len(accountIds) <= 0 { + return nil, errs.ErrAccountIdInvalid + } + + condition := "uid=? AND deleted=?" + conditionParams := make([]any, 0, len(accountIds)+2) + conditionParams = append(conditionParams, uid) + conditionParams = append(conditionParams, false) + + var accountIdConditions strings.Builder + + for i := 0; i < len(accountIds); i++ { + if accountIds[i] <= 0 { + return nil, errs.ErrAccountIdInvalid + } + + if accountIdConditions.Len() > 0 { + accountIdConditions.WriteString(",") + } + + accountIdConditions.WriteString("?") + conditionParams = append(conditionParams, accountIds[i]) + } + + if accountIdConditions.Len() > 1 { + condition = condition + " AND parent_account_id IN (" + accountIdConditions.String() + ")" + } else { + condition = condition + " AND parent_account_id = " + accountIdConditions.String() + } + + var accounts []*models.Account + err := s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...).OrderBy("display_order asc").Find(&accounts) + + return accounts, err +} + // GetAccountsByAccountIds returns account models according to account ids func (s *AccountService) GetAccountsByAccountIds(c *core.Context, uid int64, accountIds []int64) (map[int64]*models.Account, error) { if uid <= 0 { diff --git a/pkg/services/transaction_categories.go b/pkg/services/transaction_categories.go index f493e173..115155dd 100644 --- a/pkg/services/transaction_categories.go +++ b/pkg/services/transaction_categories.go @@ -1,6 +1,7 @@ package services import ( + "strings" "time" "xorm.io/xorm" @@ -68,6 +69,48 @@ func (s *TransactionCategoryService) GetAllCategoriesByUid(c *core.Context, uid return categories, err } +// GetSubCategoriesByCategoryIds returns sub-category models according to category ids +func (s *TransactionCategoryService) GetSubCategoriesByCategoryIds(c *core.Context, uid int64, categoryIds []int64) ([]*models.TransactionCategory, error) { + if uid <= 0 { + return nil, errs.ErrUserIdInvalid + } + + if len(categoryIds) <= 0 { + return nil, errs.ErrTransactionCategoryIdInvalid + } + + condition := "uid=? AND deleted=?" + conditionParams := make([]any, 0, len(categoryIds)+2) + conditionParams = append(conditionParams, uid) + conditionParams = append(conditionParams, false) + + var categoryIdConditions strings.Builder + + for i := 0; i < len(categoryIds); i++ { + if categoryIds[i] <= 0 { + return nil, errs.ErrTransactionCategoryIdInvalid + } + + if categoryIdConditions.Len() > 0 { + categoryIdConditions.WriteString(",") + } + + categoryIdConditions.WriteString("?") + conditionParams = append(conditionParams, categoryIds[i]) + } + + if categoryIdConditions.Len() > 1 { + condition = condition + " AND parent_category_id IN (" + categoryIdConditions.String() + ")" + } else { + condition = condition + " AND parent_category_id = " + categoryIdConditions.String() + } + + var categories []*models.TransactionCategory + err := s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...).OrderBy("display_order asc").Find(&categories) + + return categories, err +} + // GetCategoryByCategoryId returns a transaction category model according to transaction category id func (s *TransactionCategoryService) GetCategoryByCategoryId(c *core.Context, uid int64, categoryId int64) (*models.TransactionCategory, error) { if uid <= 0 { diff --git a/src/lib/services.js b/src/lib/services.js index 54445e05..4ae54d8f 100644 --- a/src/lib/services.js +++ b/src/lib/services.js @@ -282,12 +282,12 @@ export default { getTransactions: ({ maxTime, minTime, count, page, withCount, type, categoryId, accountId, amountFilter, keyword }) => { amountFilter = encodeURIComponent(amountFilter); keyword = encodeURIComponent(keyword); - return axios.get(`v1/transactions/list.json?max_time=${maxTime}&min_time=${minTime}&type=${type}&category_id=${categoryId}&account_id=${accountId}&amount_filter=${amountFilter}&keyword=${keyword}&count=${count}&page=${page}&with_count=${withCount}&trim_account=true&trim_category=true&trim_tag=true`); + return axios.get(`v1/transactions/list.json?max_time=${maxTime}&min_time=${minTime}&type=${type}&category_ids=${categoryId}&account_ids=${accountId}&amount_filter=${amountFilter}&keyword=${keyword}&count=${count}&page=${page}&with_count=${withCount}&trim_account=true&trim_category=true&trim_tag=true`); }, getAllTransactionsByMonth: ({ year, month, type, categoryId, accountId, amountFilter, keyword }) => { amountFilter = encodeURIComponent(amountFilter); keyword = encodeURIComponent(keyword); - return axios.get(`v1/transactions/list/by_month.json?year=${year}&month=${month}&type=${type}&category_id=${categoryId}&account_id=${accountId}&amount_filter=${amountFilter}&keyword=${keyword}&trim_account=true&trim_category=true&trim_tag=true`); + return axios.get(`v1/transactions/list/by_month.json?year=${year}&month=${month}&type=${type}&category_ids=${categoryId}&account_ids=${accountId}&amount_filter=${amountFilter}&keyword=${keyword}&trim_account=true&trim_category=true&trim_tag=true`); }, getTransactionStatistics: ({ startTime, endTime, useTransactionTimezone }) => { const queryParams = [];