diff --git a/pay/pay.go b/pay/pay.go index 0b2c601..dd9fafd 100644 --- a/pay/pay.go +++ b/pay/pay.go @@ -1,9 +1,11 @@ package pay import ( + "bytes" "encoding/xml" "errors" - "fmt" + "sort" + "strconv" "github.com/silenceper/wechat/context" "github.com/silenceper/wechat/util" @@ -24,6 +26,7 @@ type Params struct { Body string OutTradeNo string OpenID string + TradeType string } // Config 是传出用于 jsdk 用的参数 @@ -86,9 +89,20 @@ func NewPay(ctx *context.Context) *Pay { // PrePayOrder return data for invoke wechat payment func (pcf *Pay) PrePayOrder(p *Params) (payOrder PreOrder, err error) { nonceStr := util.RandomStr(32) - tradeType := "JSAPI" - template := "appid=%s&body=%s&mch_id=%s&nonce_str=%s¬ify_url=%s&openid=%s&out_trade_no=%s&spbill_create_ip=%s&total_fee=%s&trade_type=%s&key=%s" - str := fmt.Sprintf(template, pcf.AppID, p.Body, pcf.PayMchID, nonceStr, pcf.PayNotifyURL, p.OpenID, p.OutTradeNo, p.CreateIP, p.TotalFee, tradeType, pcf.PayKey) + param := make(map[string]interface{}) + param["appid"] = pcf.AppID + param["body"] = p.Body + param["mch_id"] = pcf.PayMchID + param["nonce_str"] =nonceStr + param["notify_url"] =pcf.PayNotifyURL + param["out_trade_no"] =p.OutTradeNo + param["spbill_create_ip"] =p.CreateIP + param["total_fee"] =p.TotalFee + param["trade_type"] =p.TradeType + param["openid"] = p.OpenID + + bizKey := "&key="+pcf.PayKey + str := orderParam(param,bizKey) sign := util.MD5Sum(str) request := payRequest{ AppID: pcf.AppID, @@ -100,7 +114,7 @@ func (pcf *Pay) PrePayOrder(p *Params) (payOrder PreOrder, err error) { TotalFee: p.TotalFee, SpbillCreateIP: p.CreateIP, NotifyURL: pcf.PayNotifyURL, - TradeType: tradeType, + TradeType: p.TradeType, OpenID: p.OpenID, } rawRet, err := util.PostXML(payGateway, request) @@ -136,3 +150,63 @@ func (pcf *Pay) PrePayID(p *Params) (prePayID string, err error) { prePayID = order.PrePayID return } + +// order params +func orderParam(source interface{}, bizKey string) (returnStr string) { + switch v := source.(type) { + case map[string]string: + keys := make([]string, 0, len(v)) + for k := range v { + if k == "sign" { + continue + } + keys = append(keys, k) + } + sort.Strings(keys) + var buf bytes.Buffer + for _, k := range keys { + if v[k] == "" { + continue + } + if buf.Len() > 0 { + buf.WriteByte('&') + } + buf.WriteString(k) + buf.WriteByte('=') + buf.WriteString(v[k]) + } + buf.WriteString(bizKey) + returnStr = buf.String() + case map[string]interface{}: + keys := make([]string, 0, len(v)) + for k := range v { + if k == "sign" { + continue + } + keys = append(keys, k) + } + sort.Strings(keys) + var buf bytes.Buffer + for _, k := range keys { + if v[k] == "" { + continue + } + if buf.Len() > 0 { + buf.WriteByte('&') + } + buf.WriteString(k) + buf.WriteByte('=') + switch vv := v[k].(type) { + case string: + buf.WriteString(vv) + case int: + buf.WriteString(strconv.FormatInt(int64(vv), 10)) + default: + panic("params type not supported") + } + } + buf.WriteString(bizKey) + returnStr = buf.String() + } + return +}