Files
backend-go/services/companies_service.go

445 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package services
import (
"errors"
"fmt"
"sort"
"time"
"gorm.io/gorm"
"git.beifan.cn/trace-system/backend-go/database"
"git.beifan.cn/trace-system/backend-go/models"
)
// CompaniesService 企业管理服务
type CompaniesService struct{}
// FindAll 获取所有企业列表
func (s *CompaniesService) FindAll(page int, limit int, search string) ([]models.Company, int, int, error) {
var companies []models.Company
var total int64
if page < 1 {
page = 1
}
if limit < 1 {
limit = 20
}
offset := (page - 1) * limit
db := database.DB.Model(&models.Company{})
if search != "" {
db = db.Where("company_name LIKE ?", "%"+search+"%")
}
if err := db.Count(&total).Error; err != nil {
return nil, 0, 0, errors.New("查询企业总数失败")
}
result := db.Preload("Serials").Order("updated_at DESC").Offset(offset).Limit(limit).Find(&companies)
if result.Error != nil {
return nil, 0, 0, errors.New("查询企业列表失败")
}
totalPages := 0
if total > 0 {
totalPages = (int(total) + limit - 1) / limit
}
return companies, int(total), totalPages, nil
}
// FindOne 获取单个企业详情(含分页序列号)
func (s *CompaniesService) FindOne(companyName string, page int, limit int) (map[string]any, error) {
if page < 1 {
page = 1
}
if limit < 1 {
limit = 20
}
var company models.Company
if err := database.DB.Where("company_name = ?", companyName).First(&company).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("企业不存在")
}
return nil, errors.New("查询企业失败")
}
var allSerials []models.Serial
if err := database.DB.Preload("User").Where("company_name = ?", companyName).Order("created_at DESC").Find(&allSerials).Error; err != nil {
return nil, errors.New("查询企业序列号失败")
}
now := time.Now()
serialCount := len(allSerials)
activeCount := 0
disabledCount := 0
expiredCount := 0
for _, serial := range allSerials {
if !serial.IsActive {
disabledCount++
continue
}
if serial.ValidUntil != nil && serial.ValidUntil.Before(now) {
expiredCount++
continue
}
activeCount++
}
offset := (page - 1) * limit
end := offset + limit
if offset > len(allSerials) {
offset = len(allSerials)
}
if end > len(allSerials) {
end = len(allSerials)
}
paginatedSerials := allSerials[offset:end]
serialItems := make([]map[string]any, 0, len(paginatedSerials))
for _, serial := range paginatedSerials {
createdBy := ""
if serial.User != nil {
createdBy = serial.User.Name
}
serialItems = append(serialItems, map[string]any{
"serialNumber": serial.SerialNumber,
"validUntil": serial.ValidUntil,
"isActive": serial.IsActive,
"createdAt": serial.CreatedAt,
"createdBy": createdBy,
})
}
monthlyStatsMap := map[string]int{}
for i := 11; i >= 0; i-- {
date := time.Date(now.Year(), now.Month()-time.Month(i), 1, 0, 0, 0, 0, time.Local)
monthKey := date.Format("2006-01")
monthlyStatsMap[monthKey] = 0
}
for _, serial := range allSerials {
monthKey := serial.CreatedAt.Format("2006-01")
if _, ok := monthlyStatsMap[monthKey]; ok {
monthlyStatsMap[monthKey]++
}
}
monthlyStats := make([]map[string]any, 0)
for i := 11; i >= 0; i-- {
date := time.Date(now.Year(), now.Month()-time.Month(i), 1, 0, 0, 0, 0, time.Local)
monthKey := date.Format("2006-01")
count := monthlyStatsMap[monthKey]
if count > 0 {
monthlyStats = append(monthlyStats, map[string]any{"month": monthKey, "count": count})
}
}
return map[string]any{
"companyName": company.CompanyName,
"serialCount": serialCount,
"activeCount": activeCount,
"disabledCount": disabledCount,
"expiredCount": expiredCount,
"firstCreated": company.CreatedAt,
"lastCreated": company.UpdatedAt,
"status": map[bool]string{true: "active", false: "disabled"}[company.IsActive],
"serials": serialItems,
"monthlyStats": monthlyStats,
"pagination": map[string]any{
"page": page,
"limit": limit,
"total": serialCount,
"totalPages": func() int {
if serialCount == 0 {
return 0
}
return (serialCount + limit - 1) / limit
}(),
},
}, nil
}
// Create 创建企业
func (s *CompaniesService) Create(companyName string) (*models.Company, error) {
var existingCompany models.Company
result := database.DB.Where("company_name = ?", companyName).First(&existingCompany)
if result.Error == nil {
return nil, errors.New("企业名称已存在")
}
company := models.Company{
CompanyName: companyName,
IsActive: true,
}
result = database.DB.Create(&company)
if result.Error != nil {
return nil, errors.New("创建企业失败")
}
return &company, nil
}
// Update 更新企业信息
func (s *CompaniesService) Update(companyName string, newCompanyName string, isActive *bool) (*models.Company, error) {
var company models.Company
result := database.DB.Where("company_name = ?", companyName).First(&company)
if result.Error != nil {
return nil, errors.New("企业不存在")
}
if newCompanyName == "" {
newCompanyName = companyName
}
if newCompanyName != companyName {
var existingCompany models.Company
checkResult := database.DB.Where("company_name = ?", newCompanyName).First(&existingCompany)
if checkResult.Error == nil {
return nil, errors.New("企业名称已存在")
}
}
err := database.DB.Transaction(func(tx *gorm.DB) error {
if newCompanyName != companyName {
if err := tx.Model(&models.Serial{}).Where("company_name = ?", companyName).Update("company_name", newCompanyName).Error; err != nil {
return fmt.Errorf("更新企业赋码企业名称失败: %w", err)
}
if err := tx.Model(&models.EmployeeSerial{}).Where("company_name = ?", companyName).Update("company_name", newCompanyName).Error; err != nil {
return fmt.Errorf("更新员工赋码企业名称失败: %w", err)
}
company.CompanyName = newCompanyName
}
if isActive != nil {
company.IsActive = *isActive
}
if err := tx.Save(&company).Error; err != nil {
return fmt.Errorf("更新企业信息失败: %w", err)
}
return nil
})
if err != nil {
return nil, errors.New(err.Error())
}
return &company, nil
}
// Delete 删除企业(同时删除关联序列号)
func (s *CompaniesService) Delete(companyName string) error {
var company models.Company
if err := database.DB.Where("company_name = ?", companyName).First(&company).Error; err != nil {
return errors.New("企业不存在")
}
if err := database.DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Where("company_name = ?", companyName).Delete(&models.Serial{}).Error; err != nil {
return err
}
if err := tx.Where("company_name = ?", companyName).Delete(&models.EmployeeSerial{}).Error; err != nil {
return err
}
if err := tx.Delete(&company).Error; err != nil {
return err
}
return nil
}); err != nil {
return errors.New("删除企业失败")
}
return nil
}
// DeleteSerial 删除企业下指定企业赋码序列号
func (s *CompaniesService) DeleteSerial(companyName string, serialNumber string) error {
var serial models.Serial
err := database.DB.Where("serial_number = ? AND company_name = ?", serialNumber, companyName).First(&serial).Error
if err != nil {
return errors.New("序列号不存在或不属于该企业")
}
if err := database.DB.Delete(&serial).Error; err != nil {
return errors.New("删除序列号失败")
}
return nil
}
// Revoke 吊销企业(吊销所有企业赋码与员工赋码)
func (s *CompaniesService) Revoke(companyName string) error {
var company models.Company
if err := database.DB.Where("company_name = ?", companyName).First(&company).Error; err != nil {
return errors.New("企业不存在")
}
if err := database.DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Model(&models.Serial{}).Where("company_name = ?", companyName).Update("is_active", false).Error; err != nil {
return err
}
if err := tx.Model(&models.EmployeeSerial{}).Where("company_name = ?", companyName).Update("is_active", false).Error; err != nil {
return err
}
if err := tx.Model(&company).Update("is_active", false).Error; err != nil {
return err
}
return nil
}); err != nil {
return errors.New("吊销企业失败")
}
return nil
}
// GetStats 获取企业统计(兼容 Node 返回结构)
func (s *CompaniesService) GetStats() (map[string]any, error) {
now := time.Now()
var companies []models.Company
if err := database.DB.Order("updated_at DESC").Find(&companies).Error; err != nil {
return nil, errors.New("查询企业统计失败")
}
var serials []models.Serial
if err := database.DB.Order("created_at DESC").Find(&serials).Error; err != nil {
return nil, errors.New("查询序列号统计失败")
}
var employeeSerials []models.EmployeeSerial
if err := database.DB.Order("created_at DESC").Find(&employeeSerials).Error; err != nil {
return nil, errors.New("查询员工序列号统计失败")
}
companyCount := len(companies)
serialCount := len(serials)
employeeSerialCount := len(employeeSerials)
activeCount := 0
for _, serial := range serials {
if serial.IsActive && (serial.ValidUntil == nil || serial.ValidUntil.After(now)) {
activeCount++
}
}
inactiveCount := serialCount - activeCount
monthlyItems := make([]map[string]any, 0)
for i := 11; i >= 0; i-- {
date := time.Date(now.Year(), now.Month()-time.Month(i), 1, 0, 0, 0, 0, time.Local)
monthStr := date.Format("2006-01")
monthSerialCount := 0
companySet := map[string]bool{}
for _, serial := range serials {
if serial.CreatedAt.Year() == date.Year() && serial.CreatedAt.Month() == date.Month() {
monthSerialCount++
companySet[serial.CompanyName] = true
}
}
if monthSerialCount > 0 {
monthlyItems = append(monthlyItems, map[string]any{
"month": monthStr,
"company_count": len(companySet),
"serial_count": monthSerialCount,
})
}
}
recentCompanies := make([]map[string]any, 0)
for i, company := range companies {
if i >= 10 {
break
}
recentCompanies = append(recentCompanies, map[string]any{
"companyName": company.CompanyName,
"lastCreated": company.UpdatedAt,
"status": map[bool]string{true: "active", false: "disabled"}[company.IsActive],
})
}
recentSerials := make([]map[string]any, 0)
// 添加企业序列号
for _, serial := range serials {
recentSerials = append(recentSerials, map[string]any{
"serialNumber": serial.SerialNumber,
"companyName": serial.CompanyName,
"isActive": serial.IsActive,
"createdAt": serial.CreatedAt,
"type": "company",
})
}
// 添加员工序列号
for _, serial := range employeeSerials {
recentSerials = append(recentSerials, map[string]any{
"serialNumber": serial.SerialNumber,
"companyName": serial.CompanyName,
"isActive": serial.IsActive,
"createdAt": serial.CreatedAt,
"type": "employee",
"position": serial.Position,
"employeeName": serial.EmployeeName,
})
}
// 按创建时间排序保留最新的10条
sort.Slice(recentSerials, func(i, j int) bool {
return recentSerials[i]["createdAt"].(time.Time).After(recentSerials[j]["createdAt"].(time.Time))
})
if len(recentSerials) > 10 {
recentSerials = recentSerials[:10]
}
return map[string]any{
"overview": map[string]any{
"totalCompanies": companyCount,
"totalSerials": serialCount,
"totalEmployeeSerials": employeeSerialCount,
"activeSerials": activeCount,
"inactiveSerials": inactiveCount,
},
"monthlyStats": monthlyItems,
"recentCompanies": recentCompanies,
"recentSerials": recentSerials,
}, nil
}
// GetStatsOverview 获取企业统计概览
func (s *CompaniesService) GetStatsOverview() (*models.CompanyStatsOverviewDTO, error) {
stats := &models.CompanyStatsOverviewDTO{}
if err := database.DB.Model(&models.Company{}).Count(&stats.TotalCompanies).Error; err != nil {
return nil, errors.New("统计企业总数失败")
}
if err := database.DB.Model(&models.Company{}).Where("is_active = ?", true).Count(&stats.ActiveCompanies).Error; err != nil {
return nil, errors.New("统计启用企业数量失败")
}
stats.InactiveCompanies = stats.TotalCompanies - stats.ActiveCompanies
if err := database.DB.Model(&models.Serial{}).Count(&stats.TotalSerials).Error; err != nil {
return nil, errors.New("统计企业赋码总数失败")
}
if err := database.DB.Model(&models.Serial{}).Where("is_active = ?", true).Count(&stats.ActiveSerials).Error; err != nil {
return nil, errors.New("统计有效企业赋码数量失败")
}
stats.RevokedSerials = stats.TotalSerials - stats.ActiveSerials
if err := database.DB.Model(&models.EmployeeSerial{}).Count(&stats.TotalEmployeeSerials).Error; err != nil {
return nil, errors.New("统计员工赋码总数失败")
}
if err := database.DB.Model(&models.EmployeeSerial{}).Where("is_active = ?", true).Count(&stats.ActiveEmployeeSerials).Error; err != nil {
return nil, errors.New("统计有效员工赋码数量失败")
}
stats.RevokedEmployeeSerials = stats.TotalEmployeeSerials - stats.ActiveEmployeeSerials
return stats, nil
}