Skip to content

Commit

Permalink
Add possibility to transmit directories
Browse files Browse the repository at this point in the history
  • Loading branch information
dennis-tra committed Mar 21, 2021
1 parent 256b1f0 commit ecf1363
Show file tree
Hide file tree
Showing 22 changed files with 516 additions and 205 deletions.
33 changes: 15 additions & 18 deletions pkg/crypt/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ import (
// StreamEncrypter implements the Reader interface to be used in
// streaming scenarios.
type StreamEncrypter struct {
src io.Reader
dest io.Writer
block cipher.Block
stream cipher.Stream
mac hash.Hash
iv []byte
}

// NewStreamEncrypter initializes a stream encrypter.
func NewStreamEncrypter(key []byte, src io.Reader) (*StreamEncrypter, error) {
func NewStreamEncrypter(key []byte, dest io.Writer) (*StreamEncrypter, error) {
// Create a new AES cipher
block, err := aes.NewCipher(key)
if err != nil {
Expand All @@ -39,14 +39,26 @@ func NewStreamEncrypter(key []byte, src io.Reader) (*StreamEncrypter, error) {
}

return &StreamEncrypter{
src: src,
dest: dest,
block: block,
stream: cipher.NewCTR(block, iv),
mac: hmac.New(sha256.New, key),
iv: iv,
}, nil
}

// Write writes bytes encrypted to the writer interface.
func (s *StreamEncrypter) Write(p []byte) (int, error) {
buf := make([]byte, len(p)) // Could we get rid of this allocation?
s.stream.XORKeyStream(buf, p)
n, writeErr := s.dest.Write(buf)
if err := writeHash(s.mac, buf[:n]); err != nil {
return n, err
}

return n, writeErr
}

func (s *StreamEncrypter) InitializationVector() []byte {
return s.iv
}
Expand All @@ -57,21 +69,6 @@ func (s *StreamEncrypter) Hash() []byte {
return s.mac.Sum(nil)
}

// Read reads bytes from the underlying reader and encrypts them.
func (s *StreamEncrypter) Read(p []byte) (int, error) {
n, readErr := s.src.Read(p)
if n == 0 {
return 0, io.EOF
}

s.stream.XORKeyStream(p[:n], p[:n])
if err := writeHash(s.mac, p[:n]); err != nil {
return n, err
}

return n, readErr
}

// StreamDecrypter is a decrypter for a stream of data with authentication
type StreamDecrypter struct {
src io.Reader
Expand Down
12 changes: 5 additions & 7 deletions pkg/crypt/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,15 @@ func TestStreamEncrypterDecrypter(t *testing.T) {
key, err := DeriveKey(pw, salt)
assert.Nil(t, err)

payload := []byte("some text")
src := bytes.NewReader(payload)
se, err := NewStreamEncrypter(key, src)
var buf bytes.Buffer
se, err := NewStreamEncrypter(key, &buf)
assert.Nil(t, err)
assert.NotNil(t, se)

encrypted, err := ioutil.ReadAll(se)
assert.Nil(t, err)
assert.NotNil(t, encrypted)
payload := []byte("some text")
_, _ = se.Write(payload)

sd, err := NewStreamDecrypter(key, se.InitializationVector(), bytes.NewReader(encrypted))
sd, err := NewStreamDecrypter(key, se.InitializationVector(), &buf)
assert.Nil(t, err)
assert.NotNil(t, sd)

Expand Down
26 changes: 7 additions & 19 deletions pkg/node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ func (n *Node) Read(s network.Stream, buf p2p.HeaderMessage) error {
return err
}

log.Debugf("Reading message from %s", s.Conn().RemotePeer().String())
log.Debugf("Reading message from %s\n", s.Conn().RemotePeer().String())
// Decrypt the data with the PAKE session key if it is found
sKey, found := n.GetSessionKey(s.Conn().RemotePeer())
if found {
Expand Down Expand Up @@ -278,34 +278,22 @@ func (n *Node) Read(s network.Stream, buf p2p.HeaderMessage) error {
// WriteBytes writes the given bytes to the destination writer and
// prefixes it with a uvarint indicating the length of the data.
func (n *Node) WriteBytes(w io.Writer, data []byte) (int, error) {
hdr := varint.ToUvarint(uint64(len(data)))
nhdr, err := w.Write(hdr)
if err != nil {
return nhdr, err
}

ndata, err := w.Write(data)
if err != nil {
return ndata, err
}

return nhdr + ndata, nil
size := varint.ToUvarint(uint64(len(data)))
return w.Write(append(size, data...))
}

// ReadBytes reads an uvarint from the source reader to know how
// much data is following.
func (n *Node) ReadBytes(r io.Reader) ([]byte, error) {
l, err := varint.ReadUvarint(bufio.NewReader(r))
br := bufio.NewReader(r) // init byte reader
l, err := varint.ReadUvarint(br)
if err != nil {
return nil, err
}

buf := make([]byte, l)
if _, err = r.Read(buf); err != nil {
return nil, err
}

return buf, nil
_, err = br.Read(buf)
return buf, err
}

// ResetOnShutdown resets the given stream if the node receives a shutdown
Expand Down
2 changes: 1 addition & 1 deletion pkg/node/pake.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (p *PakeProtocol) AddAuthenticatedPeer(peerID peer.ID, key []byte) {
// passed a password authenticated key exchange.
func (p *PakeProtocol) IsAuthenticated(peerID peer.ID) bool {
_, found := p.authedPeers.Load(peerID)
log.Debugf("Is peer %s authenticated: %s\n", peerID, found)
log.Debugf("Is peer %s authenticated: %v\n", peerID, found)
return found
}

Expand Down
6 changes: 3 additions & 3 deletions pkg/node/push.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (p *PushProtocol) onPushRequest(s network.Stream) {
log.Infoln(err)
return
}
log.Debugln("Received push request", req.Filename, req.Size)
log.Debugln("Received push request", req.Name, req.Size)

p.lk.RLock()
defer p.lk.RUnlock()
Expand All @@ -81,15 +81,15 @@ func (p *PushProtocol) onPushRequest(s network.Stream) {
}
}

func (p *PushProtocol) SendPushRequest(ctx context.Context, peerID peer.ID, filename string, size int64) (bool, error) {
func (p *PushProtocol) SendPushRequest(ctx context.Context, peerID peer.ID, filename string, size int64, isDir bool) (bool, error) {
s, err := p.node.NewStream(ctx, peerID, ProtocolPushRequest)
if err != nil {
return false, err
}
defer s.Close()

log.Debugln("Sending push request", filename, size)
if err = p.node.Send(s, p2p.NewPushRequest(filename, size)); err != nil {
if err = p.node.Send(s, p2p.NewPushRequest(filename, size, isDir)); err != nil {
return false, err
}

Expand Down
108 changes: 93 additions & 15 deletions pkg/node/transfer.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
package node

import (
"archive/tar"
"context"
"fmt"
"io"
"os"
"path/filepath"
"sync"

"github.com/dennis-tra/pcp/pkg/crypt"
"github.com/pkg/errors"

"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"

"github.com/dennis-tra/pcp/internal/log"
"github.com/dennis-tra/pcp/pkg/crypt"
"github.com/dennis-tra/pcp/pkg/progress"
)

// pattern: /protocol-name/request-or-response-message/version
const (
ProtocolTransfer = "/pcp/transfer/0.1.0"
ProtocolTransfer = "/pcp/transfer/0.2.0"
)

// TransferProtocol encapsulates data necessary to fulfill its protocol.
Expand All @@ -27,8 +32,8 @@ type TransferProtocol struct {
}

type TransferHandler interface {
HandleTransfer(r io.Reader)
GetLimit() int64
HandleFile(*tar.Header, io.Reader)
Done()
}

func (t *TransferProtocol) RegisterTransferHandler(th TransferHandler) {
Expand Down Expand Up @@ -60,14 +65,16 @@ func (t *TransferProtocol) onTransfer(s network.Stream) {
// Get PAKE session key for stream decryption
sKey, found := t.node.GetSessionKey(s.Conn().RemotePeer())
if !found {
log.Warningln("Received transfer from unauthenticated peer")
log.Warningln("Received transfer from unauthenticated peer:", s.Conn().RemotePeer())
s.Reset() // Tell peer to go away
return
}

// Read initialization vector from stream. This is sent first from our peer.
iv, err := t.node.ReadBytes(s)
if err != nil {
log.Warningln("Could not read stream initialization vector", err)
s.Reset() // Stream is probably broken anyways
return
}

Expand All @@ -79,17 +86,26 @@ func (t *TransferProtocol) onTransfer(s network.Stream) {
t.lk.RUnlock()
}()

// Only read as much as we expect to avoid stuffing.
lr := io.LimitReader(s, t.th.GetLimit())

// Decrypt the stream
sd, err := crypt.NewStreamDecrypter(sKey, iv, lr)
sd, err := crypt.NewStreamDecrypter(sKey, iv, s)
if err != nil {
log.Warningln("Could not instantiate stream decrypter", err)
return
}

t.th.HandleTransfer(sd)
// Drain tar archive
tr := tar.NewReader(sd)
for {
hdr, err := tr.Next()
if err == io.EOF {
break // End of archive
} else if err != nil {
log.Warningln("Error reading next tar element", err)
return
}
t.th.HandleFile(hdr, tr)
}
defer t.th.Done()

// Read file hash from the stream and check if it matches
hash, err := t.node.ReadBytes(s)
Expand All @@ -108,23 +124,29 @@ func (t *TransferProtocol) onTransfer(s network.Stream) {
// Transfer can be called to transfer the given payload to the given peer. The PushRequest is used for displaying
// the progress to the user. This function returns when the bytes where transmitted and we have received an
// acknowledgment.
func (t *TransferProtocol) Transfer(ctx context.Context, peerID peer.ID, progress io.Writer, src io.Reader) error {
func (t *TransferProtocol) Transfer(ctx context.Context, peerID peer.ID, basePath string) error {
// Open a new stream to our peer.
s, err := t.node.NewStream(ctx, peerID, ProtocolTransfer)
if err != nil {
return err
}

defer s.Close()
defer t.node.ResetOnShutdown(s)()

base, err := os.Stat(basePath)
if err != nil {
return err
}

// Get PAKE session key for stream encryption
sKey, found := t.node.GetSessionKey(peerID)
if !found {
return fmt.Errorf("session key not found to encrypt data transfer")
}

// Initialize new stream encrypter
se, err := crypt.NewStreamEncrypter(sKey, src)
se, err := crypt.NewStreamEncrypter(sKey, s)
if err != nil {
return err
}
Expand All @@ -137,17 +159,73 @@ func (t *TransferProtocol) Transfer(ctx context.Context, peerID peer.ID, progres
return err
}

// The actual file transfer.
_, err = io.Copy(io.MultiWriter(s, progress), se)
tw := tar.NewWriter(se)
err = filepath.Walk(basePath, func(path string, info os.FileInfo, err error) error {
log.Debugln("Preparing file for transmission:", path)
if err != nil {
log.Debugln("Error walking file:", err)
return err
}

hdr, err := tar.FileInfoHeader(info, "")
if err != nil {
return errors.Wrapf(err, "error writing tar file info header %s: %s", path, err)
}

// To preserve directory structure in the tar ball.
hdr.Name, err = relPath(basePath, base.IsDir(), path)
if err != nil {
return errors.Wrapf(err, "error building relative path: %s (%v) %s", basePath, base.IsDir(), path)
}

if err = tw.WriteHeader(hdr); err != nil {
return errors.Wrap(err, "error writing tar header")
}

// Continue as all information was written above with WriteHeader.
if info.IsDir() {
return nil
}

f, err := os.Open(path)
if err != nil {
return errors.Wrapf(err, "error opening file for taring at: %s", path)
}
defer f.Close()

bar := progress.DefaultBytes(info.Size(), info.Name())
if _, err = io.Copy(io.MultiWriter(tw, bar), f); err != nil {
return err
}

return nil
})
if err != nil {
return err
}

if err = tw.Close(); err != nil {
log.Debugln("Error closing tar ball", err)
}

// Send the hash of all sent data, so our recipient can check the data.
_, err = t.node.WriteBytes(s, se.Hash())
if err != nil {
return err
return errors.Wrap(err, "error writing final hash to stream")
}

return t.node.WaitForEOF(s)
}

// relPath builds the path structure for the tar archive - this will be the structure as it is received.
func relPath(basePath string, baseIsDir bool, targetPath string) (string, error) {
if baseIsDir {
rel, err := filepath.Rel(basePath, targetPath)
if err != nil {
return "", err
}
return filepath.Clean(filepath.Join(filepath.Base(basePath), rel)), nil
} else {
return filepath.Base(basePath), nil
}
}
Loading

0 comments on commit ecf1363

Please sign in to comment.