diff --git a/internal/exchange/client_flow.go b/internal/exchange/client_flow.go index 1edc03bbf..e5d68e8a1 100644 --- a/internal/exchange/client_flow.go +++ b/internal/exchange/client_flow.go @@ -7,14 +7,12 @@ import ( "math/big" "go.uber.org/zap" - - "github.com/gotd/td/internal/proto" - "golang.org/x/xerrors" "github.com/gotd/td/bin" "github.com/gotd/td/internal/crypto" "github.com/gotd/td/internal/mt" + "github.com/gotd/td/internal/proto" ) // Run runs client-side flow. @@ -41,6 +39,7 @@ func (c ClientExchange) Run(ctx context.Context) (ClientExchangeResult, error) { if res.Nonce != nonce { return ClientExchangeResult{}, xerrors.New("ResPQ nonce mismatch") } + serverNonce := res.ServerNonce // Selecting first public key that match fingerprint. var selectedPubKey *rsa.PublicKey @@ -89,7 +88,7 @@ Loop: Pq: res.Pq, Nonce: nonce, NewNonce: newNonce, - ServerNonce: res.ServerNonce, + ServerNonce: serverNonce, P: pBytes, Q: qBytes, } @@ -105,7 +104,7 @@ Loop: } reqDHParams := &mt.ReqDHParamsRequest{ Nonce: nonce, - ServerNonce: res.ServerNonce, + ServerNonce: serverNonce, P: pBytes, Q: qBytes, PublicKeyFingerprint: crypto.RSAFingerprint(selectedPubKey), @@ -138,8 +137,11 @@ Loop: if p.Nonce != nonce { return ClientExchangeResult{}, xerrors.New("ServerDHParamsOk nonce mismatch") } + if p.ServerNonce != serverNonce { + return ClientExchangeResult{}, xerrors.New("ServerDHParamsOk server nonce mismatch") + } - key, iv := crypto.TempAESKeys(newNonce.BigInt(), res.ServerNonce.BigInt()) + key, iv := crypto.TempAESKeys(newNonce.BigInt(), serverNonce.BigInt()) // Decrypting inner data. data, err := crypto.DecryptExchangeAnswer(p.EncryptedAnswer, key, iv) if err != nil { @@ -151,6 +153,12 @@ Loop: if err := innerData.Decode(b); err != nil { return ClientExchangeResult{}, err } + if innerData.Nonce != nonce { + return ClientExchangeResult{}, xerrors.New("ServerDHInnerData nonce mismatch") + } + if innerData.ServerNonce != serverNonce { + return ClientExchangeResult{}, xerrors.New("ServerDHInnerData server nonce mismatch") + } dhPrime := big.NewInt(0).SetBytes(innerData.DhPrime) g := big.NewInt(int64(innerData.G)) @@ -215,6 +223,13 @@ Loop: } switch v := dhSetRes.(type) { case *mt.DhGenOk: // dh_gen_ok#3bcbf734 + if v.Nonce != nonce { + return ClientExchangeResult{}, xerrors.New("DhGenOk nonce mismatch") + } + if v.ServerNonce != serverNonce { + return ClientExchangeResult{}, xerrors.New("DhGenOk server nonce mismatch") + } + var key crypto.Key authKey.FillBytes(key[:]) authKeyID := key.ID() diff --git a/internal/exchange/client_flow_test.go b/internal/exchange/client_flow_test.go new file mode 100644 index 000000000..174753128 --- /dev/null +++ b/internal/exchange/client_flow_test.go @@ -0,0 +1,47 @@ +package exchange + +import ( + "context" + "crypto/rsa" + "math/rand" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" + + "github.com/gotd/td/internal/tdsync" + "github.com/gotd/td/transport" +) + +func TestExchangeTimeout(t *testing.T) { + a := require.New(t) + + reader := rand.New(rand.NewSource(1)) + key, err := rsa.GenerateKey(reader, 2048) + a.NoError(err) + log := zaptest.NewLogger(t) + + i := transport.Intermediate(nil) + client, _ := i.Pipe() + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + grp := tdsync.NewCancellableGroup(ctx) + grp.Go(func(groupCtx context.Context) error { + _, err := NewExchanger(client). + WithLogger(log.Named("client")). + WithRand(reader). + WithTimeout(1 * time.Second). + Client([]*rsa.PublicKey{&key.PublicKey}). + Run(groupCtx) + return err + }) + + err = grp.Wait() + if err, ok := err.(net.Error); !ok || !err.Timeout() { + require.NoError(t, err) + } +} diff --git a/internal/exchange/flow_test.go b/internal/exchange/flow_test.go index ed508eef4..505cad983 100644 --- a/internal/exchange/flow_test.go +++ b/internal/exchange/flow_test.go @@ -5,7 +5,6 @@ import ( "crypto/rsa" "fmt" "math/rand" - "net" "testing" "time" @@ -53,37 +52,6 @@ func TestExchange(t *testing.T) { require.NoError(t, grp.Wait()) } -func TestExchangeTimeout(t *testing.T) { - a := require.New(t) - - reader := rand.New(rand.NewSource(1)) - key, err := rsa.GenerateKey(reader, 2048) - a.NoError(err) - log := zaptest.NewLogger(t) - - i := transport.Intermediate(nil) - client, _ := i.Pipe() - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - - grp := tdsync.NewCancellableGroup(ctx) - grp.Go(func(groupCtx context.Context) error { - _, err := NewExchanger(client). - WithLogger(log.Named("client")). - WithRand(reader). - WithTimeout(1 * time.Second). - Client([]*rsa.PublicKey{&key.PublicKey}). - Run(groupCtx) - return err - }) - - err = grp.Wait() - if err, ok := err.(net.Error); !ok || !err.Timeout() { - require.NoError(t, err) - } -} - func TestExchangeCorpus(t *testing.T) { k := testutil.RSAPrivateKey()