diff --git a/rueidiscompat/adapter.go b/rueidiscompat/adapter.go index 39dffec2..323cc531 100644 --- a/rueidiscompat/adapter.go +++ b/rueidiscompat/adapter.go @@ -50,8 +50,17 @@ const ( ) type Cmdable interface { + CoreCmdable Cache(ttl time.Duration) CacheCompat + Subscribe(ctx context.Context, channels ...string) PubSub + PSubscribe(ctx context.Context, patterns ...string) PubSub + SSubscribe(ctx context.Context, channels ...string) PubSub + + Watch(ctx context.Context, fn func(Tx) error, keys ...string) error +} + +type CoreCmdable interface { Command(ctx context.Context) *CommandsInfoCmd CommandList(ctx context.Context, filter FilterBy) *StringSliceCmd CommandGetKeys(ctx context.Context, commands ...any) *StringSliceCmd @@ -470,12 +479,11 @@ type ProbabilisticCmdable interface { TDigestRevRank(ctx context.Context, key string, values ...float64) *IntSliceCmd TDigestTrimmedMean(ctx context.Context, key string, lowCutQuantile, highCutQuantile float64) *FloatCmd - Subscribe(ctx context.Context, channels ...string) PubSub - PSubscribe(ctx context.Context, patterns ...string) PubSub - SSubscribe(ctx context.Context, channels ...string) PubSub - Pipeline() Pipeliner Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) + + TxPipeline() Pipeliner + TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) } // Align with go-redis @@ -4629,6 +4637,24 @@ func (c *Compat) Pipeline() Pipeliner { return newPipeline(c.client) } +func (c *Compat) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { + return newTxPipeline(c.client).Pipelined(ctx, fn) +} + +func (c *Compat) TxPipeline() Pipeliner { + return newTxPipeline(c.client) +} + +func (c *Compat) Watch(ctx context.Context, fn func(Tx) error, keys ...string) error { + dc, cancel := c.client.Dedicate() + defer cancel() + tx := newTx(dc, cancel) + if err := tx.Watch(ctx, keys...).Err(); err != nil { + return err + } + return fn(newTx(dc, cancel)) +} + func (c CacheCompat) BitCount(ctx context.Context, key string, bitCount *BitCount) *IntCmd { var resp rueidis.RedisResult if bitCount == nil { diff --git a/rueidiscompat/pipeline.go b/rueidiscompat/pipeline.go index cb37c90b..32d2891b 100644 --- a/rueidiscompat/pipeline.go +++ b/rueidiscompat/pipeline.go @@ -49,7 +49,7 @@ import ( // To avoid this: it is good idea to use reasonable bigger read/write timeouts // depends on your batch size and/or use TxPipeline. type Pipeliner interface { - Cmdable + CoreCmdable // Len is to obtain the number of commands in the pipeline that have not yet been executed. Len() int @@ -96,10 +96,6 @@ type Pipeline struct { rets []Cmder } -func (c *Pipeline) Cache(ttl time.Duration) CacheCompat { - return c.comp.Cache(ttl) -} - func (c *Pipeline) Command(ctx context.Context) *CommandsInfoCmd { ret := c.comp.Command(ctx) c.rets = append(c.rets, ret) @@ -2434,18 +2430,6 @@ func (c *Pipeline) TDigestTrimmedMean(ctx context.Context, key string, lowCutQua return ret } -func (c *Pipeline) Subscribe(ctx context.Context, channels ...string) PubSub { - return c.comp.Subscribe(ctx, channels...) -} - -func (c *Pipeline) PSubscribe(ctx context.Context, patterns ...string) PubSub { - return c.comp.PSubscribe(ctx, patterns...) -} - -func (c *Pipeline) SSubscribe(ctx context.Context, channels ...string) PubSub { - return c.comp.SSubscribe(ctx, channels...) -} - func (c *Pipeline) TSAdd(ctx context.Context, key string, timestamp interface{}, value float64) *IntCmd { ret := c.comp.TSAdd(ctx, key, timestamp, value) c.rets = append(c.rets, ret) @@ -2788,7 +2772,7 @@ func (c *Pipeline) Len() int { } // Do queues the custom command for later execution. -func (c *Pipeline) Do(ctx context.Context, args ...interface{}) *Cmd { +func (c *Pipeline) Do(_ context.Context, args ...interface{}) *Cmd { ret := &Cmd{} if len(args) == 0 { ret.SetErr(errors.New("redis: please enter the command to be executed")) diff --git a/rueidiscompat/tx.go b/rueidiscompat/tx.go new file mode 100644 index 00000000..89e62401 --- /dev/null +++ b/rueidiscompat/tx.go @@ -0,0 +1,145 @@ +package rueidiscompat + +import ( + "context" + "errors" + "time" + "unsafe" + + "github.com/redis/rueidis" +) + +var TxFailedErr = errors.New("redis: transaction failed") + +var _ Pipeliner = (*TxPipeline)(nil) + +type rePipeline = Pipeline + +func newTxPipeline(real rueidis.Client) *TxPipeline { + return &TxPipeline{rePipeline: newPipeline(real)} +} + +type TxPipeline struct { + *rePipeline +} + +func (c *TxPipeline) Exec(ctx context.Context) ([]Cmder, error) { + p := c.comp.client.(*proxy) + if len(p.cmds) == 0 { + return nil, nil + } + + rets := c.rets + cmds := p.cmds + c.rets = nil + p.cmds = nil + + cmds = append(cmds, c.comp.client.B().Multi().Build(), c.comp.client.B().Exec().Build()) + for i := len(cmds) - 2; i >= 1; i-- { + j := i - 1 + cmds[j], cmds[i] = cmds[i], cmds[j] + } + + resp := p.DoMulti(ctx, cmds...) + results, err := resp[len(resp)-1].ToArray() + if rueidis.IsRedisNil(err) { + err = TxFailedErr + } + for i, r := range results { + rets[i].from(*(*rueidis.RedisResult)(unsafe.Pointer(&proxyresult{ + err: resp[i+1].NonRedisError(), + val: r, + }))) + } + return rets, err +} + +func (c *TxPipeline) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { + if err := fn(c); err != nil { + return nil, err + } + return c.Exec(ctx) +} + +func (c *TxPipeline) Pipeline() Pipeliner { + return c +} + +func (c *TxPipeline) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { + return c.Pipelined(ctx, fn) +} + +func (c *TxPipeline) TxPipeline() Pipeliner { + return c +} + +var _ rueidis.Client = (*txproxy)(nil) + +type txproxy struct { + rueidis.CoreClient +} + +func (p *txproxy) DoCache(_ context.Context, _ rueidis.Cacheable, _ time.Duration) (resp rueidis.RedisResult) { + panic("not implemented") +} + +func (p *txproxy) DoMultiCache(_ context.Context, _ ...rueidis.CacheableTTL) (resp []rueidis.RedisResult) { + panic("not implemented") +} + +func (p *txproxy) DoStream(_ context.Context, _ rueidis.Completed) rueidis.RedisResultStream { + panic("not implemented") +} + +func (p *txproxy) DoMultiStream(_ context.Context, _ ...rueidis.Completed) rueidis.MultiRedisResultStream { + panic("not implemented") +} + +func (p *txproxy) Dedicated(_ func(rueidis.DedicatedClient) error) (err error) { + panic("not implemented") +} + +func (p *txproxy) Dedicate() (client rueidis.DedicatedClient, cancel func()) { + panic("not implemented") +} + +func (p *txproxy) Nodes() map[string]rueidis.Client { + panic("not implemented") +} + +type Tx interface { + CoreCmdable + Watch(ctx context.Context, keys ...string) *StatusCmd + Unwatch(ctx context.Context, keys ...string) *StatusCmd + Close(ctx context.Context) error +} + +func newTx(client rueidis.DedicatedClient, cancel func()) *tx { + return &tx{CoreCmdable: NewAdapter(&txproxy{CoreClient: client}), cancel: cancel} +} + +type tx struct { + CoreCmdable + cancel func() +} + +func (t *tx) Watch(ctx context.Context, keys ...string) *StatusCmd { + ret := &StatusCmd{} + if len(keys) != 0 { + client := t.CoreCmdable.(*Compat).client + ret.from(client.Do(ctx, client.B().Watch().Key(keys...).Build())) + } + return ret +} + +func (t *tx) Unwatch(ctx context.Context, _ ...string) *StatusCmd { + ret := &StatusCmd{} + client := t.CoreCmdable.(*Compat).client + ret.from(client.Do(ctx, client.B().Unwatch().Build())) + return ret +} + +func (t *tx) Close(_ context.Context) error { + t.cancel() + return nil +} diff --git a/rueidiscompat/tx_test.go b/rueidiscompat/tx_test.go new file mode 100644 index 00000000..ef349c3c --- /dev/null +++ b/rueidiscompat/tx_test.go @@ -0,0 +1,170 @@ +// Copyright (c) 2013 The github.com/go-redis/redis Authors. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package rueidiscompat + +import ( + "errors" + "math/rand/v2" + "strconv" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/redis/rueidis" +) + +var _ = Describe("RESP3 TxPipeline Commands", func() { + testAdapterTxPipeline(true) +}) + +var _ = Describe("RESP2 TxPipeline Commands", func() { + testAdapterTxPipeline(false) +}) + +func testAdapterTxPipeline(resp3 bool) { + var adapter Cmdable + + BeforeEach(func() { + if resp3 { + adapter = adapterresp3 + } else { + adapter = adapterresp2 + } + Expect(adapter.FlushDB(ctx).Err()).NotTo(HaveOccurred()) + Expect(adapter.FlushAll(ctx).Err()).NotTo(HaveOccurred()) + }) + + It("should TxPipelined", func() { + var echo, ping *StringCmd + rets, err := adapter.TxPipelined(ctx, func(pipe Pipeliner) error { + echo = pipe.Echo(ctx, "hello") + ping = pipe.Ping(ctx) + Expect(echo.Err()).To(MatchError(placeholder.err)) + Expect(ping.Err()).To(MatchError(placeholder.err)) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + Expect(rets).To(HaveLen(2)) + Expect(rets[0]).To(Equal(echo)) + Expect(rets[1]).To(Equal(ping)) + Expect(echo.Err()).NotTo(HaveOccurred()) + Expect(echo.Val()).To(Equal("hello")) + Expect(ping.Err()).NotTo(HaveOccurred()) + Expect(ping.Val()).To(Equal("PONG")) + }) + + It("should TxPipeline", func() { + pipe := adapter.TxPipeline() + echo := pipe.Echo(ctx, "hello") + ping := pipe.Ping(ctx) + Expect(echo.Err()).To(MatchError(placeholder.err)) + Expect(ping.Err()).To(MatchError(placeholder.err)) + Expect(pipe.Len()).To(Equal(2)) + + rets, err := pipe.Exec(ctx) + Expect(err).NotTo(HaveOccurred()) + Expect(pipe.Len()).To(Equal(0)) + Expect(rets).To(HaveLen(2)) + Expect(rets[0]).To(Equal(echo)) + Expect(rets[1]).To(Equal(ping)) + Expect(echo.Err()).NotTo(HaveOccurred()) + Expect(echo.Val()).To(Equal("hello")) + Expect(ping.Err()).NotTo(HaveOccurred()) + Expect(ping.Val()).To(Equal("PONG")) + }) + + It("should Discard", func() { + pipe := adapter.TxPipeline() + echo := pipe.Echo(ctx, "hello") + ping := pipe.Ping(ctx) + + pipe.Discard() + Expect(pipe.Len()).To(Equal(0)) + + rets, err := pipe.Exec(ctx) + Expect(err).NotTo(HaveOccurred()) + Expect(rets).To(HaveLen(0)) + + Expect(echo.Err()).To(MatchError(placeholder.err)) + Expect(ping.Err()).To(MatchError(placeholder.err)) + }) + + It("should Watch", func() { + k1 := strconv.Itoa(rand.Int()) + k2 := strconv.Itoa(rand.Int()) + err := adapter.Watch(ctx, func(t Tx) error { + if t.Get(ctx, k1).Err() != rueidis.Nil { + return errors.New("unclean") + } + if t.Get(ctx, k2).Err() != rueidis.Nil { + return errors.New("unclean") + } + _, err := t.TxPipelined(ctx, func(pipe Pipeliner) error { + pipe.Set(ctx, k1, k1, 0) + pipe.Set(ctx, k2, k2, 0) + return nil + }) + return err + }, k1, k2) + Expect(err).NotTo(HaveOccurred()) + Expect(adapter.Get(ctx, k1).Val()).To(Equal(k1)) + Expect(adapter.Get(ctx, k2).Val()).To(Equal(k2)) + }) + + It("should Watch Abort", func() { + k1 := strconv.Itoa(rand.Int()) + ch := make(chan error) + go func() { + ch <- adapter.Watch(ctx, func(t Tx) error { + ch <- nil + <-ch + _, err := t.TxPipelined(ctx, func(pipe Pipeliner) error { + pipe.Del(ctx, k1) + return nil + }) + return err + }, k1) + }() + <-ch + Expect(adapter.Set(ctx, k1, k1, 0).Err()).NotTo(HaveOccurred()) + ch <- nil + Expect(<-ch).To(MatchError(TxFailedErr)) + }) + + It("should Unwatch and Close", func() { + k1 := strconv.Itoa(rand.Int()) + err := adapter.Watch(ctx, func(t Tx) error { + Expect(t.Unwatch(ctx).Err()).NotTo(HaveOccurred()) + Expect(t.Close(ctx)).NotTo(HaveOccurred()) + _, err := t.TxPipelined(ctx, func(pipe Pipeliner) error { + pipe.Del(ctx, k1) + return nil + }) + return err + }, k1) + Expect(err).To(MatchError(rueidis.ErrDedicatedClientRecycled)) + }) +}