diff --git a/go.sum b/go.sum index 64f5de4..406e622 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,7 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.3 h1:CE8S1cTafDpPvMhIxNJ github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32 h1:W6apQkHrMkS0Muv8G/TipAy/FJl/rCYT0+EuS8+Z0z4= github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= @@ -21,6 +22,7 @@ github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng= github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37 h1:cg5LA/zNPRzIXIWSCxQW10Rvpy94aQh3LT/ShoCpkHw= diff --git a/pay/notify/paid.go b/pay/notify/paid.go index 8f67bdf..63755fc 100644 --- a/pay/notify/paid.go +++ b/pay/notify/paid.go @@ -85,8 +85,15 @@ func (notify *Notify) PaidVerifySign(notifyRes PaidResult) bool { // STEP3, 在键值对的最后加上key=API_KEY signStrings = signStrings + "key=" + notify.Key - // STEP4, 进行MD5签名并且将所有字符转为大写. - sign := util.MD5Sum(signStrings) + // STEP4, 根据SignType计算出签名 + var signType string + if notifyRes.SignType != nil { + signType = *notifyRes.SignType + } + sign, err := util.CalculateSign(signStrings, signType, notify.Key) + if err != nil { + return false + } if sign != *notifyRes.Sign { return false } diff --git a/pay/order/pay.go b/pay/order/pay.go index b1336c2..b478aaa 100644 --- a/pay/order/pay.go +++ b/pay/order/pay.go @@ -1,13 +1,8 @@ package order import ( - "crypto/hmac" - "crypto/md5" - "crypto/sha256" - "encoding/hex" "encoding/xml" "errors" - "hash" "strconv" "strings" "time" @@ -96,13 +91,14 @@ type payRequest struct { LimitPay string `xml:"limit_pay,omitempty"` // OpenID string `xml:"openid,omitempty"` // 用户标识 SceneInfo string `xml:"scene_info,omitempty"` // 场景信息 + + XMLName struct{} `xml:"xml"` } // BridgeConfig get js bridge config func (o *Order) BridgeConfig(p *Params) (cfg Config, err error) { var ( buffer strings.Builder - h hash.Hash timestamp = strconv.FormatInt(time.Now().Unix(), 10) ) order, err := o.PrePayOrder(p) @@ -121,14 +117,13 @@ func (o *Order) BridgeConfig(p *Params) (cfg Config, err error) { buffer.WriteString(timestamp) buffer.WriteString("&key=") buffer.WriteString(o.Key) - if p.SignType == "MD5" { - h = md5.New() - } else { - h = hmac.New(sha256.New, []byte(o.Key)) + + sign, err := util.CalculateSign(buffer.String(), p.SignType, o.Key) + if err != nil { + return } - h.Write([]byte(buffer.String())) // 签名 - cfg.PaySign = strings.ToUpper(hex.EncodeToString(h.Sum(nil))) + cfg.PaySign = sign cfg.NonceStr = order.NonceStr cfg.Timestamp = timestamp cfg.PrePayID = order.PrePayID @@ -143,13 +138,13 @@ func (o *Order) PrePayOrder(p *Params) (payOrder PreOrder, err error) { notifyURL := o.NotifyURL // 签名类型 if p.SignType == "" { - p.SignType = "MD5" + p.SignType = util.SignTypeMD5 } // 通知地址 if p.NotifyURL != "" { notifyURL = p.NotifyURL } - param := make(map[string]interface{}) + param := make(map[string]string) param["appid"] = o.AppID param["body"] = p.Body param["mch_id"] = o.MchID @@ -165,9 +160,10 @@ func (o *Order) PrePayOrder(p *Params) (payOrder PreOrder, err error) { param["goods_tag"] = p.GoodsTag param["notify_url"] = notifyURL - bizKey := "&key=" + o.Key - str := util.OrderParam(param, bizKey) - sign := util.MD5Sum(str) + sign, err := util.ParamSign(param, o.Key) + if err != nil { + return + } request := payRequest{ AppID: o.AppID, MchID: o.MchID, @@ -202,7 +198,7 @@ func (o *Order) PrePayOrder(p *Params) (payOrder PreOrder, err error) { err = errors.New(payOrder.ErrCode + payOrder.ErrCodeDes) return } - err = errors.New("[msg : xmlUnmarshalError] [rawReturn : " + string(rawRet) + "] [params : " + str + "] [sign : " + sign + "]") + err = errors.New("[msg : xmlUnmarshalError] [rawReturn : " + string(rawRet) + "] [sign : " + sign + "]") return } diff --git a/pay/refund/refund.go b/pay/refund/refund.go index 1220d5c..4ac3e7a 100644 --- a/pay/refund/refund.go +++ b/pay/refund/refund.go @@ -73,7 +73,7 @@ type Response struct { //Refund 退款申请 func (refund *Refund) Refund(p *Params) (rsp Response, err error) { nonceStr := util.RandomStr(32) - param := make(map[string]interface{}) + param := make(map[string]string) param["appid"] = refund.AppID param["mch_id"] = refund.MchID param["nonce_str"] = nonceStr @@ -81,18 +81,20 @@ func (refund *Refund) Refund(p *Params) (rsp Response, err error) { param["refund_desc"] = p.RefundDesc param["refund_fee"] = p.RefundFee param["total_fee"] = p.TotalFee - param["sign_type"] = "MD5" + param["sign_type"] = util.SignTypeMD5 param["transaction_id"] = p.TransactionID - bizKey := "&key=" + refund.Key - str := util.OrderParam(param, bizKey) - sign := util.MD5Sum(str) + sign, err := util.ParamSign(param, refund.Key) + if err != nil { + return + } + request := request{ AppID: refund.AppID, MchID: refund.MchID, NonceStr: nonceStr, Sign: sign, - SignType: "MD5", + SignType: util.SignTypeMD5, TransactionID: p.TransactionID, OutRefundNo: p.OutRefundNo, TotalFee: p.TotalFee, @@ -115,7 +117,6 @@ func (refund *Refund) Refund(p *Params) (rsp Response, err error) { err = fmt.Errorf("refund error, errcode=%s,errmsg=%s", rsp.ErrCode, rsp.ErrCodeDes) return } - err = fmt.Errorf("[msg : xmlUnmarshalError] [rawReturn : %s] [params : %s] [sign : %s]", - string(rawRet), str, sign) + err = fmt.Errorf("[msg : xmlUnmarshalError] [rawReturn : %s] [sign : %s]", string(rawRet), sign) return } diff --git a/util/crypto.go b/util/crypto.go index ce21e36..2374fb6 100644 --- a/util/crypto.go +++ b/util/crypto.go @@ -1,14 +1,23 @@ package util import ( - "bufio" - "bytes" "crypto/aes" "crypto/cipher" + "crypto/hmac" "crypto/md5" + "crypto/sha256" "encoding/base64" "encoding/hex" + "errors" "fmt" + "hash" + "strings" +) + +// 微信签名算法方式 +const ( + SignTypeMD5 = `MD5` + SignTypeHMACSHA256 = `HMAC-SHA256` ) //EncryptMsg 加密消息 @@ -186,14 +195,35 @@ func decodeNetworkByteOrder(orderBytes []byte) (n uint32) { uint32(orderBytes[3]) } -// MD5Sum 计算 32 位长度的 MD5 sum -func MD5Sum(txt string) (sum string) { - h := md5.New() - buf := bufio.NewWriterSize(h, 128) - buf.WriteString(txt) - buf.Flush() - sign := make([]byte, hex.EncodedLen(h.Size())) - hex.Encode(sign, h.Sum(nil)) - sum = string(bytes.ToUpper(sign)) - return +// CalculateSign 计算签名 +func CalculateSign(content, signType, key string) (string, error) { + var h hash.Hash + if signType == SignTypeMD5 { + h = md5.New() + } else { + h = hmac.New(sha256.New, []byte(key)) + } + + if _, err := h.Write([]byte(content)); err != nil { + return ``, err + } + return strings.ToUpper(hex.EncodeToString(h.Sum(nil))), nil +} + +// ParamSign 计算所传参数的签名 +func ParamSign(p map[string]string, key string) (string, error) { + bizKey := "&key=" + key + str := OrderParam(p, bizKey) + + var signType string + switch p["sign_type"] { + case SignTypeMD5, SignTypeHMACSHA256: + signType = p["sign_type"] + case ``: + signType = SignTypeMD5 + default: + return ``, errors.New(`invalid sign_type`) + } + + return CalculateSign(str, signType, key) } diff --git a/util/param.go b/util/param.go index aac0038..5625784 100644 --- a/util/param.go +++ b/util/param.go @@ -3,65 +3,31 @@ package util import ( "bytes" "sort" - "strconv" ) // OrderParam order params -func OrderParam(source interface{}, bizKey string) (returnStr string) { - switch v := source.(type) { - case map[string]string: - keys := make([]string, 0, len(v)) - for k := range v { - if k == "sign" { - continue - } - keys = append(keys, k) +func OrderParam(p map[string]string, bizKey string) (returnStr string) { + keys := make([]string, 0, len(p)) + for k := range p { + if k == "sign" { + continue } - sort.Strings(keys) - var buf bytes.Buffer - for _, k := range keys { - if v[k] == "" { - continue - } - if buf.Len() > 0 { - buf.WriteByte('&') - } - buf.WriteString(k) - buf.WriteByte('=') - buf.WriteString(v[k]) - } - buf.WriteString(bizKey) - returnStr = buf.String() - case map[string]interface{}: - keys := make([]string, 0, len(v)) - for k := range v { - if k == "sign" { - continue - } - keys = append(keys, k) - } - sort.Strings(keys) - var buf bytes.Buffer - for _, k := range keys { - if v[k] == "" { - continue - } - if buf.Len() > 0 { - buf.WriteByte('&') - } - buf.WriteString(k) - buf.WriteByte('=') - switch vv := v[k].(type) { - case string: - buf.WriteString(vv) - case int: - buf.WriteString(strconv.FormatInt(int64(vv), 10)) - default: - panic("params type not supported") - } - } - buf.WriteString(bizKey) - returnStr = buf.String() + keys = append(keys, k) } + sort.Strings(keys) + var buf bytes.Buffer + for _, k := range keys { + if p[k] == "" { + continue + } + if buf.Len() > 0 { + buf.WriteByte('&') + } + buf.WriteString(k) + buf.WriteByte('=') + buf.WriteString(p[k]) + } + buf.WriteString(bizKey) + returnStr = buf.String() return }