package dao import ( "context" "errors" "fmt" "time" "community/internal/cache" "community/internal/model" cacheBase "github.com/zhufuyi/sponge/pkg/cache" "github.com/zhufuyi/sponge/pkg/mysql/query" "github.com/zhufuyi/sponge/pkg/utils" "golang.org/x/sync/singleflight" "gorm.io/gorm" ) var _ PostDao = (*postDao)(nil) // PostDao defining the dao interface type PostDao interface { Create(ctx context.Context, table *model.Post) error DeleteByID(ctx context.Context, id uint64) error DeleteByIDs(ctx context.Context, ids []uint64) error UpdateByID(ctx context.Context, table *model.Post) error GetByID(ctx context.Context, id uint64) (*model.Post, error) GetByCondition(ctx context.Context, condition *query.Conditions) (*model.Post, error) GetByIDs(ctx context.Context, ids []uint64) (map[uint64]*model.Post, error) GetByColumns(ctx context.Context, params *query.Params) ([]*model.Post, int64, error) IncrViewCount(ctx context.Context, id uint64) error IncrShareCount(ctx context.Context, id uint64) error CreateByTx(ctx context.Context, tx *gorm.DB, post *model.Post) (uint64, error) DeleteByTx(ctx context.Context, tx *gorm.DB, id uint64, delFlag int) error UpdateByTx(ctx context.Context, tx *gorm.DB, table *model.Post) error IncrCommentCountByTx(ctx context.Context, tx *gorm.DB, id uint64) error DecrCommentCountByTx(ctx context.Context, tx *gorm.DB, id uint64) error IncrLikeCountByTx(ctx context.Context, tx *gorm.DB, id uint64) error DecrLikeCountByTx(ctx context.Context, tx *gorm.DB, id uint64) error IncrCollectCountByTx(ctx context.Context, tx *gorm.DB, id uint64) error DecrCollectCountByTx(ctx context.Context, tx *gorm.DB, id uint64) error } type postDao struct { db *gorm.DB cache cache.PostCache sfg *singleflight.Group } // NewPostDao creating the dao interface func NewPostDao(db *gorm.DB, xCache cache.PostCache) PostDao { return &postDao{ db: db, cache: xCache, sfg: new(singleflight.Group), } } // Create a record, insert the record and the id value is written back to the table func (d *postDao) Create(ctx context.Context, table *model.Post) error { err := d.db.WithContext(ctx).Create(table).Error _ = d.cache.Del(ctx, table.ID) return err } // DeleteByID delete a record by id func (d *postDao) DeleteByID(ctx context.Context, id uint64) error { err := d.db.WithContext(ctx).Where("id = ?", id).Delete(&model.Post{}).Error if err != nil { return err } // delete cache _ = d.cache.Del(ctx, id) return nil } // DeleteByIDs delete records by batch id func (d *postDao) DeleteByIDs(ctx context.Context, ids []uint64) error { err := d.db.WithContext(ctx).Where("id IN (?)", ids).Delete(&model.Post{}).Error if err != nil { return err } // delete cache for _, id := range ids { _ = d.cache.Del(ctx, id) } return nil } // UpdateByID update a record by id func (d *postDao) UpdateByID(ctx context.Context, table *model.Post) error { err := d.updateDataByID(ctx, d.db, table) // delete cache _ = d.cache.Del(ctx, table.ID) return err } func (d *postDao) updateDataByID(ctx context.Context, db *gorm.DB, table *model.Post) error { if table.ID < 1 { return errors.New("id cannot be 0") } update := map[string]interface{}{} if table.PostType != 0 { update["post_type"] = table.PostType } if table.UserID != 0 { update["user_id"] = table.UserID } if table.Title != "" { update["title"] = table.Title } if table.Content != "" { update["content"] = table.Content } if table.ViewCount != 0 { update["view_count"] = table.ViewCount } if table.LikeCount != 0 { update["like_count"] = table.LikeCount } if table.CommentCount != 0 { update["comment_count"] = table.CommentCount } if table.CollectCount != 0 { update["collect_count"] = table.CollectCount } if table.ShareCount != 0 { update["share_count"] = table.ShareCount } if table.Longitude != 0 { update["longitude"] = table.Longitude } if table.Latitude != 0 { update["latitude"] = table.Latitude } if table.Position != "" { update["position"] = table.Position } if table.Visible != 0 { update["visible"] = table.Visible } if table.DelFlag != 0 { update["del_flag"] = table.DelFlag } return db.WithContext(ctx).Model(table).Updates(update).Error } // GetByID get a record by id func (d *postDao) GetByID(ctx context.Context, id uint64) (*model.Post, error) { record, err := d.cache.Get(ctx, id) if err == nil { return record, nil } if errors.Is(err, model.ErrCacheNotFound) { // for the same id, prevent high concurrent simultaneous access to mysql val, err, _ := d.sfg.Do(utils.Uint64ToStr(id), func() (interface{}, error) { //nolint table := &model.Post{} err = d.db.WithContext(ctx).Where("id = ?", id).First(table).Error if err != nil { // if data is empty, set not found cache to prevent cache penetration, default expiration time 10 minutes if errors.Is(err, model.ErrRecordNotFound) { err = d.cache.SetCacheWithNotFound(ctx, id) if err != nil { return nil, err } return nil, model.ErrRecordNotFound } return nil, err } // set cache err = d.cache.Set(ctx, id, table, cache.PostExpireTime) if err != nil { return nil, fmt.Errorf("cache.Set error: %v, id=%d", err, id) } return table, nil }) if err != nil { return nil, err } table, ok := val.(*model.Post) if !ok { return nil, model.ErrRecordNotFound } return table, nil } else if errors.Is(err, cacheBase.ErrPlaceholder) { return nil, model.ErrRecordNotFound } // fail fast, if cache error return, don't request to db return nil, err } // GetByCondition get a record by condition // query conditions: // // name: column name // exp: expressions, which default is "=", support =, !=, >, >=, <, <=, like, in // value: column value, if exp=in, multiple values are separated by commas // logic: logical type, defaults to and when value is null, only &(and), ||(or) // // example: find a male aged 20 // // condition = &query.Conditions{ // Columns: []query.Column{ // { // Name: "age", // Value: 20, // }, // { // Name: "gender", // Value: "male", // }, // } func (d *postDao) GetByCondition(ctx context.Context, c *query.Conditions) (*model.Post, error) { queryStr, args, err := c.ConvertToGorm() if err != nil { return nil, err } table := &model.Post{} err = d.db.WithContext(ctx).Where(queryStr, args...).First(table).Error if err != nil { return nil, err } return table, nil } // GetByIDs list of records by batch id func (d *postDao) GetByIDs(ctx context.Context, ids []uint64) (map[uint64]*model.Post, error) { itemMap, err := d.cache.MultiGet(ctx, ids) if err != nil { return nil, err } var missedIDs []uint64 for _, id := range ids { _, ok := itemMap[id] if !ok { missedIDs = append(missedIDs, id) continue } } // get missed data if len(missedIDs) > 0 { // find the id of an active placeholder, i.e. an id that does not exist in mysql var realMissedIDs []uint64 for _, id := range missedIDs { _, err = d.cache.Get(ctx, id) if errors.Is(err, cacheBase.ErrPlaceholder) { continue } realMissedIDs = append(realMissedIDs, id) } if len(realMissedIDs) > 0 { var missedData []*model.Post err = d.db.WithContext(ctx).Where("id IN (?)", realMissedIDs).Find(&missedData).Error if err != nil { return nil, err } if len(missedData) > 0 { for _, data := range missedData { itemMap[data.ID] = data } err = d.cache.MultiSet(ctx, missedData, cache.PostExpireTime) if err != nil { return nil, err } } else { for _, id := range realMissedIDs { _ = d.cache.SetCacheWithNotFound(ctx, id) } } } } return itemMap, nil } // GetByColumns get records by paging and column information, // Note: query performance degrades when table rows are very large because of the use of offset. // // params includes paging parameters and query parameters // paging parameters (required): // // page: page number, starting from 0 // size: lines per page // sort: sort fields, default is id backwards, you can add - sign before the field to indicate reverse order, no - sign to indicate ascending order, multiple fields separated by comma // // query parameters (not required): // // name: column name // exp: expressions, which default is "=", support =, !=, >, >=, <, <=, like, in // value: column value, if exp=in, multiple values are separated by commas // logic: logical type, defaults to and when value is null, only &(and), ||(or) // // example: search for a male over 20 years of age // // params = &query.Params{ // Page: 0, // Size: 20, // Columns: []query.Column{ // { // Name: "age", // Exp: ">", // Value: 20, // }, // { // Name: "gender", // Value: "male", // }, // } func (d *postDao) GetByColumns(ctx context.Context, params *query.Params) ([]*model.Post, int64, error) { queryStr, args, err := params.ConvertToGormConditions() if err != nil { return nil, 0, errors.New("query params error: " + err.Error()) } var total int64 if params.Sort != "ignore count" { // determine if count is required err = d.db.WithContext(ctx).Model(&model.Post{}).Select([]string{"id"}).Where(queryStr, args...).Count(&total).Error if err != nil { return nil, 0, err } if total == 0 { return nil, total, nil } } records := []*model.Post{} order, limit, offset := params.ConvertToPage() err = d.db.WithContext(ctx).Order(order).Limit(limit).Offset(offset).Where(queryStr, args...).Find(&records).Error if err != nil { return nil, 0, err } return records, total, err } // IncrViewCount increment view_count by 1 func (d *postDao) IncrViewCount(ctx context.Context, id uint64) error { err := d.db.Model(&model.Post{}).Where("id = ?", id). Update("view_count", gorm.Expr("view_count + ?", 1)).Error if err != nil { return err } // delete cache err = d.cache.Del(ctx, id) if err != nil { return err } return nil } // IncrShareCount increment share_count by 1 func (d *postDao) IncrShareCount(ctx context.Context, id uint64) error { err := d.db.Model(&model.Post{}).Where("id = ?", id). Update("share_count", gorm.Expr("share_count + ?", 1)).Error if err != nil { return err } // delete cache err = d.cache.Del(ctx, id) if err != nil { return err } return nil } // CreateByTx create a record in the database using the provided transaction func (d *postDao) CreateByTx(ctx context.Context, tx *gorm.DB, table *model.Post) (uint64, error) { err := tx.WithContext(ctx).Create(table).Error return table.ID, err } // DeleteByTx delete a record in the database using the provided transaction func (d *postDao) DeleteByTx(ctx context.Context, tx *gorm.DB, id uint64, delFlag int) error { update := map[string]interface{}{ "del_flag": delFlag, "deleted_at": time.Now(), } err := tx.WithContext(ctx).Model(&model.Post{}).Where("id = ?", id).Updates(update).Error if err != nil { return err } // delete cache _ = d.cache.Del(ctx, id) return nil } // UpdateByTx update a record by id in the database using the provided transaction func (d *postDao) UpdateByTx(ctx context.Context, tx *gorm.DB, table *model.Post) error { err := d.updateDataByID(ctx, tx, table) // delete cache _ = d.cache.Del(ctx, table.ID) return err } // IncrCommentCountByTx increment comment_count by 1 using the provided transaction func (d *postDao) IncrCommentCountByTx(ctx context.Context, tx *gorm.DB, id uint64) error { err := tx.WithContext(ctx).Model(&model.Post{}).Where("id = ?", id). Update("comment_count", gorm.Expr("comment_count + ?", 1)).Error if err != nil { return err } // delete cache err = d.cache.Del(ctx, id) if err != nil { return err } return nil } // DecrCommentCountByTx decrement comment_count by 1 using the provided transaction func (d *postDao) DecrCommentCountByTx(ctx context.Context, tx *gorm.DB, id uint64) error { err := tx.WithContext(ctx).Model(&model.Post{}).Where("id = ? AND comment_count > 0", id). Update("comment_count", gorm.Expr("comment_count - ?", 1)).Error if err != nil { return err } // delete cache err = d.cache.Del(ctx, id) if err != nil { return err } return nil } // IncrLikeCountByTx increment like_count by 1 using the provided transaction func (d *postDao) IncrLikeCountByTx(ctx context.Context, tx *gorm.DB, id uint64) error { err := tx.WithContext(ctx).Model(&model.Post{}).Where("id = ?", id). Update("like_count", gorm.Expr("like_count + ?", 1)).Error if err != nil { return err } // delete cache err = d.cache.Del(ctx, id) if err != nil { return err } return nil } // DecrLikeCountByTx decrement like_count by 1 using the provided transaction func (d *postDao) DecrLikeCountByTx(ctx context.Context, tx *gorm.DB, id uint64) error { err := tx.WithContext(ctx).Model(&model.Post{}).Where("id = ? AND like_count > 0", id). Update("like_count", gorm.Expr("like_count - ?", 1)).Error if err != nil { return err } // delete cache err = d.cache.Del(ctx, id) if err != nil { return err } return nil } // IncrCollectCountByTx increment collect_count by 1 using the provided transaction func (d *postDao) IncrCollectCountByTx(ctx context.Context, tx *gorm.DB, id uint64) error { err := tx.WithContext(ctx).Model(&model.Post{}).Where("id = ?", id). Update("collect_count", gorm.Expr("collect_count + ?", 1)).Error if err != nil { return err } // delete cache err = d.cache.Del(ctx, id) if err != nil { return err } return nil } // DecrCollectCountByTx decrement collect_count by 1 using the provided transaction func (d *postDao) DecrCollectCountByTx(ctx context.Context, tx *gorm.DB, id uint64) error { err := tx.WithContext(ctx).Model(&model.Post{}).Where("id = ? AND collect_count > 0", id). Update("collect_count", gorm.Expr("collect_count - ?", 1)).Error if err != nil { return err } // delete cache err = d.cache.Del(ctx, id) if err != nil { return err } return nil }