Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add TxPipeline to rueidiscompat #605

Merged
merged 3 commits into from
Aug 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 85 additions & 5 deletions rueidiscompat/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,20 @@ const (
BitCountIndexBit = "BIT"
)

var Nil = rueidis.Nil

type Cmdable interface {
CoreCmdable
Cache(ttl time.Duration) CacheCompat

Subscribe(ctx context.Context, channels ...string) PubSub
PSubscribe(ctx context.Context, patterns ...string) PubSub
SSubscribe(ctx context.Context, channels ...string) PubSub

Watch(ctx context.Context, fn func(Tx) error, keys ...string) error
}

type CoreCmdable interface {
Command(ctx context.Context) *CommandsInfoCmd
CommandList(ctx context.Context, filter FilterBy) *StringSliceCmd
CommandGetKeys(ctx context.Context, commands ...any) *StringSliceCmd
Expand Down Expand Up @@ -127,11 +138,13 @@ type Cmdable interface {
BitPos(ctx context.Context, key string, bit int64, pos ...int64) *IntCmd
BitPosSpan(ctx context.Context, key string, bit int64, start, end int64, span string) *IntCmd
BitField(ctx context.Context, key string, args ...any) *IntSliceCmd
// TODO BitFieldRO(ctx context.Context, key string, values ...interface{}) *IntSliceCmd

Scan(ctx context.Context, cursor uint64, match string, count int64) *ScanCmd
ScanType(ctx context.Context, cursor uint64, match string, count int64, keyType string) *ScanCmd
SScan(ctx context.Context, key string, cursor uint64, match string, count int64) *ScanCmd
HScan(ctx context.Context, key string, cursor uint64, match string, count int64) *ScanCmd
// TODO HScanNoValues(ctx context.Context, key string, cursor uint64, match string, count int64) *ScanCmd
ZScan(ctx context.Context, key string, cursor uint64, match string, count int64) *ScanCmd

HDel(ctx context.Context, key string, fields ...string) *IntCmd
Expand All @@ -149,6 +162,19 @@ type Cmdable interface {
HVals(ctx context.Context, key string) *StringSliceCmd
HRandField(ctx context.Context, key string, count int64) *StringSliceCmd
HRandFieldWithValues(ctx context.Context, key string, count int64) *KeyValueSliceCmd
// TODO HExpire(ctx context.Context, key string, expiration time.Duration, fields ...string) *IntSliceCmd
// TODO HExpireWithArgs(ctx context.Context, key string, expiration time.Duration, expirationArgs HExpireArgs, fields ...string) *IntSliceCmd
// TODO HPExpire(ctx context.Context, key string, expiration time.Duration, fields ...string) *IntSliceCmd
// TODO HPExpireWithArgs(ctx context.Context, key string, expiration time.Duration, expirationArgs HExpireArgs, fields ...string) *IntSliceCmd
// TODO HExpireAt(ctx context.Context, key string, tm time.Time, fields ...string) *IntSliceCmd
// TODO HExpireAtWithArgs(ctx context.Context, key string, tm time.Time, expirationArgs HExpireArgs, fields ...string) *IntSliceCmd
// TODO HPExpireAt(ctx context.Context, key string, tm time.Time, fields ...string) *IntSliceCmd
// TODO HPExpireAtWithArgs(ctx context.Context, key string, tm time.Time, expirationArgs HExpireArgs, fields ...string) *IntSliceCmd
// TODO HPersist(ctx context.Context, key string, fields ...string) *IntSliceCmd
// TODO HExpireTime(ctx context.Context, key string, fields ...string) *IntSliceCmd
// TODO HPExpireTime(ctx context.Context, key string, fields ...string) *IntSliceCmd
// TODO HTTL(ctx context.Context, key string, fields ...string) *IntSliceCmd
// TODO HPTTL(ctx context.Context, key string, fields ...string) *IntSliceCmd

BLPop(ctx context.Context, timeout time.Duration, keys ...string) *StringSliceCmd
BLMPop(ctx context.Context, timeout time.Duration, direction string, count int64, keys ...string) *KeyValuesCmd
Expand Down Expand Up @@ -375,6 +401,8 @@ type Cmdable interface {
ClusterFailover(ctx context.Context) *StatusCmd
ClusterAddSlots(ctx context.Context, slots ...int64) *StatusCmd
ClusterAddSlotsRange(ctx context.Context, min, max int64) *StatusCmd
// TODO ReadOnly(ctx context.Context) *StatusCmd
// TODO ReadWrite(ctx context.Context) *StatusCmd

GeoAdd(ctx context.Context, key string, geoLocation ...GeoLocation) *IntCmd
GeoPos(ctx context.Context, key string, members ...string) *GeoPosCmd
Expand All @@ -389,13 +417,48 @@ type Cmdable interface {
GeoHash(ctx context.Context, key string, members ...string) *StringSliceCmd

ACLDryRun(ctx context.Context, username string, command ...any) *StringCmd
// TODO ACLLog(ctx context.Context, count int64) *ACLLogCmd
// TODO ACLLogReset(ctx context.Context) *StatusCmd

// TODO ModuleLoadex(ctx context.Context, conf *ModuleLoadexConfig) *StringCmd
GearsCmdable
ProbabilisticCmdable
TimeseriesCmdable
JSONCmdable
}
// TODO SearchCmdable
}

// TODO SearchCmdable
//type SearchCmdable interface {
// FT_List(ctx context.Context) *StringSliceCmd
// FTAggregate(ctx context.Context, index string, query string) *MapStringInterfaceCmd
// FTAggregateWithArgs(ctx context.Context, index string, query string, options *FTAggregateOptions) *AggregateCmd
// FTAliasAdd(ctx context.Context, index string, alias string) *StatusCmd
// FTAliasDel(ctx context.Context, alias string) *StatusCmd
// FTAliasUpdate(ctx context.Context, index string, alias string) *StatusCmd
// FTAlter(ctx context.Context, index string, skipInitalScan bool, definition []interface{}) *StatusCmd
// FTConfigGet(ctx context.Context, option string) *MapMapStringInterfaceCmd
// FTConfigSet(ctx context.Context, option string, value interface{}) *StatusCmd
// FTCreate(ctx context.Context, index string, options *FTCreateOptions, schema ...*FieldSchema) *StatusCmd
// FTCursorDel(ctx context.Context, index string, cursorId int) *StatusCmd
// FTCursorRead(ctx context.Context, index string, cursorId int, count int) *MapStringInterfaceCmd
// FTDictAdd(ctx context.Context, dict string, term ...interface{}) *IntCmd
// FTDictDel(ctx context.Context, dict string, term ...interface{}) *IntCmd
// FTDictDump(ctx context.Context, dict string) *StringSliceCmd
// FTDropIndex(ctx context.Context, index string) *StatusCmd
// FTDropIndexWithArgs(ctx context.Context, index string, options *FTDropIndexOptions) *StatusCmd
// FTExplain(ctx context.Context, index string, query string) *StringCmd
// FTExplainWithArgs(ctx context.Context, index string, query string, options *FTExplainOptions) *StringCmd
// FTInfo(ctx context.Context, index string) *FTInfoCmd
// FTSpellCheck(ctx context.Context, index string, query string) *FTSpellCheckCmd
// FTSpellCheckWithArgs(ctx context.Context, index string, query string, options *FTSpellCheckOptions) *FTSpellCheckCmd
// FTSearch(ctx context.Context, index string, query string) *FTSearchCmd
// FTSearchWithArgs(ctx context.Context, index string, query string, options *FTSearchOptions) *FTSearchCmd
// FTSynDump(ctx context.Context, index string) *FTSynDumpCmd
// FTSynUpdate(ctx context.Context, index string, synGroupId interface{}, terms []interface{}) *StatusCmd
// FTSynUpdateWithArgs(ctx context.Context, index string, synGroupId interface{}, options *FTSynUpdateOptions, terms []interface{}) *StatusCmd
// FTTagVals(ctx context.Context, index string, field string) *StringSliceCmd
//}

// https://github.com/redis/go-redis/blob/af4872cbd0de349855ce3f0978929c2f56eb995f/probabilistic.go#L10
type ProbabilisticCmdable interface {
Expand Down Expand Up @@ -470,12 +533,11 @@ type ProbabilisticCmdable interface {
TDigestRevRank(ctx context.Context, key string, values ...float64) *IntSliceCmd
TDigestTrimmedMean(ctx context.Context, key string, lowCutQuantile, highCutQuantile float64) *FloatCmd

Subscribe(ctx context.Context, channels ...string) PubSub
PSubscribe(ctx context.Context, patterns ...string) PubSub
SSubscribe(ctx context.Context, channels ...string) PubSub

Pipeline() Pipeliner
Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error)

TxPipeline() Pipeliner
TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error)
}

// Align with go-redis
Expand Down Expand Up @@ -4629,6 +4691,24 @@ func (c *Compat) Pipeline() Pipeliner {
return newPipeline(c.client)
}

func (c *Compat) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return newTxPipeline(c.client).Pipelined(ctx, fn)
}

func (c *Compat) TxPipeline() Pipeliner {
return newTxPipeline(c.client)
}

func (c *Compat) Watch(ctx context.Context, fn func(Tx) error, keys ...string) error {
dc, cancel := c.client.Dedicate()
defer cancel()
tx := newTx(dc, cancel)
if err := tx.Watch(ctx, keys...).Err(); err != nil {
return err
}
return fn(newTx(dc, cancel))
}

func (c CacheCompat) BitCount(ctx context.Context, key string, bitCount *BitCount) *IntCmd {
var resp rueidis.RedisResult
if bitCount == nil {
Expand Down
20 changes: 2 additions & 18 deletions rueidiscompat/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ import (
// To avoid this: it is good idea to use reasonable bigger read/write timeouts
// depends on your batch size and/or use TxPipeline.
type Pipeliner interface {
Cmdable
CoreCmdable

// Len is to obtain the number of commands in the pipeline that have not yet been executed.
Len() int
Expand Down Expand Up @@ -96,10 +96,6 @@ type Pipeline struct {
rets []Cmder
}

func (c *Pipeline) Cache(ttl time.Duration) CacheCompat {
return c.comp.Cache(ttl)
}

func (c *Pipeline) Command(ctx context.Context) *CommandsInfoCmd {
ret := c.comp.Command(ctx)
c.rets = append(c.rets, ret)
Expand Down Expand Up @@ -2434,18 +2430,6 @@ func (c *Pipeline) TDigestTrimmedMean(ctx context.Context, key string, lowCutQua
return ret
}

func (c *Pipeline) Subscribe(ctx context.Context, channels ...string) PubSub {
return c.comp.Subscribe(ctx, channels...)
}

func (c *Pipeline) PSubscribe(ctx context.Context, patterns ...string) PubSub {
return c.comp.PSubscribe(ctx, patterns...)
}

func (c *Pipeline) SSubscribe(ctx context.Context, channels ...string) PubSub {
return c.comp.SSubscribe(ctx, channels...)
}

func (c *Pipeline) TSAdd(ctx context.Context, key string, timestamp interface{}, value float64) *IntCmd {
ret := c.comp.TSAdd(ctx, key, timestamp, value)
c.rets = append(c.rets, ret)
Expand Down Expand Up @@ -2788,7 +2772,7 @@ func (c *Pipeline) Len() int {
}

// Do queues the custom command for later execution.
func (c *Pipeline) Do(ctx context.Context, args ...interface{}) *Cmd {
func (c *Pipeline) Do(_ context.Context, args ...interface{}) *Cmd {
ret := &Cmd{}
if len(args) == 0 {
ret.SetErr(errors.New("redis: please enter the command to be executed"))
Expand Down
145 changes: 145 additions & 0 deletions rueidiscompat/tx.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package rueidiscompat

import (
"context"
"errors"
"time"
"unsafe"

"github.com/redis/rueidis"
)

var TxFailedErr = errors.New("redis: transaction failed")

var _ Pipeliner = (*TxPipeline)(nil)

type rePipeline = Pipeline

func newTxPipeline(real rueidis.Client) *TxPipeline {
return &TxPipeline{rePipeline: newPipeline(real)}
}

type TxPipeline struct {
*rePipeline
}

func (c *TxPipeline) Exec(ctx context.Context) ([]Cmder, error) {
p := c.comp.client.(*proxy)
if len(p.cmds) == 0 {
return nil, nil
}

rets := c.rets
cmds := p.cmds
c.rets = nil
p.cmds = nil

cmds = append(cmds, c.comp.client.B().Multi().Build(), c.comp.client.B().Exec().Build())
for i := len(cmds) - 2; i >= 1; i-- {
j := i - 1
cmds[j], cmds[i] = cmds[i], cmds[j]
}

resp := p.DoMulti(ctx, cmds...)
results, err := resp[len(resp)-1].ToArray()
if rueidis.IsRedisNil(err) {
err = TxFailedErr
}
for i, r := range results {
rets[i].from(*(*rueidis.RedisResult)(unsafe.Pointer(&proxyresult{
err: resp[i+1].NonRedisError(),
val: r,
})))
}
return rets, err
}

func (c *TxPipeline) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
if err := fn(c); err != nil {
return nil, err
}
return c.Exec(ctx)
}

func (c *TxPipeline) Pipeline() Pipeliner {
return c
}

func (c *TxPipeline) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipelined(ctx, fn)
}

func (c *TxPipeline) TxPipeline() Pipeliner {
return c
}

var _ rueidis.Client = (*txproxy)(nil)

type txproxy struct {
rueidis.CoreClient
}

func (p *txproxy) DoCache(_ context.Context, _ rueidis.Cacheable, _ time.Duration) (resp rueidis.RedisResult) {
panic("not implemented")
}

func (p *txproxy) DoMultiCache(_ context.Context, _ ...rueidis.CacheableTTL) (resp []rueidis.RedisResult) {
panic("not implemented")
}

func (p *txproxy) DoStream(_ context.Context, _ rueidis.Completed) rueidis.RedisResultStream {
panic("not implemented")
}

func (p *txproxy) DoMultiStream(_ context.Context, _ ...rueidis.Completed) rueidis.MultiRedisResultStream {
panic("not implemented")
}

func (p *txproxy) Dedicated(_ func(rueidis.DedicatedClient) error) (err error) {
panic("not implemented")
}

func (p *txproxy) Dedicate() (client rueidis.DedicatedClient, cancel func()) {
panic("not implemented")
}

func (p *txproxy) Nodes() map[string]rueidis.Client {
panic("not implemented")
}

type Tx interface {
CoreCmdable
Watch(ctx context.Context, keys ...string) *StatusCmd
Unwatch(ctx context.Context, keys ...string) *StatusCmd
Close(ctx context.Context) error
}

func newTx(client rueidis.DedicatedClient, cancel func()) *tx {
return &tx{CoreCmdable: NewAdapter(&txproxy{CoreClient: client}), cancel: cancel}
}

type tx struct {
CoreCmdable
cancel func()
}

func (t *tx) Watch(ctx context.Context, keys ...string) *StatusCmd {
ret := &StatusCmd{}
if len(keys) != 0 {
client := t.CoreCmdable.(*Compat).client
ret.from(client.Do(ctx, client.B().Watch().Key(keys...).Build()))
}
return ret
}

func (t *tx) Unwatch(ctx context.Context, _ ...string) *StatusCmd {
ret := &StatusCmd{}
client := t.CoreCmdable.(*Compat).client
ret.from(client.Do(ctx, client.B().Unwatch().Build()))
return ret
}

func (t *tx) Close(_ context.Context) error {
t.cancel()
return nil
}
Loading