diff --git a/cache/redis.go b/cache/redis.go index bcaa8b8..682a76c 100644 --- a/cache/redis.go +++ b/cache/redis.go @@ -2,9 +2,13 @@ package cache import ( "context" + "fmt" + "net" + "os" "time" "github.com/go-redis/redis/v8" + "golang.org/x/crypto/ssh" ) // Redis .redis cache @@ -22,6 +26,7 @@ type RedisOpts struct { MaxIdle int `yml:"max_idle" json:"max_idle"` MaxActive int `yml:"max_active" json:"max_active"` IdleTimeout int `yml:"idle_timeout" json:"idle_timeout"` // second + Dialer func(ctx context.Context, network, addr string) (net.Conn, error) } // NewRedis 实例化 @@ -33,6 +38,20 @@ func NewRedis(ctx context.Context, opts *RedisOpts) *Redis { Password: opts.Password, IdleTimeout: time.Second * time.Duration(opts.IdleTimeout), MinIdleConns: opts.MaxIdle, + Dialer: opts.Dialer, + }) + return &Redis{ctx: ctx, conn: conn} +} + +// NewRedisOverSSH 实例化(通过 SSH 代理连接 Redis ) +func NewRedisOverSSH(ctx context.Context, opts *RedisOpts, overSSH *OverSSH) *Redis { + conn := redis.NewUniversalClient(&redis.UniversalOptions{ + Addrs: []string{opts.Host}, + DB: opts.Database, + Password: opts.Password, + IdleTimeout: time.Second * time.Duration(opts.IdleTimeout), + MinIdleConns: opts.MaxIdle, + Dialer: overSSH.MakeDialer(), }) return &Redis{ctx: ctx, conn: conn} } @@ -92,3 +111,80 @@ func (r *Redis) Delete(key string) error { func (r *Redis) DeleteContext(ctx context.Context, key string) error { return r.conn.Del(ctx, key).Err() } + +// SSHAuthMethod SSH认证方式 +type SSHAuthMethod uint8 + +const ( + // PubKeyAuth SSH公钥方式认证 + PubKeyAuth SSHAuthMethod = 1 + // PwdAuth SSH密码方式认证 + PwdAuth SSHAuthMethod = 2 +) + +// OverSSH SSH 代理配置 +type OverSSH struct { + Host string `yml:"host" json:"host"` + Port int `yml:"port" json:"port"` + AuthMethod SSHAuthMethod `yml:"auth_method" json:"auth_method"` + Username string `yml:"username" json:"username"` + Password string `yml:"password" json:"password"` + KeyFile string `yml:"key_file" json:"key_file"` +} + +// DialWithPassword 返回密码方式认证的 SSH 客户端 +func (s *OverSSH) DialWithPassword() (*ssh.Client, error) { + return ssh.Dial( + "tcp", + fmt.Sprintf("%s:%d", s.Host, s.Port), + &ssh.ClientConfig{ + User: s.Username, + Auth: []ssh.AuthMethod{ + ssh.Password(s.Password), + }, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }, + ) +} + +// DialWithKeyFile 返回公钥方式认证的 SSH 客户端 +func (s *OverSSH) DialWithKeyFile() (*ssh.Client, error) { + k, err := os.ReadFile(s.KeyFile) + if err != nil { + return nil, err + } + signer, err := ssh.ParsePrivateKey(k) + if err != nil { + return nil, err + } + + return ssh.Dial( + "tcp", + fmt.Sprintf("%s:%d", s.Host, s.Port), + &ssh.ClientConfig{ + User: s.Username, + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(signer), + }, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }, + ) +} + +// MakeDialer 创建 SSH 代理拨号器 +func (s *OverSSH) MakeDialer() func(ctx context.Context, network, addr string) (net.Conn, error) { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + var err error + var sshclient *ssh.Client + switch s.AuthMethod { + case PwdAuth: + sshclient, err = s.DialWithPassword() + case PubKeyAuth: + sshclient, err = s.DialWithKeyFile() + } + if err != nil { + return nil, err + } + return sshclient.Dial(network, addr) + } +} diff --git a/go.sum b/go.sum index 64deda9..56f0abe 100644 --- a/go.sum +++ b/go.sum @@ -111,6 +111,7 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=