package store import ( "fmt" "log" "opencatd-open/internal/model" "opencatd-open/pkg/config" "os" "strings" // "gocloud.dev/mysql" // "gocloud.dev/postgres" "github.com/glebarez/sqlite" "gorm.io/driver/mysql" "gorm.io/driver/postgres" // "gorm.io/driver/sqlite" "gorm.io/gorm" ) var DB *gorm.DB // var DBType consts.DBType var IsPostgres bool func GetDB() *gorm.DB { return DB } // InitDB 初始化数据库连接 func InitDB(cfg *config.Config) (*gorm.DB, error) { var db *gorm.DB var err error // 从环境变量获取DSN dsn := cfg.DSN if dsn == "" { log.Println("No DSN provided, using SQLite as default") db, err = initSQLite() } // 解析DSN来确定数据库类型 if strings.HasPrefix(dsn, "postgres://") { IsPostgres = true cfg.DB_Type = "postgres" db, err = initPostgres(dsn) } else if strings.HasPrefix(dsn, "mysql://") { cfg.DB_Type = "mysql" db, err = initMySQL(dsn) } else { if dsn != "" { return nil, fmt.Errorf("unsupported database type in DSN: %s", dsn) } } if err != nil { return nil, err } DB = db if IsPostgres { err = db.AutoMigrate(&model.User{}, &model.Token{}, &model.ApiKey_PG{}, &model.Usage{}, &model.DailyUsage{}, &model.Passkey{}) if err != nil { return nil, err } } else { err = db.AutoMigrate(&model.User{}, &model.Token{}, &model.ApiKey{}, &model.Usage{}, &model.DailyUsage{}, &model.Passkey{}) if err != nil { return nil, err } } return db, nil } // initSQLite 初始化 SQLite 数据库 func initSQLite() (*gorm.DB, error) { if _, err := os.Stat("db"); os.IsNotExist(err) { errDir := os.MkdirAll("db", 0755) if errDir != nil { log.Fatalln("Error creating directory:", err) } } db, err := gorm.Open(sqlite.Open("./db/openteam.db"), &gorm.Config{}) if err != nil { return nil, fmt.Errorf("failed to connect to SQLite: %v", err) } // db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) return db, nil } // initPostgres 初始化 PostgreSQL 数据库 func initPostgres(dsn string) (*gorm.DB, error) { db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) if err != nil { return nil, fmt.Errorf("failed to connect to PostgreSQL: %v", err) } return db, nil } // initMySQL 初始化 MySQL 数据库 func initMySQL(dsn string) (*gorm.DB, error) { // 移除 "mysql://" 前缀,因为 MySQL 驱动不需要这个前缀 dsn = strings.TrimPrefix(dsn, "mysql://") db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) if err != nil { return nil, fmt.Errorf("failed to connect to MySQL: %v", err) } return db, nil }