Initial commit
This commit is contained in:
119
services/auth_service.go
Normal file
119
services/auth_service.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"git.beifan.cn/trace-system/backend-go/config"
|
||||
"git.beifan.cn/trace-system/backend-go/database"
|
||||
"git.beifan.cn/trace-system/backend-go/models"
|
||||
)
|
||||
|
||||
// AuthService 认证服务
|
||||
type AuthService struct{}
|
||||
|
||||
// ValidateUser 验证用户身份
|
||||
func (s *AuthService) ValidateUser(username string, password string) (*models.User, error) {
|
||||
var user models.User
|
||||
result := database.DB.Where("username = ?", username).First(&user)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("验证用户失败: %w", errors.New("用户名或密码错误"))
|
||||
}
|
||||
|
||||
err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("密码验证失败: %w", errors.New("用户名或密码错误"))
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GenerateToken 生成 JWT 令牌
|
||||
func (s *AuthService) GenerateToken(user *models.User) (string, error) {
|
||||
cfg := config.GetAppConfig()
|
||||
claims := jwt.MapClaims{
|
||||
"userId": user.ID,
|
||||
"username": user.Username,
|
||||
"role": user.Role,
|
||||
"exp": time.Now().Add(time.Second * time.Duration(cfg.JWT.Expire)).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(cfg.JWT.Secret))
|
||||
}
|
||||
|
||||
// GetProfile 获取用户信息
|
||||
func (s *AuthService) GetProfile(userId uint) (*models.UserDTO, error) {
|
||||
var user models.User
|
||||
result := database.DB.First(&user, userId)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询用户失败: %w", errors.New("用户不存在"))
|
||||
}
|
||||
|
||||
return &models.UserDTO{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Name: user.Name,
|
||||
Email: user.Email,
|
||||
Role: user.Role,
|
||||
CreatedAt: user.CreatedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ChangePassword 修改密码
|
||||
func (s *AuthService) ChangePassword(userId uint, currentPassword string, newPassword string) error {
|
||||
var user models.User
|
||||
result := database.DB.First(&user, userId)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("查询用户失败: %w", errors.New("用户不存在"))
|
||||
}
|
||||
|
||||
err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(currentPassword))
|
||||
if err != nil {
|
||||
return fmt.Errorf("密码验证失败: %w", errors.New("当前密码错误"))
|
||||
}
|
||||
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("密码加密失败: %w", err)
|
||||
}
|
||||
|
||||
user.Password = string(hashedPassword)
|
||||
result = database.DB.Save(&user)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("保存用户失败: %w", errors.New("密码修改失败"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateProfile 更新用户信息
|
||||
func (s *AuthService) UpdateProfile(userId uint, name string, email string) (*models.UserDTO, error) {
|
||||
var user models.User
|
||||
result := database.DB.First(&user, userId)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询用户失败: %w", errors.New("用户不存在"))
|
||||
}
|
||||
|
||||
user.Name = name
|
||||
user.Email = email
|
||||
|
||||
result = database.DB.Save(&user)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("保存用户失败: %w", errors.New("个人信息更新失败"))
|
||||
}
|
||||
|
||||
return &models.UserDTO{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Name: user.Name,
|
||||
Email: user.Email,
|
||||
Role: user.Role,
|
||||
CreatedAt: user.CreatedAt,
|
||||
}, nil
|
||||
}
|
||||
112
services/companies_service.go
Normal file
112
services/companies_service.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"git.beifan.cn/trace-system/backend-go/database"
|
||||
"git.beifan.cn/trace-system/backend-go/models"
|
||||
)
|
||||
|
||||
// CompaniesService 企业管理服务
|
||||
type CompaniesService struct{}
|
||||
|
||||
// FindAll 获取所有企业列表
|
||||
func (s *CompaniesService) FindAll(page int, limit int, search string) ([]models.Company, int, int, error) {
|
||||
var companies []models.Company
|
||||
var total int64
|
||||
|
||||
offset := (page - 1) * limit
|
||||
db := database.DB
|
||||
|
||||
// 搜索条件
|
||||
if search != "" {
|
||||
db = db.Where("company_name LIKE ?", "%"+search+"%")
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
db.Count(&total)
|
||||
|
||||
// 分页查询
|
||||
result := db.Order("created_at DESC").Offset(offset).Limit(limit).Find(&companies)
|
||||
if result.Error != nil {
|
||||
return nil, 0, 0, errors.New("查询企业列表失败")
|
||||
}
|
||||
|
||||
totalPages := (int(total) + limit - 1) / limit
|
||||
|
||||
return companies, int(total), totalPages, nil
|
||||
}
|
||||
|
||||
// Create 创建企业
|
||||
func (s *CompaniesService) Create(companyName string) (*models.Company, error) {
|
||||
// 检查企业是否已存在
|
||||
var existingCompany models.Company
|
||||
result := database.DB.Where("company_name = ?", companyName).First(&existingCompany)
|
||||
if result.Error == nil {
|
||||
return nil, errors.New("企业名称已存在")
|
||||
}
|
||||
|
||||
company := models.Company{
|
||||
CompanyName: companyName,
|
||||
IsActive: true,
|
||||
}
|
||||
|
||||
result = database.DB.Create(&company)
|
||||
if result.Error != nil {
|
||||
return nil, errors.New("创建企业失败")
|
||||
}
|
||||
|
||||
return &company, nil
|
||||
}
|
||||
|
||||
// Update 更新企业信息
|
||||
func (s *CompaniesService) Update(companyName string, newCompanyName string, isActive bool) (*models.Company, error) {
|
||||
var company models.Company
|
||||
result := database.DB.Where("company_name = ?", companyName).First(&company)
|
||||
if result.Error != nil {
|
||||
return nil, errors.New("企业不存在")
|
||||
}
|
||||
|
||||
// 如果企业名称已变更,检查新名称是否已存在
|
||||
if newCompanyName != companyName {
|
||||
var existingCompany models.Company
|
||||
checkResult := database.DB.Where("company_name = ?", newCompanyName).First(&existingCompany)
|
||||
if checkResult.Error == nil {
|
||||
return nil, errors.New("企业名称已存在")
|
||||
}
|
||||
|
||||
company.CompanyName = newCompanyName
|
||||
}
|
||||
|
||||
company.IsActive = isActive
|
||||
|
||||
result = database.DB.Save(&company)
|
||||
if result.Error != nil {
|
||||
return nil, errors.New("更新企业信息失败")
|
||||
}
|
||||
|
||||
return &company, nil
|
||||
}
|
||||
|
||||
// Delete 删除企业
|
||||
func (s *CompaniesService) Delete(companyName string) error {
|
||||
var company models.Company
|
||||
result := database.DB.Where("company_name = ?", companyName).First(&company)
|
||||
if result.Error != nil {
|
||||
return errors.New("企业不存在")
|
||||
}
|
||||
|
||||
// 检查企业是否有关联的序列号
|
||||
var serialCount int64
|
||||
database.DB.Model(&models.Serial{}).Where("company_name = ?", companyName).Count(&serialCount)
|
||||
if serialCount > 0 {
|
||||
return errors.New("企业下还有序列号,无法删除")
|
||||
}
|
||||
|
||||
result = database.DB.Delete(&company)
|
||||
if result.Error != nil {
|
||||
return errors.New("删除企业失败")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
BIN
services/data/database.sqlite
Normal file
BIN
services/data/database.sqlite
Normal file
Binary file not shown.
269
services/serials_service.go
Normal file
269
services/serials_service.go
Normal file
@@ -0,0 +1,269 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
qr "github.com/yeqown/go-qrcode/v2"
|
||||
"github.com/yeqown/go-qrcode/writer/standard"
|
||||
|
||||
"git.beifan.cn/trace-system/backend-go/database"
|
||||
"git.beifan.cn/trace-system/backend-go/models"
|
||||
)
|
||||
|
||||
// SerialsService 序列号服务
|
||||
type SerialsService struct{}
|
||||
|
||||
// Generate 生成序列号
|
||||
func (s *SerialsService) Generate(
|
||||
companyName string,
|
||||
quantity int,
|
||||
validDays int,
|
||||
userId uint,
|
||||
prefix ...string,
|
||||
) ([]models.Serial, error) {
|
||||
var serials []models.Serial
|
||||
validUntil := time.Now().AddDate(0, 0, validDays)
|
||||
|
||||
// 检查公司是否存在,不存在则创建
|
||||
var company models.Company
|
||||
result := database.DB.Where("company_name = ?", companyName).First(&company)
|
||||
if result.Error != nil {
|
||||
company = models.Company{
|
||||
CompanyName: companyName,
|
||||
IsActive: true,
|
||||
}
|
||||
result = database.DB.Create(&company)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("创建公司失败: %w", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
// 生成序列号前缀
|
||||
var serialPrefix string
|
||||
if len(prefix) > 0 && prefix[0] != "" {
|
||||
serialPrefix = strings.ToUpper(strings.ReplaceAll(prefix[0], "[^A-Z0-9]", ""))
|
||||
} else {
|
||||
serialPrefix = fmt.Sprintf("BF%d", time.Now().Year()%100)
|
||||
}
|
||||
|
||||
// 预生成所有序列号
|
||||
serialNumbers := make(map[string]bool)
|
||||
for i := 0; i < quantity; {
|
||||
randomBytes := make([]byte, 3)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return nil, fmt.Errorf("生成随机数失败: %w", err)
|
||||
}
|
||||
randomPart := hex.EncodeToString(randomBytes)[:6]
|
||||
serialNumber := fmt.Sprintf("%s%s", serialPrefix, randomPart)
|
||||
|
||||
if serialNumbers[serialNumber] {
|
||||
continue
|
||||
}
|
||||
|
||||
var existingSerial models.Serial
|
||||
checkResult := database.DB.Where("serial_number = ?", serialNumber).First(&existingSerial)
|
||||
if checkResult.Error != nil {
|
||||
serialNumbers[serialNumber] = true
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
for serialNumber := range serialNumbers {
|
||||
serial := models.Serial{
|
||||
SerialNumber: strings.ToUpper(serialNumber),
|
||||
CompanyName: companyName,
|
||||
ValidUntil: &validUntil,
|
||||
CreatedBy: &userId,
|
||||
IsActive: true,
|
||||
}
|
||||
serials = append(serials, serial)
|
||||
}
|
||||
|
||||
// 保存到数据库
|
||||
result = database.DB.Create(&serials)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("保存序列号失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return serials, nil
|
||||
}
|
||||
|
||||
// GenerateQRCode 生成二维码
|
||||
func (s *SerialsService) GenerateQRCode(
|
||||
serialNumber string,
|
||||
baseUrl string,
|
||||
requestHost string,
|
||||
protocol string,
|
||||
) (string, string, error) {
|
||||
var serial models.Serial
|
||||
result := database.DB.Preload("User").Where("serial_number = ?", strings.ToUpper(serialNumber)).First(&serial)
|
||||
if result.Error != nil {
|
||||
return "", "", fmt.Errorf("查询序列号失败: %w", errors.New("序列号不存在"))
|
||||
}
|
||||
|
||||
if !serial.IsActive {
|
||||
return "", "", fmt.Errorf("序列号状态无效: %w", errors.New("序列号已被禁用"))
|
||||
}
|
||||
|
||||
if serial.ValidUntil != nil && serial.ValidUntil.Before(time.Now()) {
|
||||
return "", "", fmt.Errorf("序列号已过期")
|
||||
}
|
||||
|
||||
// 确定查询 URL
|
||||
if baseUrl == "" {
|
||||
baseUrl = fmt.Sprintf("%s://%s/query.html", protocol, requestHost)
|
||||
}
|
||||
|
||||
var queryUrl string
|
||||
if strings.Contains(baseUrl, "?") {
|
||||
queryUrl = fmt.Sprintf("%s&serial=%s", baseUrl, serial.SerialNumber)
|
||||
} else {
|
||||
queryUrl = fmt.Sprintf("%s?serial=%s", baseUrl, serial.SerialNumber)
|
||||
}
|
||||
|
||||
// 生成二维码到临时文件
|
||||
filePath := fmt.Sprintf("temp_qr_%s.png", uuid.New().String())
|
||||
writer, err := standard.New(filePath, standard.WithQRWidth(6))
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("二维码写入器创建失败: %w", err)
|
||||
}
|
||||
|
||||
qrc, errCode := qr.New(queryUrl)
|
||||
if errCode != nil {
|
||||
os.Remove(filePath)
|
||||
return "", "", fmt.Errorf("二维码创建失败: %w", errCode)
|
||||
}
|
||||
|
||||
if errSave := qrc.Save(writer); errSave != nil {
|
||||
os.Remove(filePath)
|
||||
return "", "", fmt.Errorf("二维码保存失败: %w", errSave)
|
||||
}
|
||||
|
||||
// 读取文件内容
|
||||
fileContent, errRead := os.ReadFile(filePath)
|
||||
if errRead != nil {
|
||||
os.Remove(filePath)
|
||||
return "", "", fmt.Errorf("二维码文件读取失败: %w", errRead)
|
||||
}
|
||||
|
||||
// 删除临时文件
|
||||
os.Remove(filePath)
|
||||
|
||||
// 转换为 base64
|
||||
qrCodeBase64 := fmt.Sprintf("data:image/png;base64,%s", base64.StdEncoding.EncodeToString(fileContent))
|
||||
return qrCodeBase64, queryUrl, nil
|
||||
}
|
||||
|
||||
// Query 查询序列号信息
|
||||
func (s *SerialsService) Query(serialNumber string) (*models.Serial, error) {
|
||||
var serial models.Serial
|
||||
result := database.DB.Preload("User").Where("serial_number = ?", strings.ToUpper(serialNumber)).First(&serial)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询序列号失败: %w", errors.New("序列号不存在"))
|
||||
}
|
||||
|
||||
if serial.ValidUntil != nil && serial.ValidUntil.Before(time.Now()) {
|
||||
return nil, fmt.Errorf("序列号已过期")
|
||||
}
|
||||
|
||||
return &serial, nil
|
||||
}
|
||||
|
||||
// FindAll 获取序列号列表
|
||||
func (s *SerialsService) FindAll(page int, limit int, search string) ([]models.Serial, int, int, error) {
|
||||
var serials []models.Serial
|
||||
var total int64
|
||||
|
||||
offset := (page - 1) * limit
|
||||
db := database.DB.Preload("User")
|
||||
|
||||
// 搜索条件
|
||||
if search != "" {
|
||||
db = db.Where("serial_number LIKE ? OR company_name LIKE ?", "%"+search+"%", "%"+search+"%")
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
countQuery := db.Model(&models.Serial{})
|
||||
if search != "" {
|
||||
countQuery = countQuery.Where("serial_number LIKE ? OR company_name LIKE ?", "%"+search+"%", "%"+search+"%")
|
||||
}
|
||||
countQuery.Count(&total)
|
||||
|
||||
// 分页查询
|
||||
result := db.Model(&models.Serial{}).Order("created_at DESC").Offset(offset).Limit(limit).Find(&serials)
|
||||
if result.Error != nil {
|
||||
return nil, 0, 0, fmt.Errorf("查询序列号列表失败: %w", result.Error)
|
||||
}
|
||||
|
||||
totalPages := (int(total) + limit - 1) / limit
|
||||
|
||||
return serials, int(total), totalPages, nil
|
||||
}
|
||||
|
||||
// Update 更新序列号信息
|
||||
func (s *SerialsService) Update(serialNumber string, updateData models.UpdateSerialDTO) (*models.Serial, error) {
|
||||
var serial models.Serial
|
||||
result := database.DB.Where("serial_number = ?", strings.ToUpper(serialNumber)).First(&serial)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询序列号失败: %w", errors.New("序列号不存在"))
|
||||
}
|
||||
|
||||
if updateData.CompanyName != "" {
|
||||
// 检查公司是否存在
|
||||
var company models.Company
|
||||
companyResult := database.DB.Where("company_name = ?", updateData.CompanyName).First(&company)
|
||||
if companyResult.Error != nil {
|
||||
company = models.Company{
|
||||
CompanyName: updateData.CompanyName,
|
||||
IsActive: true,
|
||||
}
|
||||
database.DB.Create(&company)
|
||||
}
|
||||
|
||||
serial.CompanyName = updateData.CompanyName
|
||||
}
|
||||
|
||||
if updateData.ValidUntil != nil {
|
||||
serial.ValidUntil = updateData.ValidUntil
|
||||
}
|
||||
|
||||
if updateData.IsActive != nil {
|
||||
serial.IsActive = *updateData.IsActive
|
||||
}
|
||||
|
||||
result = database.DB.Save(&serial)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("更新序列号失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return &serial, nil
|
||||
}
|
||||
|
||||
// Revoke 吊销序列号
|
||||
func (s *SerialsService) Revoke(serialNumber string) error {
|
||||
var serial models.Serial
|
||||
result := database.DB.Where("serial_number = ?", strings.ToUpper(serialNumber)).First(&serial)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("查询序列号失败: %w", errors.New("序列号不存在"))
|
||||
}
|
||||
|
||||
if !serial.IsActive {
|
||||
return fmt.Errorf("序列号状态无效: %w", errors.New("序列号已被吊销"))
|
||||
}
|
||||
|
||||
serial.IsActive = false
|
||||
result = database.DB.Save(&serial)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("吊销序列号失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
413
services/services_test.go
Normal file
413
services/services_test.go
Normal file
@@ -0,0 +1,413 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"git.beifan.cn/trace-system/backend-go/config"
|
||||
"git.beifan.cn/trace-system/backend-go/database"
|
||||
"git.beifan.cn/trace-system/backend-go/logger"
|
||||
"git.beifan.cn/trace-system/backend-go/models"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
config.LoadConfig()
|
||||
|
||||
if err := logger.InitializeLogger("test"); err != nil {
|
||||
fmt.Printf("日志系统初始化失败: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer logger.Sync()
|
||||
|
||||
database.InitDB()
|
||||
database.AutoMigrate()
|
||||
|
||||
database.DB.Unscoped().Where("1 = 1").Delete(&models.User{})
|
||||
database.DB.Unscoped().Where("1 = 1").Delete(&models.Company{})
|
||||
database.DB.Unscoped().Where("1 = 1").Delete(&models.Serial{})
|
||||
|
||||
exitCode := m.Run()
|
||||
|
||||
database.DB.Unscoped().Where("1 = 1").Delete(&models.User{})
|
||||
database.DB.Unscoped().Where("1 = 1").Delete(&models.Company{})
|
||||
database.DB.Unscoped().Where("1 = 1").Delete(&models.Serial{})
|
||||
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
func TestAuthService_ValidateUser_Success(t *testing.T) {
|
||||
var user models.User
|
||||
password, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
|
||||
user = models.User{
|
||||
Username: "testuser",
|
||||
Password: string(password),
|
||||
Name: "测试用户",
|
||||
Email: "test@example.com",
|
||||
Role: "user",
|
||||
}
|
||||
database.DB.Create(&user)
|
||||
|
||||
authService := AuthService{}
|
||||
result, err := authService.ValidateUser("testuser", "password123")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, "testuser", result.Username)
|
||||
|
||||
database.DB.Unscoped().Delete(&user)
|
||||
}
|
||||
|
||||
func TestAuthService_ValidateUser_WrongPassword(t *testing.T) {
|
||||
var user models.User
|
||||
password, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
|
||||
user = models.User{
|
||||
Username: "testuser2",
|
||||
Password: string(password),
|
||||
Name: "测试用户2",
|
||||
Email: "test2@example.com",
|
||||
Role: "user",
|
||||
}
|
||||
database.DB.Create(&user)
|
||||
|
||||
authService := AuthService{}
|
||||
_, err := authService.ValidateUser("testuser2", "wrongpassword")
|
||||
|
||||
assert.Error(t, err)
|
||||
|
||||
database.DB.Unscoped().Delete(&user)
|
||||
}
|
||||
|
||||
func TestAuthService_ValidateUser_UserNotFound(t *testing.T) {
|
||||
authService := AuthService{}
|
||||
_, err := authService.ValidateUser("nonexistent", "password")
|
||||
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAuthService_GenerateToken_Success(t *testing.T) {
|
||||
user := &models.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Role: "user",
|
||||
}
|
||||
|
||||
authService := AuthService{}
|
||||
token, err := authService.GenerateToken(user)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
assert.NotEmpty(t, "Bearer "+token)
|
||||
}
|
||||
|
||||
func TestAuthService_GetProfile_Success(t *testing.T) {
|
||||
var user models.User
|
||||
password, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
|
||||
user = models.User{
|
||||
Username: "testuser3",
|
||||
Password: string(password),
|
||||
Name: "测试用户3",
|
||||
Email: "test3@example.com",
|
||||
Role: "user",
|
||||
}
|
||||
database.DB.Create(&user)
|
||||
|
||||
authService := AuthService{}
|
||||
profile, err := authService.GetProfile(user.ID)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, profile)
|
||||
assert.Equal(t, "testuser3", profile.Username)
|
||||
assert.Equal(t, "测试用户3", profile.Name)
|
||||
|
||||
database.DB.Unscoped().Delete(&user)
|
||||
}
|
||||
|
||||
func TestAuthService_ChangePassword_Success(t *testing.T) {
|
||||
var user models.User
|
||||
password, _ := bcrypt.GenerateFromPassword([]byte("oldpassword"), bcrypt.DefaultCost)
|
||||
user = models.User{
|
||||
Username: "testuser4",
|
||||
Password: string(password),
|
||||
Name: "测试用户4",
|
||||
Email: "test4@example.com",
|
||||
Role: "user",
|
||||
}
|
||||
database.DB.Create(&user)
|
||||
|
||||
authService := AuthService{}
|
||||
err := authService.ChangePassword(user.ID, "oldpassword", "newpassword")
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
var updatedUser models.User
|
||||
database.DB.First(&updatedUser, user.ID)
|
||||
err = bcrypt.CompareHashAndPassword([]byte(updatedUser.Password), []byte("newpassword"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
database.DB.Unscoped().Delete(&user)
|
||||
}
|
||||
|
||||
func TestAuthService_ChangePassword_WrongCurrentPassword(t *testing.T) {
|
||||
var user models.User
|
||||
password, _ := bcrypt.GenerateFromPassword([]byte("oldpassword"), bcrypt.DefaultCost)
|
||||
user = models.User{
|
||||
Username: "testuser5",
|
||||
Password: string(password),
|
||||
Name: "测试用户5",
|
||||
Email: "test5@example.com",
|
||||
Role: "user",
|
||||
}
|
||||
database.DB.Create(&user)
|
||||
|
||||
authService := AuthService{}
|
||||
err := authService.ChangePassword(user.ID, "wrongpassword", "newpassword")
|
||||
|
||||
assert.Error(t, err)
|
||||
|
||||
database.DB.Unscoped().Delete(&user)
|
||||
}
|
||||
|
||||
func TestAuthService_UpdateProfile_Success(t *testing.T) {
|
||||
var user models.User
|
||||
password, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
|
||||
user = models.User{
|
||||
Username: "testuser6",
|
||||
Password: string(password),
|
||||
Name: "测试用户6",
|
||||
Email: "test6@example.com",
|
||||
Role: "user",
|
||||
}
|
||||
database.DB.Create(&user)
|
||||
|
||||
authService := AuthService{}
|
||||
profile, err := authService.UpdateProfile(user.ID, "新名称", "newemail@example.com")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, profile)
|
||||
assert.Equal(t, "新名称", profile.Name)
|
||||
assert.Equal(t, "newemail@example.com", profile.Email)
|
||||
|
||||
database.DB.Unscoped().Delete(&user)
|
||||
}
|
||||
|
||||
func TestSerialsService_Generate_Success(t *testing.T) {
|
||||
var user models.User
|
||||
password, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
|
||||
user = models.User{
|
||||
Username: "adminuser",
|
||||
Password: string(password),
|
||||
Name: "管理员",
|
||||
Email: "admin@example.com",
|
||||
Role: "admin",
|
||||
}
|
||||
database.DB.Create(&user)
|
||||
|
||||
serialService := SerialsService{}
|
||||
serials, err := serialService.Generate("TestCompany", 5, 30, user.ID)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, serials, 5)
|
||||
assert.Equal(t, "TestCompany", serials[0].CompanyName)
|
||||
assert.True(t, serials[0].IsActive)
|
||||
|
||||
for _, serial := range serials {
|
||||
database.DB.Unscoped().Delete(&serial)
|
||||
}
|
||||
database.DB.Unscoped().Where("company_name = ?", "TestCompany").Delete(&models.Company{})
|
||||
database.DB.Unscoped().Delete(&user)
|
||||
}
|
||||
|
||||
func TestSerialsService_Generate_WithPrefix(t *testing.T) {
|
||||
var user models.User
|
||||
password, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
|
||||
user = models.User{
|
||||
Username: "adminuser2",
|
||||
Password: string(password),
|
||||
Name: "管理员2",
|
||||
Email: "admin2@example.com",
|
||||
Role: "admin",
|
||||
}
|
||||
database.DB.Create(&user)
|
||||
|
||||
serialService := SerialsService{}
|
||||
serials, err := serialService.Generate("TestCompany2", 3, 30, user.ID, "TEST")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, serials, 3)
|
||||
assert.True(t, len(serials[0].SerialNumber) > 0)
|
||||
assert.Contains(t, serials[0].SerialNumber, "TEST")
|
||||
|
||||
for _, serial := range serials {
|
||||
database.DB.Unscoped().Delete(&serial)
|
||||
}
|
||||
database.DB.Unscoped().Where("company_name = ?", "TestCompany2").Delete(&models.Company{})
|
||||
database.DB.Unscoped().Delete(&user)
|
||||
}
|
||||
|
||||
func TestSerialsService_Query_QuerySuccess(t *testing.T) {
|
||||
var user models.User
|
||||
password, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
|
||||
user = models.User{
|
||||
Username: "adminuser3",
|
||||
Password: string(password),
|
||||
Name: "管理员3",
|
||||
Email: "admin3@example.com",
|
||||
Role: "admin",
|
||||
}
|
||||
database.DB.Create(&user)
|
||||
|
||||
serialService := SerialsService{}
|
||||
serials, _ := serialService.Generate("TestCompany3", 1, 30, user.ID, "QR")
|
||||
|
||||
serialNumber := strings.ToUpper(serials[0].SerialNumber)
|
||||
result, err := serialService.Query(serialNumber)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, serialNumber, strings.ToUpper(result.SerialNumber))
|
||||
assert.True(t, result.IsActive)
|
||||
|
||||
for _, serial := range serials {
|
||||
database.DB.Unscoped().Delete(&serial)
|
||||
}
|
||||
database.DB.Unscoped().Where("company_name = ?", "TestCompany3").Delete(&models.Company{})
|
||||
database.DB.Unscoped().Delete(&user)
|
||||
}
|
||||
|
||||
func TestSerialsService_Query_SerialNotFound(t *testing.T) {
|
||||
serialService := SerialsService{}
|
||||
_, err := serialService.Query("NONEXISTENT")
|
||||
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSerialsService_FindAll_Success(t *testing.T) {
|
||||
var user models.User
|
||||
password, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
|
||||
user = models.User{
|
||||
Username: "adminuser4",
|
||||
Password: string(password),
|
||||
Name: "管理员4",
|
||||
Email: "admin4@example.com",
|
||||
Role: "admin",
|
||||
}
|
||||
database.DB.Create(&user)
|
||||
|
||||
serialService := SerialsService{}
|
||||
serials, _ := serialService.Generate("TestCompany4", 10, 30, user.ID, "LIST")
|
||||
|
||||
result, total, totalPages, err := serialService.FindAll(1, 5, "")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, 5)
|
||||
assert.GreaterOrEqual(t, total, 10)
|
||||
assert.Greater(t, totalPages, 0)
|
||||
|
||||
for _, serial := range serials {
|
||||
database.DB.Unscoped().Delete(&serial)
|
||||
}
|
||||
database.DB.Unscoped().Where("company_name = ?", "TestCompany4").Delete(&models.Company{})
|
||||
database.DB.Unscoped().Delete(&user)
|
||||
}
|
||||
|
||||
func TestSerialsService_FindAll_WithSearch(t *testing.T) {
|
||||
var user models.User
|
||||
password, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
|
||||
user = models.User{
|
||||
Username: "adminuser5",
|
||||
Password: string(password),
|
||||
Name: "管理员5",
|
||||
Email: "admin5@example.com",
|
||||
Role: "admin",
|
||||
}
|
||||
database.DB.Create(&user)
|
||||
|
||||
serialService := SerialsService{}
|
||||
serials, _ := serialService.Generate("SearchCompany", 5, 30, user.ID, "SEARCH")
|
||||
|
||||
result, _, _, err := serialService.FindAll(1, 10, "SearchCompany")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, len(result), 0)
|
||||
assert.Equal(t, "SearchCompany", result[0].CompanyName)
|
||||
|
||||
for _, serial := range serials {
|
||||
database.DB.Unscoped().Delete(&serial)
|
||||
}
|
||||
database.DB.Unscoped().Where("company_name = ?", "SearchCompany").Delete(&models.Company{})
|
||||
database.DB.Unscoped().Delete(&user)
|
||||
}
|
||||
|
||||
func TestSerialsService_Revoke_Success(t *testing.T) {
|
||||
var user models.User
|
||||
password, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
|
||||
user = models.User{
|
||||
Username: "adminuser6",
|
||||
Password: string(password),
|
||||
Name: "管理员6",
|
||||
Email: "admin6@example.com",
|
||||
Role: "admin",
|
||||
}
|
||||
database.DB.Create(&user)
|
||||
|
||||
serialService := SerialsService{}
|
||||
serials, _ := serialService.Generate("RevokeCompany", 1, 30, user.ID, "REVOKE")
|
||||
|
||||
serialNumber := serials[0].SerialNumber
|
||||
err := serialService.Revoke(serialNumber)
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
var revokedSerial models.Serial
|
||||
database.DB.Where("serial_number = ?", serialNumber).First(&revokedSerial)
|
||||
assert.False(t, revokedSerial.IsActive)
|
||||
|
||||
for _, serial := range serials {
|
||||
database.DB.Unscoped().Delete(&serial)
|
||||
}
|
||||
database.DB.Unscoped().Where("company_name = ?", "RevokeCompany").Delete(&models.Company{})
|
||||
database.DB.Unscoped().Delete(&user)
|
||||
}
|
||||
|
||||
func TestSerialsService_Update_Success(t *testing.T) {
|
||||
var user models.User
|
||||
password, _ := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
|
||||
user = models.User{
|
||||
Username: "adminuser7",
|
||||
Password: string(password),
|
||||
Name: "管理员7",
|
||||
Email: "admin7@example.com",
|
||||
Role: "admin",
|
||||
}
|
||||
database.DB.Create(&user)
|
||||
|
||||
serialService := SerialsService{}
|
||||
serials, _ := serialService.Generate("UpdateCompany", 1, 30, user.ID, "UPDATE")
|
||||
|
||||
serialNumber := serials[0].SerialNumber
|
||||
newValidUntil := time.Now().AddDate(0, 0, 60)
|
||||
isActive := false
|
||||
|
||||
updateData := models.UpdateSerialDTO{
|
||||
ValidUntil: &newValidUntil,
|
||||
IsActive: &isActive,
|
||||
}
|
||||
|
||||
result, err := serialService.Update(serialNumber, updateData)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.False(t, result.IsActive)
|
||||
|
||||
for _, serial := range serials {
|
||||
database.DB.Unscoped().Delete(&serial)
|
||||
}
|
||||
database.DB.Unscoped().Where("company_name = ?", "UpdateCompany").Delete(&models.Company{})
|
||||
database.DB.Unscoped().Delete(&user)
|
||||
}
|
||||
Reference in New Issue
Block a user