198 lines
5.7 KiB
Go
198 lines
5.7 KiB
Go
package config
|
||
|
||
import (
|
||
"fmt"
|
||
"os"
|
||
"strings"
|
||
|
||
"github.com/joho/godotenv"
|
||
"github.com/spf13/viper"
|
||
)
|
||
|
||
// ServerConfig 服务器配置
|
||
type ServerConfig struct {
|
||
Port string `mapstructure:"port"`
|
||
Environment string `mapstructure:"environment"`
|
||
}
|
||
|
||
// DatabaseConfig 数据库配置
|
||
type DatabaseConfig struct {
|
||
Driver string `mapstructure:"driver"`
|
||
SQLite SQLiteConfig `mapstructure:"sqlite"`
|
||
Postgres PostgresConfig `mapstructure:"postgres"`
|
||
}
|
||
|
||
// SQLiteConfig SQLite 配置
|
||
type SQLiteConfig struct {
|
||
Path string `mapstructure:"path"`
|
||
}
|
||
|
||
// PostgresConfig PostgreSQL 配置
|
||
type PostgresConfig struct {
|
||
Host string `mapstructure:"host"`
|
||
Port string `mapstructure:"port"`
|
||
User string `mapstructure:"user"`
|
||
Password string `mapstructure:"password"`
|
||
DBName string `mapstructure:"dbname"`
|
||
SSLMode string `mapstructure:"sslmode"`
|
||
}
|
||
|
||
// JWTConfig JWT 配置
|
||
type JWTConfig struct {
|
||
Secret string `mapstructure:"secret"`
|
||
Expire int `mapstructure:"expire"`
|
||
}
|
||
|
||
// AppConfig 应用程序配置
|
||
type AppConfig struct {
|
||
Server ServerConfig `mapstructure:"server"`
|
||
Database DatabaseConfig `mapstructure:"database"`
|
||
JWT JWTConfig `mapstructure:"jwt"`
|
||
}
|
||
|
||
// 全局配置变量
|
||
var appConfig AppConfig
|
||
|
||
// LoadConfig 加载配置
|
||
// 优先级:环境变量 > .env 文件 > config.yaml > 默认值
|
||
func LoadConfig() {
|
||
// 设置默认值(最先设置,优先级最低)
|
||
setDefaults()
|
||
|
||
// 1. 加载 config.yaml 配置文件
|
||
viper.SetConfigName("config")
|
||
viper.SetConfigType("yaml")
|
||
viper.AddConfigPath(".")
|
||
viper.AddConfigPath("./config")
|
||
if err := viper.ReadInConfig(); err != nil {
|
||
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
|
||
fmt.Println("提示: 未找到 config.yaml,使用默认配置")
|
||
} else {
|
||
fmt.Printf("错误: 读取配置文件失败: %v\n", err)
|
||
}
|
||
} else {
|
||
fmt.Printf("已加载配置文件: %s\n", viper.ConfigFileUsed())
|
||
}
|
||
|
||
// 2. 加载 .env 文件(如果存在)
|
||
// .env 文件中的变量会被加载到环境变量中
|
||
if err := godotenv.Load(); err != nil {
|
||
// .env 文件是可选的,不存在不报错
|
||
if os.IsNotExist(err) {
|
||
fmt.Println("提示: 未找到 .env 文件,将使用环境变量或默认值")
|
||
} else {
|
||
fmt.Printf("警告: 加载 .env 文件失败: %v\n", err)
|
||
}
|
||
} else {
|
||
fmt.Println("已加载 .env 文件")
|
||
}
|
||
|
||
// 3. 设置环境变量支持
|
||
// 设置环境变量前缀为 APP_
|
||
viper.SetEnvPrefix("APP")
|
||
// 设置环境变量键名替换规则(将 _ 替换为 .)
|
||
// 例如:APP_SERVER_PORT 会映射到 server.port
|
||
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||
// 启用自动环境变量读取
|
||
viper.AutomaticEnv()
|
||
|
||
// 4. 绑定特定的环境变量(确保嵌套结构体能正确映射)
|
||
bindEnvVariables()
|
||
|
||
// 解析配置到结构体
|
||
if err := viper.Unmarshal(&appConfig); err != nil {
|
||
fmt.Printf("错误: 配置解析失败: %v\n", err)
|
||
}
|
||
|
||
// 验证配置
|
||
validateConfig()
|
||
|
||
// 打印配置信息(开发环境)
|
||
printConfig()
|
||
}
|
||
|
||
// setDefaults 设置默认配置值
|
||
func setDefaults() {
|
||
// Server 默认值
|
||
viper.SetDefault("server.port", "3000")
|
||
viper.SetDefault("server.environment", "development")
|
||
|
||
// Database 默认值
|
||
viper.SetDefault("database.driver", "sqlite")
|
||
viper.SetDefault("database.sqlite.path", "./data/database.sqlite")
|
||
viper.SetDefault("database.postgres.host", "localhost")
|
||
viper.SetDefault("database.postgres.port", "5432")
|
||
viper.SetDefault("database.postgres.user", "trace")
|
||
viper.SetDefault("database.postgres.password", "trace123")
|
||
viper.SetDefault("database.postgres.dbname", "trace")
|
||
viper.SetDefault("database.postgres.sslmode", "disable")
|
||
|
||
// JWT 默认值
|
||
viper.SetDefault("jwt.secret", "your-secret-key-here-change-in-production")
|
||
viper.SetDefault("jwt.expire", 7200)
|
||
}
|
||
|
||
// bindEnvVariables 绑定环境变量
|
||
// 支持以下格式:
|
||
// - APP_SERVER_PORT=8080 (映射到 server.port)
|
||
// - APP_DATABASE_DRIVER=postgres (映射到 database.driver)
|
||
// - APP_DATABASE_POSTGRES_HOST=db.example.com (映射到 database.postgres.host)
|
||
func bindEnvVariables() {
|
||
// 服务器配置
|
||
viper.BindEnv("server.port")
|
||
viper.BindEnv("server.environment")
|
||
|
||
// 数据库配置
|
||
viper.BindEnv("database.driver")
|
||
viper.BindEnv("database.sqlite.path")
|
||
viper.BindEnv("database.postgres.host")
|
||
viper.BindEnv("database.postgres.port")
|
||
viper.BindEnv("database.postgres.user")
|
||
viper.BindEnv("database.postgres.password")
|
||
viper.BindEnv("database.postgres.dbname")
|
||
viper.BindEnv("database.postgres.sslmode")
|
||
|
||
// JWT 配置
|
||
viper.BindEnv("jwt.secret")
|
||
viper.BindEnv("jwt.expire")
|
||
}
|
||
|
||
// validateConfig 验证配置
|
||
func validateConfig() {
|
||
// 验证 JWT 密钥
|
||
if appConfig.JWT.Secret == "your-secret-key-here-change-in-production" {
|
||
if appConfig.Server.Environment == "production" {
|
||
fmt.Println("警告: 生产环境使用了默认 JWT 密钥,请设置 APP_JWT_SECRET 环境变量")
|
||
} else {
|
||
fmt.Println("提示: 使用默认 JWT 密钥(仅适用于开发环境)")
|
||
}
|
||
}
|
||
|
||
// 验证端口
|
||
if appConfig.Server.Port == "" {
|
||
appConfig.Server.Port = "3000"
|
||
}
|
||
}
|
||
|
||
// printConfig 打印配置信息(仅开发环境)
|
||
func printConfig() {
|
||
fmt.Printf("配置加载完成:\n")
|
||
fmt.Printf(" 环境: %s\n", appConfig.Server.Environment)
|
||
fmt.Printf(" 端口: %s\n", appConfig.Server.Port)
|
||
fmt.Printf(" 数据库驱动: %s\n", appConfig.Database.Driver)
|
||
|
||
if appConfig.Database.Driver == "sqlite" {
|
||
fmt.Printf(" SQLite 路径: %s\n", appConfig.Database.SQLite.Path)
|
||
} else if appConfig.Database.Driver == "postgres" {
|
||
fmt.Printf(" PostgreSQL: %s:%s/%s\n",
|
||
appConfig.Database.Postgres.Host,
|
||
appConfig.Database.Postgres.Port,
|
||
appConfig.Database.Postgres.DBName)
|
||
}
|
||
}
|
||
|
||
// GetAppConfig 获取应用程序配置
|
||
func GetAppConfig() *AppConfig {
|
||
return &appConfig
|
||
}
|