diff --git a/api/disabled_parameters.go b/api/disabled_parameters.go new file mode 100644 index 000000000..3a2d0c4c0 --- /dev/null +++ b/api/disabled_parameters.go @@ -0,0 +1,149 @@ +package api + +import ( + "fmt" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/labstack/echo/v4" +) + +// EndpointConfig is a data structure that contains whether the +// endpoint is disabled (with a boolean) as well as a set that +// contains disabled optional parameters. The disabled optional parameter +// set is keyed by the name of the variable +type EndpointConfig struct { + EndpointDisabled bool + DisabledOptionalParameters map[string]bool +} + +// NewEndpointConfig creates a new empty endpoint config +func NewEndpointConfig() *EndpointConfig { + rval := &EndpointConfig{ + EndpointDisabled: false, + DisabledOptionalParameters: make(map[string]bool), + } + + return rval +} + +// DisabledMap a type that holds a map of disabled types +// The key for a disabled map is the handler function name +type DisabledMap struct { + Data map[string]*EndpointConfig +} + +// NewDisabledMap creates a new empty disabled map +func NewDisabledMap() *DisabledMap { + return &DisabledMap{ + Data: make(map[string]*EndpointConfig), + } +} + +// NewDisabledMapFromOA3 Creates a new disabled map from an openapi3 definition +func NewDisabledMapFromOA3(swag *openapi3.Swagger) *DisabledMap { + rval := NewDisabledMap() + for _, item := range swag.Paths { + for _, opItem := range item.Operations() { + + endpointConfig := NewEndpointConfig() + + for _, pref := range opItem.Parameters { + + // TODO how to enable it to be disabled + parameterIsDisabled := false + if !parameterIsDisabled { + // If the parameter is not disabled, then we don't need + // to do anything + continue + } + + if pref.Value.Required { + // If an endpoint config required parameter is disabled, then the whole endpoint is disabled + endpointConfig.EndpointDisabled = true + } else { + // If the optional parameter is disabled, add it to the map + endpointConfig.DisabledOptionalParameters[pref.Value.Name] = true + } + } + + rval.Data[opItem.OperationID] = endpointConfig + + } + + } + + return rval +} + +// ErrVerifyFailedEndpoint an error that signifies that the entire endpoint is disabled +var ErrVerifyFailedEndpoint error = fmt.Errorf("endpoint is disabled") + +// ErrVerifyFailedParameter an error that signifies that a parameter was provided when it was disabled +type ErrVerifyFailedParameter struct { + ParameterName string +} + +func (evfp ErrVerifyFailedParameter) Error() string { + return fmt.Sprintf("provided disabled parameter: %s", evfp.ParameterName) +} + +// DisabledParameterErrorReporter defines an error reporting interface +// for the Verify functions +type DisabledParameterErrorReporter interface { + Errorf(format string, args ...interface{}) +} + +// Verify returns nil if the function can continue (i.e. the parameters are valid and disabled +// parameters are not supplied), otherwise VerifyFailedEndpoint if the endpoint failed and +// VerifyFailedParameter if a disabled parameter was provided. +func Verify(dm *DisabledMap, nameOfHandlerFunc string, ctx echo.Context, log DisabledParameterErrorReporter) error { + + if dm == nil || dm.Data == nil { + return nil + } + + if val, ok := dm.Data[nameOfHandlerFunc]; ok { + return val.verify(ctx, log) + } + + // If the function name wasn't in the map something got messed up.... + log.Errorf("verify function could not find name of handler function in map: %s", nameOfHandlerFunc) + // We want to fail-safe to not stop the indexer + return nil +} + +func (ec *EndpointConfig) verify(ctx echo.Context, log DisabledParameterErrorReporter) error { + + if ec.EndpointDisabled { + return ErrVerifyFailedEndpoint + } + + queryParams := ctx.QueryParams() + formParams, formErr := ctx.FormParams() + + if formErr != nil { + log.Errorf("retrieving form parameters for verification resulted in an error: %v", formErr) + } + + for paramName := range ec.DisabledOptionalParameters { + + // The optional param is disabled, check that it wasn't supplied... + queryValue := queryParams.Get(paramName) + if queryValue != "" { + // If the query value is non-zero, and it was disabled, we should return false + return ErrVerifyFailedParameter{paramName} + } + + if formErr != nil { + continue + } + + formValue := formParams.Get(paramName) + if formValue != "" { + // If the query value is non-zero, and it was disabled, we should return false + return ErrVerifyFailedParameter{paramName} + } + } + + return nil +} diff --git a/api/disabled_parameters_test.go b/api/disabled_parameters_test.go new file mode 100644 index 000000000..2e76391e8 --- /dev/null +++ b/api/disabled_parameters_test.go @@ -0,0 +1,163 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/labstack/echo/v4" + "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/require" +) + +// TestFailingParam tests that disabled parameters provided via +// the FormParams() and QueryParams() functions of the context are appropriately handled +func TestFailingParam(t *testing.T) { + type testingStruct struct { + name string + setFormValues func(*url.Values) + expectedError error + expectedErrorCount int + mimeType string + } + tests := []testingStruct{ + { + "non-disabled param provided", + func(f *url.Values) { + f.Set("3", "Provided") + }, nil, 0, echo.MIMEApplicationForm, + }, + { + "disabled param provided but empty", + func(f *url.Values) { + f.Set("1", "") + }, nil, 0, echo.MIMEApplicationForm, + }, + { + "disabled param provided", + func(f *url.Values) { + f.Set("1", "Provided") + }, ErrVerifyFailedParameter{"1"}, 0, echo.MIMEApplicationForm, + }, + } + + testsPostOnly := []testingStruct{ + { + "Error encountered for Form Params", + func(f *url.Values) { + f.Set("1", "Provided") + }, nil, 1, echo.MIMEMultipartForm, + }, + } + + ctxFactoryGet := func(e *echo.Echo, f *url.Values, t *testingStruct) *echo.Context { + req := httptest.NewRequest(http.MethodGet, "/?"+f.Encode(), nil) + ctx := e.NewContext(req, nil) + return &ctx + } + + ctxFactoryPost := func(e *echo.Echo, f *url.Values, t *testingStruct) *echo.Context { + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) + req.Header.Add(echo.HeaderContentType, t.mimeType) + ctx := e.NewContext(req, nil) + return &ctx + } + + runner := func(t *testing.T, tstruct *testingStruct, ctxFactory func(*echo.Echo, *url.Values, *testingStruct) *echo.Context) { + dm := NewDisabledMap() + e1 := NewEndpointConfig() + e1.EndpointDisabled = false + e1.DisabledOptionalParameters["1"] = true + + dm.Data["K1"] = e1 + + e := echo.New() + + f := make(url.Values) + tstruct.setFormValues(&f) + + ctx := ctxFactory(e, &f, tstruct) + + logger, hook := test.NewNullLogger() + + err := Verify(dm, "K1", *ctx, logger) + + require.Equal(t, tstruct.expectedError, err) + require.Len(t, hook.AllEntries(), tstruct.expectedErrorCount) + } + + for _, test := range tests { + t.Run("Post-"+test.name, func(t *testing.T) { + runner(t, &test, ctxFactoryPost) + }) + + t.Run("Get-"+test.name, func(t *testing.T) { + runner(t, &test, ctxFactoryGet) + }) + + } + + for _, test := range testsPostOnly { + t.Run("Post-"+test.name, func(t *testing.T) { + runner(t, &test, ctxFactoryPost) + }) + + } +} + +// TestFailingEndpoint tests that an endpoint which has a disabled required parameter +// returns a failed endpoint error +func TestFailingEndpoint(t *testing.T) { + dm := NewDisabledMap() + + e1 := NewEndpointConfig() + e1.EndpointDisabled = true + e1.DisabledOptionalParameters["1"] = true + + dm.Data["K1"] = e1 + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/?", nil) + ctx := e.NewContext(req, nil) + + logger, hook := test.NewNullLogger() + + err := Verify(dm, "K1", ctx, logger) + + require.Equal(t, ErrVerifyFailedEndpoint, err) + + require.Len(t, hook.AllEntries(), 0) +} + +// TestVerifyNonExistentHandler tests that nonexistent endpoint is logged +// but doesn't stop the indexer from functioning +func TestVerifyNonExistentHandler(t *testing.T) { + dm := NewDisabledMap() + + e1 := NewEndpointConfig() + e1.EndpointDisabled = false + e1.DisabledOptionalParameters["1"] = true + + dm.Data["K1"] = e1 + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/?", nil) + ctx := e.NewContext(req, nil) + + logger, hook := test.NewNullLogger() + + err := Verify(dm, "DoesntExist", ctx, logger) + + require.Equal(t, nil, err) + require.Len(t, hook.AllEntries(), 1) + + hook.Reset() + + err = Verify(dm, "K1", ctx, logger) + + require.Equal(t, nil, err) + + require.Len(t, hook.AllEntries(), 0) +} diff --git a/api/handlers.go b/api/handlers.go index 0bd7795b2..0137e68e0 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -38,6 +38,8 @@ type ServerImplementation struct { timeout time.Duration log *log.Logger + + disabledParams *DisabledMap } ///////////////////// @@ -145,9 +147,17 @@ func (si *ServerImplementation) MakeHealthCheck(ctx echo.Context) error { }) } +func (si *ServerImplementation) verifyHandler(operationID string, ctx echo.Context) error { + return Verify(si.disabledParams, operationID, ctx, si.log) +} + // LookupAccountByID queries indexer for a given account. // (GET /v2/accounts/{account-id}) func (si *ServerImplementation) LookupAccountByID(ctx echo.Context, accountID string, params generated.LookupAccountByIDParams) error { + if err := si.verifyHandler("LookupAccountByID", ctx); err != nil { + return badRequest(ctx, err.Error()) + } + addr, errors := decodeAddress(&accountID, "account-id", make([]string, 0)) if len(errors) != 0 { return badRequest(ctx, errors[0]) @@ -183,6 +193,10 @@ func (si *ServerImplementation) LookupAccountByID(ctx echo.Context, accountID st // SearchForAccounts returns accounts matching the provided parameters // (GET /v2/accounts) func (si *ServerImplementation) SearchForAccounts(ctx echo.Context, params generated.SearchForAccountsParams) error { + if err := si.verifyHandler("SearchForAccounts", ctx); err != nil { + return badRequest(ctx, err.Error()) + } + if !si.EnableAddressSearchRoundRewind && params.Round != nil { return badRequest(ctx, errMultiAcctRewind) } @@ -242,6 +256,9 @@ func (si *ServerImplementation) SearchForAccounts(ctx echo.Context, params gener // LookupAccountTransactions looks up transactions associated with a particular account. // (GET /v2/accounts/{account-id}/transactions) func (si *ServerImplementation) LookupAccountTransactions(ctx echo.Context, accountID string, params generated.LookupAccountTransactionsParams) error { + if err := si.verifyHandler("LookupAccountTransactions", ctx); err != nil { + return badRequest(ctx, err.Error()) + } // Check that a valid account was provided _, errors := decodeAddress(strPtr(accountID), "account-id", make([]string, 0)) if len(errors) != 0 { @@ -277,6 +294,9 @@ func (si *ServerImplementation) LookupAccountTransactions(ctx echo.Context, acco // SearchForApplications returns applications for the provided parameters. // (GET /v2/applications) func (si *ServerImplementation) SearchForApplications(ctx echo.Context, params generated.SearchForApplicationsParams) error { + if err := si.verifyHandler("SearchForApplications", ctx); err != nil { + return badRequest(ctx, err.Error()) + } apps, round, err := si.fetchApplications(ctx.Request().Context(), params) if err != nil { return indexerError(ctx, fmt.Errorf("%s: %w", errFailedSearchingApplication, err)) @@ -298,6 +318,9 @@ func (si *ServerImplementation) SearchForApplications(ctx echo.Context, params g // LookupApplicationByID returns one application for the requested ID. // (GET /v2/applications/{application-id}) func (si *ServerImplementation) LookupApplicationByID(ctx echo.Context, applicationID uint64, params generated.LookupApplicationByIDParams) error { + if err := si.verifyHandler("LookupApplicationByID", ctx); err != nil { + return badRequest(ctx, err.Error()) + } p := generated.SearchForApplicationsParams{ ApplicationId: &applicationID, IncludeAll: params.IncludeAll, @@ -326,6 +349,10 @@ func (si *ServerImplementation) LookupApplicationByID(ctx echo.Context, applicat // LookupApplicationLogsByID returns one application logs // (GET /v2/applications/{application-id}/logs) func (si *ServerImplementation) LookupApplicationLogsByID(ctx echo.Context, applicationID uint64, params generated.LookupApplicationLogsByIDParams) error { + if err := si.verifyHandler("LookupApplicationLogsByID", ctx); err != nil { + return badRequest(ctx, err.Error()) + } + searchParams := generated.SearchForTransactionsParams{ AssetId: nil, ApplicationId: uint64Ptr(applicationID), @@ -385,6 +412,10 @@ func (si *ServerImplementation) LookupApplicationLogsByID(ctx echo.Context, appl // LookupAssetByID looks up a particular asset // (GET /v2/assets/{asset-id}) func (si *ServerImplementation) LookupAssetByID(ctx echo.Context, assetID uint64, params generated.LookupAssetByIDParams) error { + if err := si.verifyHandler("LookupAssetByID", ctx); err != nil { + return badRequest(ctx, err.Error()) + } + search := generated.SearchForAssetsParams{ AssetId: uint64Ptr(assetID), Limit: uint64Ptr(1), @@ -417,6 +448,10 @@ func (si *ServerImplementation) LookupAssetByID(ctx echo.Context, assetID uint64 // LookupAssetBalances looks up balances for a particular asset // (GET /v2/assets/{asset-id}/balances) func (si *ServerImplementation) LookupAssetBalances(ctx echo.Context, assetID uint64, params generated.LookupAssetBalancesParams) error { + if err := si.verifyHandler("LookupAssetBalances", ctx); err != nil { + return badRequest(ctx, err.Error()) + } + query := idb.AssetBalanceQuery{ AssetID: assetID, AmountGT: params.CurrencyGreaterThan, @@ -453,6 +488,10 @@ func (si *ServerImplementation) LookupAssetBalances(ctx echo.Context, assetID ui // LookupAssetTransactions looks up transactions associated with a particular asset // (GET /v2/assets/{asset-id}/transactions) func (si *ServerImplementation) LookupAssetTransactions(ctx echo.Context, assetID uint64, params generated.LookupAssetTransactionsParams) error { + if err := si.verifyHandler("LookupAssetTransactions", ctx); err != nil { + return badRequest(ctx, err.Error()) + } + searchParams := generated.SearchForTransactionsParams{ AssetId: uint64Ptr(assetID), ApplicationId: nil, @@ -481,6 +520,10 @@ func (si *ServerImplementation) LookupAssetTransactions(ctx echo.Context, assetI // SearchForAssets returns assets matching the provided parameters // (GET /v2/assets) func (si *ServerImplementation) SearchForAssets(ctx echo.Context, params generated.SearchForAssetsParams) error { + if err := si.verifyHandler("SearchForAssets", ctx); err != nil { + return badRequest(ctx, err.Error()) + } + options, err := assetParamsToAssetQuery(params) if err != nil { return badRequest(ctx, err.Error()) @@ -506,6 +549,10 @@ func (si *ServerImplementation) SearchForAssets(ctx echo.Context, params generat // LookupBlock returns the block for a given round number // (GET /v2/blocks/{round-number}) func (si *ServerImplementation) LookupBlock(ctx echo.Context, roundNumber uint64) error { + if err := si.verifyHandler("LookupBlock", ctx); err != nil { + return badRequest(ctx, err.Error()) + } + blk, err := si.fetchBlock(ctx.Request().Context(), roundNumber) if errors.Is(err, idb.ErrorBlockNotFound) { return notFound(ctx, fmt.Sprintf("%s '%d': %v", errLookingUpBlockForRound, roundNumber, err)) @@ -519,6 +566,10 @@ func (si *ServerImplementation) LookupBlock(ctx echo.Context, roundNumber uint64 // LookupTransaction searches for the requested transaction ID. func (si *ServerImplementation) LookupTransaction(ctx echo.Context, txid string) error { + if err := si.verifyHandler("LookupTransaction", ctx); err != nil { + return badRequest(ctx, err.Error()) + } + filter, err := transactionParamsToTransactionFilter(generated.SearchForTransactionsParams{ Txid: strPtr(txid), }) @@ -556,6 +607,10 @@ func (si *ServerImplementation) LookupTransaction(ctx echo.Context, txid string) // SearchForTransactions returns transactions matching the provided parameters // (GET /v2/transactions) func (si *ServerImplementation) SearchForTransactions(ctx echo.Context, params generated.SearchForTransactionsParams) error { + if err := si.verifyHandler("SearchForTransactions", ctx); err != nil { + return badRequest(ctx, err.Error()) + } + filter, err := transactionParamsToTransactionFilter(params) if err != nil { return badRequest(ctx, err.Error()) diff --git a/api/server.go b/api/server.go index 286054c85..8fcef2b0a 100644 --- a/api/server.go +++ b/api/server.go @@ -77,12 +77,19 @@ func Serve(ctx context.Context, serveAddr string, db idb.IndexerDb, fetcherError middleware = append(middleware, middlewares.MakeAuth("X-Indexer-API-Token", options.Tokens)) } + swag, err := generated.GetSwagger() + + if err != nil { + log.Fatal(err) + } + api := ServerImplementation{ EnableAddressSearchRoundRewind: options.DeveloperMode, db: db, fetcher: fetcherError, timeout: options.handlerTimeout(), log: log, + disabledParams: NewDisabledMapFromOA3(swag), } generated.RegisterHandlers(e, &api, middleware...)