This commit is contained in:
Sakurasan
2025-02-01 23:52:55 +08:00
parent 65d6d12972
commit bc223d6530
30 changed files with 2683 additions and 242 deletions

View File

@@ -3,6 +3,7 @@ package store
import (
"fmt"
"log"
"opencatd-open/team/consts"
"opencatd-open/team/model"
"os"
"strings"
@@ -10,13 +11,23 @@ import (
// "gocloud.dev/mysql"
// "gocloud.dev/postgres"
"github.com/glebarez/sqlite"
"github.com/google/wire"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
// "gorm.io/driver/sqlite"
"gorm.io/gorm"
)
var DB *gorm.DB
var DBType consts.DBType
var IsPostgres bool
var DBSet = wire.NewSet(
InitDB,
)
// InitDB 初始化数据库连接
func InitDB() (*gorm.DB, error) {
var db *gorm.DB
@@ -32,21 +43,35 @@ func InitDB() (*gorm.DB, error) {
// 解析DSN来确定数据库类型
if strings.HasPrefix(dsn, "postgres://") {
IsPostgres = true
DBType = consts.DBTypePostgreSQL
db, err = initPostgres(dsn)
} else if strings.HasPrefix(dsn, "mysql://") {
DBType = consts.DBTypeMySQL
db, err = initMySQL(dsn)
} else {
if dsn != "" {
return nil, fmt.Errorf("unsupported database type in DSN: %s", dsn)
}
}
if err != nil {
return nil, err
}
DB = db
if IsPostgres {
err = db.AutoMigrate(&model.User{}, &model.ApiKey_PG{}, &model.Token{}, &model.Session{}, &model.Usage{}, &model.DailyUsage{})
err = db.AutoMigrate(&model.User{}, &model.Token{}, &model.ApiKey_PG{}, &model.Usage{}, &model.DailyUsage{})
if err != nil {
return nil, err
}
} else {
err = db.AutoMigrate(&model.User{}, &model.Token{}, &model.ApiKey{}, &model.Usage{}, &model.DailyUsage{})
if err != nil {
return nil, err
}
}
return nil, fmt.Errorf("unsupported database type in DSN: %s", dsn)
return db, nil
}
// initSQLite 初始化 SQLite 数据库
@@ -55,6 +80,7 @@ func initSQLite() (*gorm.DB, error) {
if err != nil {
return nil, fmt.Errorf("failed to connect to SQLite: %v", err)
}
// db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
return db, nil
}

View File

@@ -42,7 +42,7 @@ func Handleinit(c *gin.Context) {
})
return
}
if user.ID == uint(1) {
if user.ID == 1 {
c.JSON(http.StatusForbidden, gin.H{
"error": "super user already exists, use cli to reset password",
})

View File

@@ -82,7 +82,7 @@ func HandleResetUserToken(c *gin.Context) {
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
if u.ID == uint(1) {
if u.ID == 1 {
rootToken = u.Token
}
c.JSON(http.StatusOK, u)

View File

@@ -183,7 +183,7 @@ func Cost(model string, promptCount, completionCount int) float64 {
cost = (0.00035/1000)*float64(prompt) + (0.00053/1000)*float64(completion)
case "gemini-2.0-flash-exp":
cost = (0.00035/1000)*float64(prompt) + (0.00053/1000)*float64(completion)
case "gemini-2.0-flash-thinking-exp-1219":
case "gemini-2.0-flash-thinking-exp-1219", "gemini-2.0-flash-thinking-exp-01-21":
cost = (0.00035/1000)*float64(prompt) + (0.00053/1000)*float64(completion)
case "learnlm-1.5-pro-experimental", " gemini-exp-1114", "gemini-exp-1121", "gemini-exp-1206":
cost = (0.00035/1000)*float64(prompt) + (0.00053/1000)*float64(completion)