Skip to content

Commit

Permalink
test(errdefs): use more helpers, increase test coverage (#358)
Browse files Browse the repository at this point in the history
Signed-off-by: Gyuho Lee <[email protected]>

---------

Signed-off-by: Gyuho Lee <[email protected]>
  • Loading branch information
gyuho authored Feb 11, 2025
1 parent b573723 commit 1b3c8f8
Show file tree
Hide file tree
Showing 8 changed files with 398 additions and 50 deletions.
3 changes: 1 addition & 2 deletions client/v1/examples/get-states/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package main

import (
"context"
"errors"
"time"

client_v1 "github.com/leptonai/gpud/client/v1"
Expand All @@ -17,7 +16,7 @@ func main() {
defer cancel()
states, err := client_v1.GetStates(ctx, baseURL, client_v1.WithComponent(componentName))
if err != nil {
if errors.Is(err, errdefs.ErrNotFound) {
if errdefs.IsNotFound(err) {
log.Logger.Warnw("component not found", "component", componentName)
return
}
Expand Down
46 changes: 46 additions & 0 deletions client/v1/package_status.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package v1

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"

"github.com/leptonai/gpud/manager/packages"
)

// GetPackageStatus fetches the GPUd package status from the GPUd admin API.
func GetPackageStatus(ctx context.Context, url string, opts ...OpOption) ([]packages.PackageStatus, error) {
op := &Op{}
if err := op.applyOpts(opts); err != nil {
return nil, err
}

httpClient := op.httpClient
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, err
}

resp, err := httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code %v received", resp.StatusCode)
}

rawBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}

var ret []packages.PackageStatus
if err := json.Unmarshal(rawBody, &ret); err != nil {
return nil, err
}
return ret, nil
}
116 changes: 116 additions & 0 deletions client/v1/package_status_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package v1

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/assert"

"github.com/leptonai/gpud/manager/packages"
)

func TestGetStatus(t *testing.T) {
tests := []struct {
name string
serverResponse func(w http.ResponseWriter, r *http.Request)
wantErr bool
expectedData []packages.PackageStatus
}{
{
name: "successful response",
serverResponse: func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodGet, r.Method)

response := []packages.PackageStatus{
{
Name: "test-package",
IsInstalled: true,
Installing: false,
Progress: 100,
TotalTime: 1 * time.Hour,
Status: true,
TargetVersion: "1.0.0",
CurrentVersion: "1.0.0",
ScriptPath: "/path/to/script",
Dependency: [][]string{{"dep1", "1.0.0"}},
},
}
assert.NoError(t, json.NewEncoder(w).Encode(response))
},
wantErr: false,
expectedData: []packages.PackageStatus{
{
Name: "test-package",
IsInstalled: true,
Installing: false,
Progress: 100,
TotalTime: 1 * time.Hour,
Status: true,
TargetVersion: "1.0.0",
CurrentVersion: "1.0.0",
ScriptPath: "/path/to/script",
Dependency: [][]string{{"dep1", "1.0.0"}},
},
},
},
{
name: "server returns error status",
serverResponse: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
},
wantErr: true,
expectedData: nil,
},
{
name: "invalid JSON response",
serverResponse: func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("invalid json"))
assert.NoError(t, err)
},
wantErr: true,
expectedData: nil,
},
{
name: "context canceled",
serverResponse: func(w http.ResponseWriter, r *http.Request) {
// Simulate a slow response that will be canceled
time.Sleep(100 * time.Millisecond)
assert.NoError(t, json.NewEncoder(w).Encode([]packages.PackageStatus{}))
},
wantErr: true,
expectedData: nil,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create test server
server := httptest.NewTLSServer(http.HandlerFunc(tt.serverResponse))
defer server.Close()

var ctx context.Context
var cancel context.CancelFunc
if tt.name == "context canceled" {
ctx, cancel = context.WithTimeout(context.Background(), 50*time.Millisecond)
} else {
ctx, cancel = context.WithTimeout(context.Background(), time.Second)
}
defer cancel()

// Call GetStatus with the test server's URL
status, err := GetPackageStatus(ctx, server.URL)

if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, status)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectedData, status)
}
})
}
}
45 changes: 7 additions & 38 deletions cmd/gpud/command/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,18 @@ package command

import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"time"

"github.com/urfave/cli"

client "github.com/leptonai/gpud/client/v1"
"github.com/leptonai/gpud/config"
"github.com/leptonai/gpud/errdefs"
"github.com/leptonai/gpud/internal/server"
"github.com/leptonai/gpud/log"
"github.com/leptonai/gpud/manager/packages"
"github.com/leptonai/gpud/pkg/systemd"

"github.com/urfave/cli"
)

func cmdStatus(cliContext *cli.Context) error {
Expand Down Expand Up @@ -56,7 +52,9 @@ func cmdStatus(cliContext *cli.Context) error {
fmt.Printf("%s successfully checked gpud health\n", checkMark)

for {
packageStatus, err := getStatus()
cctx, ccancel := context.WithTimeout(rootCtx, 15*time.Second)
packageStatus, err := client.GetPackageStatus(cctx, fmt.Sprintf("https://localhost:%d%s", config.DefaultGPUdPort, server.URLPathAdminPackages))
ccancel()
if err != nil {
fmt.Printf("%s failed to get package status: %v\n", warningSign, err)
return err
Expand Down Expand Up @@ -122,7 +120,7 @@ func checkNvidiaInfoComponent() error {
defer cancel()
states, err := client.GetStates(ctx, baseURL, client.WithComponent(componentName))
if err != nil {
if errors.Is(err, errdefs.ErrNotFound) {
if errdefs.IsNotFound(err) {
log.Logger.Warnw("component not found", "component", componentName)
return nil
}
Expand All @@ -141,32 +139,3 @@ func checkNvidiaInfoComponent() error {

return nil
}

func getStatus() ([]packages.PackageStatus, error) {
httpClient := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
req, err := http.NewRequest("GET", fmt.Sprintf("https://localhost:%d/admin/packages", config.DefaultGPUdPort), nil)
if err != nil {
return nil, err
}
resp, err := httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code %v received", resp.StatusCode)
}
rawBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var ret []packages.PackageStatus
if err := json.Unmarshal(rawBody, &ret); err != nil {
return nil, err
}
return ret, nil
}
Loading

0 comments on commit 1b3c8f8

Please sign in to comment.