diff --git a/pkg/cache/store/mongodb.go b/pkg/cache/store/mongodb.go index 20988df..1d05601 100644 --- a/pkg/cache/store/mongodb.go +++ b/pkg/cache/store/mongodb.go @@ -3,6 +3,7 @@ package store import ( "context" + "fmt" "sort" "time" @@ -50,7 +51,9 @@ func (db *mongodb) Init(source string) (Store, error) { } // LoadInsertAccount 读取或创建账户 -func (db *mongodb) LoadInsertAccount(ctx context.Context, acct *model.Account) (*model.Account, error) { +func (db *mongodb) LoadInsertAccount(ctx context.Context, + acct *model.Account) (*model.Account, error) { + collection := db.Database(mongoDBName).Collection(collectionAccount) filter := bson.M{"username": config.Conf.BlogApp.Account.Username} @@ -67,8 +70,26 @@ func (db *mongodb) LoadInsertAccount(ctx context.Context, acct *model.Account) ( return acct, err } +// UpdateAccount 更新账户 +func (db *mongodb) UpdateAccount(ctx context.Context, name string, + fields map[string]interface{}) error { + + collection := db.Database(mongoDBName).Collection(collectionAccount) + + filter := bson.M{"username": name} + params := bson.M{} + for k, v := range fields { + params[k] = v + } + update := bson.M{"$set": params} + _, err := collection.UpdateOne(ctx, filter, update) + return err +} + // LoadInsertBlogger 读取或创建博客 -func (db *mongodb) LoadInsertBlogger(ctx context.Context, blogger *model.Blogger) (*model.Blogger, error) { +func (db *mongodb) LoadInsertBlogger(ctx context.Context, + blogger *model.Blogger) (*model.Blogger, error) { + collection := db.Database(mongoDBName).Collection(collectionBlogger) filter := bson.M{} @@ -85,76 +106,20 @@ func (db *mongodb) LoadInsertBlogger(ctx context.Context, blogger *model.Blogger return blogger, err } -// LoadAllArticle 读取所有文章 -func (db *mongodb) LoadAllArticle(ctx context.Context) (model.SortedArticles, error) { - collection := db.Database(mongoDBName).Collection(collectionArticle) +// UpdateBlogger 更新博客 +func (db *mongodb) UpdateBlogger(ctx context.Context, + fields map[string]interface{}) error { - filter := bson.M{"isdraft": false, "deletetime": bson.M{"$eq": time.Time{}}} - cur, err := collection.Find(ctx, filter) - if err != nil { - return nil, err + collection := db.Database(mongoDBName).Collection(collectionBlogger) + + filter := bson.M{} + params := bson.M{} + for k, v := range fields { + params[k] = v } - defer cur.Close(ctx) - - var articles model.SortedArticles - for cur.Next(ctx) { - article := model.Article{} - err = cur.Decode(&article) - if err != nil { - return nil, err - } - articles = append(articles, &article) - } - sort.Sort(articles) - return articles, nil -} - -// LoadTrashArticles 读取回收箱 -func (db *mongodb) LoadTrashArticles(ctx context.Context) (model.SortedArticles, error) { - collection := db.Database(mongoDBName).Collection(collectionArticle) - - filter := bson.M{"deletetime": bson.M{"$ne": time.Time{}}} - cur, err := collection.Find(ctx, filter) - if err != nil { - return nil, err - } - defer cur.Close(ctx) - - var articles model.SortedArticles - for cur.Next(ctx) { - article := model.Article{} - err = cur.Decode(&article) - if err != nil { - return nil, err - } - articles = append(articles, &article) - } - sort.Sort(articles) - return articles, nil -} - -// LoadDraftArticles 读取草稿箱 -func (db *mongodb) LoadDraftArticles(ctx context.Context) (model.SortedArticles, error) { - collection := db.Database(mongoDBName).Collection(collectionArticle) - - filter := bson.M{"isdraft": true} - cur, err := collection.Find(ctx, filter) - if err != nil { - return nil, err - } - defer cur.Close(ctx) - - var articles model.SortedArticles - for cur.Next(ctx) { - article := model.Article{} - err = cur.Decode(&article) - if err != nil { - return nil, err - } - articles = append(articles, &article) - } - sort.Sort(articles) - return articles, nil + update := bson.M{"$set": params} + _, err := collection.UpdateOne(ctx, filter, update) + return err } // InsertSeries 创建专题 @@ -176,15 +141,45 @@ func (db *mongodb) RemoveSeries(ctx context.Context, id int) error { } // UpdateSeries 更新专题 -func (db *mongodb) UpdateSeries(ctx context.Context, series *model.Series) error { +func (db *mongodb) UpdateSeries(ctx context.Context, id int, + fields map[string]interface{}) error { + collection := db.Database(mongoDBName).Collection(collectionSeries) - filter := bson.M{"id": series.ID} - update := bson.M{"$set": bson.M{"desc": series.Desc}} + filter := bson.M{"id": id} + params := bson.M{} + for k, v := range fields { + params[k] = v + } + update := bson.M{"$set": params} _, err := collection.UpdateOne(ctx, filter, update) return err } +// LoadAllSeries 查询所有专题 +func (db *mongodb) LoadAllSeries(ctx context.Context) (model.SortedSeries, error) { + collection := db.Database(mongoDBName).Collection(collectionSeries) + + filter := bson.M{} + cur, err := collection.Find(ctx, filter) + if err != nil { + return nil, err + } + defer cur.Close(ctx) + + var series model.SortedSeries + for cur.Next(ctx) { + obj := model.Series{} + err = cur.Decode(&obj) + if err != nil { + return nil, err + } + series = append(series, &obj) + } + sort.Sort(series) + return series, nil +} + // InsertArticle 创建文章 func (db *mongodb) InsertArticle(ctx context.Context, article *model.Article) error { // 分配ID, 占位至起始id @@ -203,6 +198,15 @@ func (db *mongodb) InsertArticle(ctx context.Context, article *model.Article) er return err } +// RemoveArticle 硬删除文章 +func (db *mongodb) RemoveArticle(ctx context.Context, id int) error { + collection := db.Database(mongoDBName).Collection(collectionArticle) + + filter := bson.M{"id": id} + _, err := collection.DeleteOne(ctx, filter) + return err +} + // DeleteArticle 软删除文章,放入回收箱 func (db *mongodb) DeleteArticle(ctx context.Context, id int) error { collection := db.Database(mongoDBName).Collection(collectionArticle) @@ -213,25 +217,34 @@ func (db *mongodb) DeleteArticle(ctx context.Context, id int) error { return err } -// RemoveArticle 硬删除文章 -func (db *mongodb) RemoveArticle(ctx context.Context, id int) error { - collection := db.Database(mongoDBName).Collection(collectionArticle) - - filter := bson.M{"id": id} - _, err := collection.DeleteOne(ctx, filter) - return err -} - // CleanArticles 清理回收站文章 func (db *mongodb) CleanArticles(ctx context.Context) error { collection := db.Database(mongoDBName).Collection(collectionArticle) exp := time.Now().Add(time.Duration(config.Conf.BlogApp.General.Trash) * time.Hour) + fmt.Println(exp) filter := bson.M{"deletetime": bson.M{"$gt": time.Time{}, "$lt": exp}} _, err := collection.DeleteMany(ctx, filter) return err } +// UpdateArticle 更新文章 +func (db *mongodb) UpdateArticle(ctx context.Context, id int, + fields map[string]interface{}) error { + + collection := db.Database(mongoDBName).Collection(collectionArticle) + + filter := bson.M{"id": id} + params := bson.M{} + for k, v := range fields { + params[k] = v + } + update := bson.M{"$set": params} + fmt.Println(update) + _, err := collection.UpdateOne(ctx, filter, update) + return err +} + // RecoverArticle 恢复文章到草稿 func (db *mongodb) RecoverArticle(ctx context.Context, id int) error { collection := db.Database(mongoDBName).Collection(collectionArticle) @@ -242,45 +255,76 @@ func (db *mongodb) RecoverArticle(ctx context.Context, id int) error { return err } -// UpdateAccount 更新账户 -func (db *mongodb) UpdateAccount(ctx context.Context, name string, fields map[string]interface{}) error { - collection := db.Database(mongoDBName).Collection(collectionAccount) +// LoadAllArticle 读取所有文章 +func (db *mongodb) LoadAllArticle(ctx context.Context) (model.SortedArticles, error) { + collection := db.Database(mongoDBName).Collection(collectionArticle) - filter := bson.M{"username": name} - update := bson.M{} - for k, v := range fields { - update[k] = v + filter := bson.M{"isdraft": false, "deletetime": bson.M{"$eq": time.Time{}}} + cur, err := collection.Find(ctx, filter) + if err != nil { + return nil, err } - _, err := collection.UpdateOne(ctx, filter, update) - return err + defer cur.Close(ctx) + + var articles model.SortedArticles + for cur.Next(ctx) { + obj := model.Article{} + err = cur.Decode(&obj) + if err != nil { + return nil, err + } + articles = append(articles, &obj) + } + sort.Sort(articles) + return articles, nil } -// UpdateBlogger 更新博客 -func (db *mongodb) UpdateBlogger(ctx context.Context, fields map[string]interface{}) error { - collection := db.Database(mongoDBName).Collection(collectionBlogger) +// LoadTrashArticles 读取回收箱 +func (db *mongodb) LoadTrashArticles(ctx context.Context) (model.SortedArticles, error) { + collection := db.Database(mongoDBName).Collection(collectionArticle) - filter := bson.M{} - update := bson.M{} - for k, v := range fields { - update[k] = v + filter := bson.M{"deletetime": bson.M{"$ne": time.Time{}}} + cur, err := collection.Find(ctx, filter) + if err != nil { + return nil, err } - _, err := collection.UpdateOne(ctx, filter, update) - return err + defer cur.Close(ctx) + + var articles model.SortedArticles + for cur.Next(ctx) { + obj := model.Article{} + err = cur.Decode(&obj) + if err != nil { + return nil, err + } + articles = append(articles, &obj) + } + sort.Sort(articles) + return articles, nil } -// UpdateArticle 更新文章 -func (db *mongodb) UpdateArticle(ctx context.Context, article *model.Article) error { - collection := db.Database(mongoDBName).Collection(collectionBlogger) +// LoadDraftArticles 读取草稿箱 +func (db *mongodb) LoadDraftArticles(ctx context.Context) (model.SortedArticles, error) { + collection := db.Database(mongoDBName).Collection(collectionArticle) - filter := bson.M{"id": article.ID} - update := bson.M{"$set": bson.M{ - "title": article.Title, - "content": article.Content, - "updatetime": article.UpdateTime, - "createtime": article.CreateTime, - }} - _, err := collection.UpdateOne(ctx, filter, update) - return err + filter := bson.M{"isdraft": true} + cur, err := collection.Find(ctx, filter) + if err != nil { + return nil, err + } + defer cur.Close(ctx) + + var articles model.SortedArticles + for cur.Next(ctx) { + obj := model.Article{} + err = cur.Decode(&obj) + if err != nil { + return nil, err + } + articles = append(articles, &obj) + } + sort.Sort(articles) + return articles, nil } // counter counter diff --git a/pkg/cache/store/mongodb_test.go b/pkg/cache/store/mongodb_test.go index fe3c61f..33fc1ca 100644 --- a/pkg/cache/store/mongodb_test.go +++ b/pkg/cache/store/mongodb_test.go @@ -1,7 +1,21 @@ // Package store provides ... package store -var store Store +import ( + "context" + "testing" + "time" + + "github.com/eiblog/eiblog/pkg/model" +) + +var ( + store Store + acct *model.Account + blogger *model.Blogger + series *model.Series + article *model.Article +) func init() { var err error @@ -9,4 +23,183 @@ func init() { if err != nil { panic(err) } + // account + acct = &model.Account{ + Username: "deepzz", + Password: "deepzz", + Email: "deepzz@example.com", + PhoneN: "12345678900", + Address: "address", + CreateTime: time.Now(), + } + // blogger + blogger = &model.Blogger{ + BlogName: "Deepzz", + SubTitle: "不抛弃,不放弃", + BeiAn: "beian", + BTitle: "Deepzz's Blog", + Copyright: "Copyright", + } + // series + series = &model.Series{ + Slug: "slug", + Name: "series name", + Desc: "series desc", + CreateTime: time.Now(), + } + // article + article = &model.Article{ + Author: "deepzz", + Slug: "slug", + Title: "title", + Count: 0, + Content: "### count", + SerieID: 0, + Tags: "", + IsDraft: false, + + UpdateTime: time.Now(), + CreateTime: time.Now(), + } +} + +func TestLoadInsertAccount(t *testing.T) { + acct2, err := store.LoadInsertAccount(context.Background(), acct) + if err != nil { + t.Fatal(err) + } + t.Log(acct2) + t.Log(acct == acct2) +} + +func TestUpdateAccount(t *testing.T) { + err := store.UpdateAccount(context.Background(), "deepzz", map[string]interface{}{ + "phonn": "09876543211", + "loginua": "chrome", + "password": "123456", + "logintime": time.Now(), + "logouttime": time.Now(), + }) + if err != nil { + t.Fatal(err) + } +} + +func TestLoadInsertBlogger(t *testing.T) { + blogger2, err := store.LoadInsertBlogger(context.Background(), blogger) + if err != nil { + t.Fatal(err) + } + t.Log(blogger2) + t.Log(blogger == blogger2) +} + +func TestUpdateBlogger(t *testing.T) { + err := store.UpdateBlogger(context.Background(), map[string]interface{}{ + "blogname": "blogname", + }) + if err != nil { + t.Fatal(err) + } +} + +func TestInsertSeries(t *testing.T) { + err := store.InsertSeries(context.Background(), series) + if err != nil { + t.Fatal(err) + } +} + +func TestRemoveSeries(t *testing.T) { + err := store.RemoveSeries(context.Background(), 1) + if err != nil { + t.Fatal(err) + } +} + +func TestUpdateSeries(t *testing.T) { + err := store.UpdateSeries(context.Background(), 2, map[string]interface{}{ + "desc": "update desc", + }) + if err != nil { + t.Fatal(err) + } +} + +func TestLoadAllSeries(t *testing.T) { + series, err := store.LoadAllSeries(context.Background()) + if err != nil { + t.Fatal(err) + } + t.Logf("load all series: %d", len(series)) +} + +func TestInsertArticle(t *testing.T) { + err := store.InsertArticle(context.Background(), article) + if err != nil { + t.Fatal(err) + } +} + +func TestRemoveArticle(t *testing.T) { + err := store.RemoveArticle(context.Background(), 11) + if err != nil { + t.Fatal(err) + } +} + +func TestDeleteArticle(t *testing.T) { + err := store.DeleteArticle(context.Background(), 12) + if err != nil { + t.Fatal(err) + } +} + +// TODO +func TestCleanArticles(t *testing.T) { + err := store.CleanArticles(context.Background()) + if err != nil { + t.Fatal(err) + } +} + +func TestUpdateArticle(t *testing.T) { + err := store.UpdateArticle(context.Background(), 13, map[string]interface{}{ + "title": "new title", + "updatetime": time.Now(), + }) + if err != nil { + t.Fatal(err) + } +} + +func TestRecoverArticle(t *testing.T) { + err := store.RecoverArticle(context.Background(), 12) + if err != nil { + t.Fatal(err) + } +} + +func TestLoadAllArticle(t *testing.T) { + articles, err := store.LoadAllArticle(context.Background()) + if err != nil { + t.Fatal(err) + } + t.Logf("load all articles: %d", len(articles)) +} + +func TestLoadTrashArticles(t *testing.T) { + articles, err := store.LoadTrashArticles(context.Background()) + if err != nil { + t.Fatal(err) + } + t.Logf("load trash articles: %d", len(articles)) +} + +func TestLoadDraftArticles(t *testing.T) { + articles, err := store.LoadDraftArticles(context.Background()) + if err != nil { + t.Fatal(err) + } + t.Logf("load draft articles: %d", len(articles)) } diff --git a/pkg/cache/store/store.go b/pkg/cache/store/store.go index 5815ae7..1c5a098 100644 --- a/pkg/cache/store/store.go +++ b/pkg/cache/store/store.go @@ -19,39 +19,41 @@ var ( type Store interface { // LoadInsertAccount 读取或创建账户 LoadInsertAccount(ctx context.Context, acct *model.Account) (*model.Account, error) + // UpdateAccount 更新账户 + UpdateAccount(ctx context.Context, name string, fields map[string]interface{}) error + // LoadInsertBlogger 读取或创建博客 LoadInsertBlogger(ctx context.Context, blogger *model.Blogger) (*model.Blogger, error) - // LoadAllArticle 读取所有文章 - LoadAllArticle(ctx context.Context) (model.SortedArticles, error) - // LoadTrashArticles 读取回收箱 - LoadTrashArticles(ctx context.Context) (model.SortedArticles, error) - // LoadDraftArticles 读取草稿箱 - LoadDraftArticles(ctx context.Context) (model.SortedArticles, error) + // UpdateBlogger 更新博客 + UpdateBlogger(ctx context.Context, fields map[string]interface{}) error // InsertSeries 创建专题 InsertSeries(ctx context.Context, series *model.Series) error // RemoveSeries 删除专题 RemoveSeries(ctx context.Context, id int) error // UpdateSeries 更新专题 - UpdateSeries(ctx context.Context, series *model.Series) error + UpdateSeries(ctx context.Context, id int, fields map[string]interface{}) error + // LoadAllSeries 读取所有专题 + LoadAllSeries(ctx context.Context) (model.SortedSeries, error) // InsertArticle 创建文章 InsertArticle(ctx context.Context, article *model.Article) error - // DeleteArticle 软删除文章,放入回收箱 - DeleteArticle(ctx context.Context, id int) error // RemoveArticle 硬删除文章 RemoveArticle(ctx context.Context, id int) error + // DeleteArticle 软删除文章,放入回收箱 + DeleteArticle(ctx context.Context, id int) error // CleanArticles 清理回收站文章 CleanArticles(ctx context.Context) error + // UpdateArticle 更新文章 + UpdateArticle(ctx context.Context, id int, fields map[string]interface{}) error // RecoverArticle 恢复文章到草稿 RecoverArticle(ctx context.Context, id int) error - - // UpdateAccount 更新账户 - UpdateAccount(ctx context.Context, name string, fields map[string]interface{}) error - // UpdateBlogger 更新博客 - UpdateBlogger(ctx context.Context, fields map[string]interface{}) error - // UpdateArticle 更新文章 - UpdateArticle(ctx context.Context, article *model.Article) error + // LoadAllArticle 读取所有文章 + LoadAllArticle(ctx context.Context) (model.SortedArticles, error) + // LoadTrashArticles 读取回收箱 + LoadTrashArticles(ctx context.Context) (model.SortedArticles, error) + // LoadDraftArticles 读取草稿箱 + LoadDraftArticles(ctx context.Context) (model.SortedArticles, error) } // Driver 存储驱动 diff --git a/pkg/config/config.go b/pkg/config/config.go index f1e6060..7337b14 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -136,9 +136,9 @@ type Config struct { func init() { // compatibility linux and windows var err error - WorkDir, err = os.Getwd() - if err != nil { - panic(err) + if gopath := os.Getenv("GOPATH"); gopath != "" { + WorkDir = filepath.Join(gopath, "src", "github.com", + "eiblog", "eiblog") } path := filepath.Join(WorkDir, "conf", "app.yml") diff --git a/pkg/model/article.go b/pkg/model/article.go index baf79ea..d08f708 100644 --- a/pkg/model/article.go +++ b/pkg/model/article.go @@ -11,7 +11,7 @@ type Article struct { Title string `gorm:"not null"` // 标题 Count int `gorm:"not null"` // 评论数量 Content string `gorm:"not null"` // markdown内容 - SerieID int32 `gorm:"not null"` // 专题ID + SerieID int `gorm:"not null"` // 专题ID Tags string `gorm:"not null"` // tag,以逗号隔开 IsDraft bool `gorm:"not null"` // 是否是草稿