diff --git a/opencat.go b/opencat.go index 43d1eba..ffe1fe0 100644 --- a/opencat.go +++ b/opencat.go @@ -71,6 +71,8 @@ func main() { // 获取当前用户信息 group.GET("/me", router.HandleMe) + group.GET("/me/usages", router.HandleMeUsage) + // 获取所有Key group.GET("/keys", router.HandleKeys) diff --git a/router/router.go b/router/router.go index d1427c7..688845f 100644 --- a/router/router.go +++ b/router/router.go @@ -203,6 +203,40 @@ func HandleMe(c *gin.Context) { c.JSON(http.StatusOK, resJSON) } +func HandleMeUsage(c *gin.Context) { + token := c.GetHeader("Authorization") + fromStr := c.Query("from") + toStr := c.Query("to") + getMonthStartAndEnd := func() (start, end string) { + loc, _ := time.LoadLocation("Local") + now := time.Now().In(loc) + + year, month, _ := now.Date() + + startOfMonth := time.Date(year, month, 1, 0, 0, 0, 0, loc) + endOfMonth := startOfMonth.AddDate(0, 1, 0) + + start = startOfMonth.Format("2006-01-02") + end = endOfMonth.Format("2006-01-02") + return + } + if fromStr == "" || toStr == "" { + fromStr, toStr = getMonthStartAndEnd() + } + user, err := store.GetUserByToken(token) + if err != nil { + c.AbortWithError(http.StatusForbidden, err) + return + } + usage, err := store.QueryUserUsage(to.String(user.ID), fromStr, toStr) + if err != nil { + c.AbortWithError(http.StatusForbidden, err) + return + } + + c.JSON(200, usage) +} + func HandleKeys(c *gin.Context) { keys, err := store.GetAllKeys() if err != nil { @@ -439,7 +473,12 @@ func HandleProy(c *gin.Context) { case "openai": fallthrough default: - req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, &body) + if onekey.EndPoint != "" { + req, err = http.NewRequest(c.Request.Method, onekey.EndPoint+c.Request.RequestURI, &body) + } else { + req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, &body) + } + req.Header = c.Request.Header req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key)) } diff --git a/store/usage.go b/store/usage.go index d0c3bc3..f5aa543 100644 --- a/store/usage.go +++ b/store/usage.go @@ -68,6 +68,21 @@ func QueryUsage(from, to string) ([]CalcUsage, error) { return results, nil } +func QueryUserUsage(userid, from, to string) (*CalcUsage, error) { + var results = new(CalcUsage) + err := usage.Model(&DailyUsage{}).Select(`user_id, + --SUM(prompt_units) AS prompt_units, + -- SUM(completion_units) AS completion_units, + SUM(total_unit) AS total_unit, + printf('%.6f', SUM(cost)) AS cost`). + Where("user_id = ? AND date >= ? AND date < ?", userid, from, to). + Find(&results).Error + if err != nil { + return nil, err + } + return results, nil +} + type Tokens struct { UserID int PromptCount int