diff --git a/internal/datafs/stdinfs.go b/internal/datafs/stdinfs.go index 46cb030a8..f6a02c512 100644 --- a/internal/datafs/stdinfs.go +++ b/internal/datafs/stdinfs.go @@ -18,7 +18,8 @@ func NewStdinFS(_ *url.URL) (fs.FS, error) { } type stdinFS struct { - ctx context.Context + ctx context.Context + data []byte } //nolint:gochecknoglobals @@ -46,9 +47,15 @@ func (f *stdinFS) Open(name string) (fs.File, error) { } } - stdin := StdinFromContext(f.ctx) + if err := f.readData(); err != nil { + return nil, &fs.PathError{ + Op: "open", + Path: name, + Err: err, + } + } - return &stdinFile{name: name, body: stdin}, nil + return &stdinFile{name: name, body: bytes.NewReader(f.data)}, nil } func (f *stdinFS) ReadFile(name string) ([]byte, error) { @@ -60,9 +67,32 @@ func (f *stdinFS) ReadFile(name string) ([]byte, error) { } } + if err := f.readData(); err != nil { + return nil, &fs.PathError{ + Op: "readFile", + Path: name, + Err: err, + } + } + + return f.data, nil +} + +func (f *stdinFS) readData() error { + if f.data != nil { + return nil + } + stdin := StdinFromContext(f.ctx) - return io.ReadAll(stdin) + b, err := io.ReadAll(stdin) + if err != nil { + return err + } + + f.data = b + + return nil } type stdinFile struct { diff --git a/internal/datafs/stdinfs_test.go b/internal/datafs/stdinfs_test.go index f8c30a066..f5010a470 100644 --- a/internal/datafs/stdinfs_test.go +++ b/internal/datafs/stdinfs_test.go @@ -99,6 +99,57 @@ func TestStdinFS(t *testing.T) { _, err = f.Read(p) require.Error(t, err) require.ErrorIs(t, err, io.EOF) + + t.Run("open/read multiple times", func(t *testing.T) { + ctx := ContextWithStdin(context.Background(), bytes.NewReader(content)) + fsys = fsimpl.WithContextFS(ctx, fsys) + + for i := 0; i < 3; i++ { + f, err := fsys.Open("foo") + require.NoError(t, err) + + b, err := io.ReadAll(f) + require.NoError(t, err) + require.Equal(t, content, b, "read %d failed", i) + } + }) + + t.Run("readFile multiple times", func(t *testing.T) { + ctx := ContextWithStdin(context.Background(), bytes.NewReader(content)) + fsys = fsimpl.WithContextFS(ctx, fsys) + + for i := 0; i < 3; i++ { + b, err := fs.ReadFile(fsys, "foo") + require.NoError(t, err) + require.Equal(t, content, b, "read %d failed", i) + } + }) + + t.Run("open errors", func(t *testing.T) { + ctx := ContextWithStdin(context.Background(), &errorReader{err: fs.ErrPermission}) + + fsys, err := NewStdinFS(u) + require.NoError(t, err) + assert.IsType(t, &stdinFS{}, fsys) + + fsys = fsimpl.WithContextFS(ctx, fsys) + + _, err = fsys.Open("foo") + require.ErrorIs(t, err, fs.ErrPermission) + }) + + t.Run("readFile errors", func(t *testing.T) { + ctx := ContextWithStdin(context.Background(), &errorReader{err: fs.ErrPermission}) + + fsys, err := NewStdinFS(u) + require.NoError(t, err) + assert.IsType(t, &stdinFS{}, fsys) + + fsys = fsimpl.WithContextFS(ctx, fsys) + + _, err = fs.ReadFile(fsys, "foo") + require.ErrorIs(t, err, fs.ErrPermission) + }) } type errorReader struct {