chore: mongodb unit test

This commit is contained in:
deepzz0
2021-04-27 10:21:27 +08:00
parent e2df642a46
commit bb40570053
5 changed files with 372 additions and 133 deletions

View File

@@ -3,6 +3,7 @@ package store
import (
"context"
"fmt"
"sort"
"time"
@@ -50,7 +51,9 @@ func (db *mongodb) Init(source string) (Store, error) {
}
// LoadInsertAccount 读取或创建账户
func (db *mongodb) LoadInsertAccount(ctx context.Context, acct *model.Account) (*model.Account, error) {
func (db *mongodb) LoadInsertAccount(ctx context.Context,
acct *model.Account) (*model.Account, error) {
collection := db.Database(mongoDBName).Collection(collectionAccount)
filter := bson.M{"username": config.Conf.BlogApp.Account.Username}
@@ -67,8 +70,26 @@ func (db *mongodb) LoadInsertAccount(ctx context.Context, acct *model.Account) (
return acct, err
}
// UpdateAccount 更新账户
func (db *mongodb) UpdateAccount(ctx context.Context, name string,
fields map[string]interface{}) error {
collection := db.Database(mongoDBName).Collection(collectionAccount)
filter := bson.M{"username": name}
params := bson.M{}
for k, v := range fields {
params[k] = v
}
update := bson.M{"$set": params}
_, err := collection.UpdateOne(ctx, filter, update)
return err
}
// LoadInsertBlogger 读取或创建博客
func (db *mongodb) LoadInsertBlogger(ctx context.Context, blogger *model.Blogger) (*model.Blogger, error) {
func (db *mongodb) LoadInsertBlogger(ctx context.Context,
blogger *model.Blogger) (*model.Blogger, error) {
collection := db.Database(mongoDBName).Collection(collectionBlogger)
filter := bson.M{}
@@ -85,76 +106,20 @@ func (db *mongodb) LoadInsertBlogger(ctx context.Context, blogger *model.Blogger
return blogger, err
}
// LoadAllArticle 读取所有文章
func (db *mongodb) LoadAllArticle(ctx context.Context) (model.SortedArticles, error) {
collection := db.Database(mongoDBName).Collection(collectionArticle)
// UpdateBlogger 更新博客
func (db *mongodb) UpdateBlogger(ctx context.Context,
fields map[string]interface{}) error {
filter := bson.M{"isdraft": false, "deletetime": bson.M{"$eq": time.Time{}}}
cur, err := collection.Find(ctx, filter)
if err != nil {
return nil, err
collection := db.Database(mongoDBName).Collection(collectionBlogger)
filter := bson.M{}
params := bson.M{}
for k, v := range fields {
params[k] = v
}
defer cur.Close(ctx)
var articles model.SortedArticles
for cur.Next(ctx) {
article := model.Article{}
err = cur.Decode(&article)
if err != nil {
return nil, err
}
articles = append(articles, &article)
}
sort.Sort(articles)
return articles, nil
}
// LoadTrashArticles 读取回收箱
func (db *mongodb) LoadTrashArticles(ctx context.Context) (model.SortedArticles, error) {
collection := db.Database(mongoDBName).Collection(collectionArticle)
filter := bson.M{"deletetime": bson.M{"$ne": time.Time{}}}
cur, err := collection.Find(ctx, filter)
if err != nil {
return nil, err
}
defer cur.Close(ctx)
var articles model.SortedArticles
for cur.Next(ctx) {
article := model.Article{}
err = cur.Decode(&article)
if err != nil {
return nil, err
}
articles = append(articles, &article)
}
sort.Sort(articles)
return articles, nil
}
// LoadDraftArticles 读取草稿箱
func (db *mongodb) LoadDraftArticles(ctx context.Context) (model.SortedArticles, error) {
collection := db.Database(mongoDBName).Collection(collectionArticle)
filter := bson.M{"isdraft": true}
cur, err := collection.Find(ctx, filter)
if err != nil {
return nil, err
}
defer cur.Close(ctx)
var articles model.SortedArticles
for cur.Next(ctx) {
article := model.Article{}
err = cur.Decode(&article)
if err != nil {
return nil, err
}
articles = append(articles, &article)
}
sort.Sort(articles)
return articles, nil
update := bson.M{"$set": params}
_, err := collection.UpdateOne(ctx, filter, update)
return err
}
// InsertSeries 创建专题
@@ -176,15 +141,45 @@ func (db *mongodb) RemoveSeries(ctx context.Context, id int) error {
}
// UpdateSeries 更新专题
func (db *mongodb) UpdateSeries(ctx context.Context, series *model.Series) error {
func (db *mongodb) UpdateSeries(ctx context.Context, id int,
fields map[string]interface{}) error {
collection := db.Database(mongoDBName).Collection(collectionSeries)
filter := bson.M{"id": series.ID}
update := bson.M{"$set": bson.M{"desc": series.Desc}}
filter := bson.M{"id": id}
params := bson.M{}
for k, v := range fields {
params[k] = v
}
update := bson.M{"$set": params}
_, err := collection.UpdateOne(ctx, filter, update)
return err
}
// LoadAllSeries 查询所有专题
func (db *mongodb) LoadAllSeries(ctx context.Context) (model.SortedSeries, error) {
collection := db.Database(mongoDBName).Collection(collectionSeries)
filter := bson.M{}
cur, err := collection.Find(ctx, filter)
if err != nil {
return nil, err
}
defer cur.Close(ctx)
var series model.SortedSeries
for cur.Next(ctx) {
obj := model.Series{}
err = cur.Decode(&obj)
if err != nil {
return nil, err
}
series = append(series, &obj)
}
sort.Sort(series)
return series, nil
}
// InsertArticle 创建文章
func (db *mongodb) InsertArticle(ctx context.Context, article *model.Article) error {
// 分配ID, 占位至起始id
@@ -203,6 +198,15 @@ func (db *mongodb) InsertArticle(ctx context.Context, article *model.Article) er
return err
}
// RemoveArticle 硬删除文章
func (db *mongodb) RemoveArticle(ctx context.Context, id int) error {
collection := db.Database(mongoDBName).Collection(collectionArticle)
filter := bson.M{"id": id}
_, err := collection.DeleteOne(ctx, filter)
return err
}
// DeleteArticle 软删除文章,放入回收箱
func (db *mongodb) DeleteArticle(ctx context.Context, id int) error {
collection := db.Database(mongoDBName).Collection(collectionArticle)
@@ -213,25 +217,34 @@ func (db *mongodb) DeleteArticle(ctx context.Context, id int) error {
return err
}
// RemoveArticle 硬删除文章
func (db *mongodb) RemoveArticle(ctx context.Context, id int) error {
collection := db.Database(mongoDBName).Collection(collectionArticle)
filter := bson.M{"id": id}
_, err := collection.DeleteOne(ctx, filter)
return err
}
// CleanArticles 清理回收站文章
func (db *mongodb) CleanArticles(ctx context.Context) error {
collection := db.Database(mongoDBName).Collection(collectionArticle)
exp := time.Now().Add(time.Duration(config.Conf.BlogApp.General.Trash) * time.Hour)
fmt.Println(exp)
filter := bson.M{"deletetime": bson.M{"$gt": time.Time{}, "$lt": exp}}
_, err := collection.DeleteMany(ctx, filter)
return err
}
// UpdateArticle 更新文章
func (db *mongodb) UpdateArticle(ctx context.Context, id int,
fields map[string]interface{}) error {
collection := db.Database(mongoDBName).Collection(collectionArticle)
filter := bson.M{"id": id}
params := bson.M{}
for k, v := range fields {
params[k] = v
}
update := bson.M{"$set": params}
fmt.Println(update)
_, err := collection.UpdateOne(ctx, filter, update)
return err
}
// RecoverArticle 恢复文章到草稿
func (db *mongodb) RecoverArticle(ctx context.Context, id int) error {
collection := db.Database(mongoDBName).Collection(collectionArticle)
@@ -242,45 +255,76 @@ func (db *mongodb) RecoverArticle(ctx context.Context, id int) error {
return err
}
// UpdateAccount 更新账户
func (db *mongodb) UpdateAccount(ctx context.Context, name string, fields map[string]interface{}) error {
collection := db.Database(mongoDBName).Collection(collectionAccount)
// LoadAllArticle 读取所有文章
func (db *mongodb) LoadAllArticle(ctx context.Context) (model.SortedArticles, error) {
collection := db.Database(mongoDBName).Collection(collectionArticle)
filter := bson.M{"username": name}
update := bson.M{}
for k, v := range fields {
update[k] = v
filter := bson.M{"isdraft": false, "deletetime": bson.M{"$eq": time.Time{}}}
cur, err := collection.Find(ctx, filter)
if err != nil {
return nil, err
}
_, err := collection.UpdateOne(ctx, filter, update)
return err
defer cur.Close(ctx)
var articles model.SortedArticles
for cur.Next(ctx) {
obj := model.Article{}
err = cur.Decode(&obj)
if err != nil {
return nil, err
}
articles = append(articles, &obj)
}
sort.Sort(articles)
return articles, nil
}
// UpdateBlogger 更新博客
func (db *mongodb) UpdateBlogger(ctx context.Context, fields map[string]interface{}) error {
collection := db.Database(mongoDBName).Collection(collectionBlogger)
// LoadTrashArticles 读取回收箱
func (db *mongodb) LoadTrashArticles(ctx context.Context) (model.SortedArticles, error) {
collection := db.Database(mongoDBName).Collection(collectionArticle)
filter := bson.M{}
update := bson.M{}
for k, v := range fields {
update[k] = v
filter := bson.M{"deletetime": bson.M{"$ne": time.Time{}}}
cur, err := collection.Find(ctx, filter)
if err != nil {
return nil, err
}
_, err := collection.UpdateOne(ctx, filter, update)
return err
defer cur.Close(ctx)
var articles model.SortedArticles
for cur.Next(ctx) {
obj := model.Article{}
err = cur.Decode(&obj)
if err != nil {
return nil, err
}
articles = append(articles, &obj)
}
sort.Sort(articles)
return articles, nil
}
// UpdateArticle 更新文章
func (db *mongodb) UpdateArticle(ctx context.Context, article *model.Article) error {
collection := db.Database(mongoDBName).Collection(collectionBlogger)
// LoadDraftArticles 读取草稿箱
func (db *mongodb) LoadDraftArticles(ctx context.Context) (model.SortedArticles, error) {
collection := db.Database(mongoDBName).Collection(collectionArticle)
filter := bson.M{"id": article.ID}
update := bson.M{"$set": bson.M{
"title": article.Title,
"content": article.Content,
"updatetime": article.UpdateTime,
"createtime": article.CreateTime,
}}
_, err := collection.UpdateOne(ctx, filter, update)
return err
filter := bson.M{"isdraft": true}
cur, err := collection.Find(ctx, filter)
if err != nil {
return nil, err
}
defer cur.Close(ctx)
var articles model.SortedArticles
for cur.Next(ctx) {
obj := model.Article{}
err = cur.Decode(&obj)
if err != nil {
return nil, err
}
articles = append(articles, &obj)
}
sort.Sort(articles)
return articles, nil
}
// counter counter

View File

@@ -1,7 +1,21 @@
// Package store provides ...
package store
var store Store
import (
"context"
"testing"
"time"
"github.com/eiblog/eiblog/pkg/model"
)
var (
store Store
acct *model.Account
blogger *model.Blogger
series *model.Series
article *model.Article
)
func init() {
var err error
@@ -9,4 +23,183 @@ func init() {
if err != nil {
panic(err)
}
// account
acct = &model.Account{
Username: "deepzz",
Password: "deepzz",
Email: "deepzz@example.com",
PhoneN: "12345678900",
Address: "address",
CreateTime: time.Now(),
}
// blogger
blogger = &model.Blogger{
BlogName: "Deepzz",
SubTitle: "不抛弃,不放弃",
BeiAn: "beian",
BTitle: "Deepzz's Blog",
Copyright: "Copyright",
}
// series
series = &model.Series{
Slug: "slug",
Name: "series name",
Desc: "series desc",
CreateTime: time.Now(),
}
// article
article = &model.Article{
Author: "deepzz",
Slug: "slug",
Title: "title",
Count: 0,
Content: "### count",
SerieID: 0,
Tags: "",
IsDraft: false,
UpdateTime: time.Now(),
CreateTime: time.Now(),
}
}
func TestLoadInsertAccount(t *testing.T) {
acct2, err := store.LoadInsertAccount(context.Background(), acct)
if err != nil {
t.Fatal(err)
}
t.Log(acct2)
t.Log(acct == acct2)
}
func TestUpdateAccount(t *testing.T) {
err := store.UpdateAccount(context.Background(), "deepzz", map[string]interface{}{
"phonn": "09876543211",
"loginua": "chrome",
"password": "123456",
"logintime": time.Now(),
"logouttime": time.Now(),
})
if err != nil {
t.Fatal(err)
}
}
func TestLoadInsertBlogger(t *testing.T) {
blogger2, err := store.LoadInsertBlogger(context.Background(), blogger)
if err != nil {
t.Fatal(err)
}
t.Log(blogger2)
t.Log(blogger == blogger2)
}
func TestUpdateBlogger(t *testing.T) {
err := store.UpdateBlogger(context.Background(), map[string]interface{}{
"blogname": "blogname",
})
if err != nil {
t.Fatal(err)
}
}
func TestInsertSeries(t *testing.T) {
err := store.InsertSeries(context.Background(), series)
if err != nil {
t.Fatal(err)
}
}
func TestRemoveSeries(t *testing.T) {
err := store.RemoveSeries(context.Background(), 1)
if err != nil {
t.Fatal(err)
}
}
func TestUpdateSeries(t *testing.T) {
err := store.UpdateSeries(context.Background(), 2, map[string]interface{}{
"desc": "update desc",
})
if err != nil {
t.Fatal(err)
}
}
func TestLoadAllSeries(t *testing.T) {
series, err := store.LoadAllSeries(context.Background())
if err != nil {
t.Fatal(err)
}
t.Logf("load all series: %d", len(series))
}
func TestInsertArticle(t *testing.T) {
err := store.InsertArticle(context.Background(), article)
if err != nil {
t.Fatal(err)
}
}
func TestRemoveArticle(t *testing.T) {
err := store.RemoveArticle(context.Background(), 11)
if err != nil {
t.Fatal(err)
}
}
func TestDeleteArticle(t *testing.T) {
err := store.DeleteArticle(context.Background(), 12)
if err != nil {
t.Fatal(err)
}
}
// TODO
func TestCleanArticles(t *testing.T) {
err := store.CleanArticles(context.Background())
if err != nil {
t.Fatal(err)
}
}
func TestUpdateArticle(t *testing.T) {
err := store.UpdateArticle(context.Background(), 13, map[string]interface{}{
"title": "new title",
"updatetime": time.Now(),
})
if err != nil {
t.Fatal(err)
}
}
func TestRecoverArticle(t *testing.T) {
err := store.RecoverArticle(context.Background(), 12)
if err != nil {
t.Fatal(err)
}
}
func TestLoadAllArticle(t *testing.T) {
articles, err := store.LoadAllArticle(context.Background())
if err != nil {
t.Fatal(err)
}
t.Logf("load all articles: %d", len(articles))
}
func TestLoadTrashArticles(t *testing.T) {
articles, err := store.LoadTrashArticles(context.Background())
if err != nil {
t.Fatal(err)
}
t.Logf("load trash articles: %d", len(articles))
}
func TestLoadDraftArticles(t *testing.T) {
articles, err := store.LoadDraftArticles(context.Background())
if err != nil {
t.Fatal(err)
}
t.Logf("load draft articles: %d", len(articles))
}

View File

@@ -19,39 +19,41 @@ var (
type Store interface {
// LoadInsertAccount 读取或创建账户
LoadInsertAccount(ctx context.Context, acct *model.Account) (*model.Account, error)
// UpdateAccount 更新账户
UpdateAccount(ctx context.Context, name string, fields map[string]interface{}) error
// LoadInsertBlogger 读取或创建博客
LoadInsertBlogger(ctx context.Context, blogger *model.Blogger) (*model.Blogger, error)
// LoadAllArticle 读取所有文章
LoadAllArticle(ctx context.Context) (model.SortedArticles, error)
// LoadTrashArticles 读取回收箱
LoadTrashArticles(ctx context.Context) (model.SortedArticles, error)
// LoadDraftArticles 读取草稿箱
LoadDraftArticles(ctx context.Context) (model.SortedArticles, error)
// UpdateBlogger 更新博客
UpdateBlogger(ctx context.Context, fields map[string]interface{}) error
// InsertSeries 创建专题
InsertSeries(ctx context.Context, series *model.Series) error
// RemoveSeries 删除专题
RemoveSeries(ctx context.Context, id int) error
// UpdateSeries 更新专题
UpdateSeries(ctx context.Context, series *model.Series) error
UpdateSeries(ctx context.Context, id int, fields map[string]interface{}) error
// LoadAllSeries 读取所有专题
LoadAllSeries(ctx context.Context) (model.SortedSeries, error)
// InsertArticle 创建文章
InsertArticle(ctx context.Context, article *model.Article) error
// DeleteArticle 软删除文章,放入回收箱
DeleteArticle(ctx context.Context, id int) error
// RemoveArticle 硬删除文章
RemoveArticle(ctx context.Context, id int) error
// DeleteArticle 软删除文章,放入回收箱
DeleteArticle(ctx context.Context, id int) error
// CleanArticles 清理回收站文章
CleanArticles(ctx context.Context) error
// UpdateArticle 更新文章
UpdateArticle(ctx context.Context, id int, fields map[string]interface{}) error
// RecoverArticle 恢复文章到草稿
RecoverArticle(ctx context.Context, id int) error
// UpdateAccount 更新账户
UpdateAccount(ctx context.Context, name string, fields map[string]interface{}) error
// UpdateBlogger 更新博客
UpdateBlogger(ctx context.Context, fields map[string]interface{}) error
// UpdateArticle 更新文章
UpdateArticle(ctx context.Context, article *model.Article) error
// LoadAllArticle 读取所有文章
LoadAllArticle(ctx context.Context) (model.SortedArticles, error)
// LoadTrashArticles 读取回收箱
LoadTrashArticles(ctx context.Context) (model.SortedArticles, error)
// LoadDraftArticles 读取草稿箱
LoadDraftArticles(ctx context.Context) (model.SortedArticles, error)
}
// Driver 存储驱动

View File

@@ -136,9 +136,9 @@ type Config struct {
func init() {
// compatibility linux and windows
var err error
WorkDir, err = os.Getwd()
if err != nil {
panic(err)
if gopath := os.Getenv("GOPATH"); gopath != "" {
WorkDir = filepath.Join(gopath, "src", "github.com",
"eiblog", "eiblog")
}
path := filepath.Join(WorkDir, "conf", "app.yml")

View File

@@ -11,7 +11,7 @@ type Article struct {
Title string `gorm:"not null"` // 标题
Count int `gorm:"not null"` // 评论数量
Content string `gorm:"not null"` // markdown内容
SerieID int32 `gorm:"not null"` // 专题ID
SerieID int `gorm:"not null"` // 专题ID
Tags string `gorm:"not null"` // tag,以逗号隔开
IsDraft bool `gorm:"not null"` // 是否是草稿