Compare commits

...

33 Commits

Author SHA1 Message Date
henry.chen
fb66b6871e release v1.4.3 2018-02-09 16:15:34 +08:00
henry.chen
5ae76f243e fixed #6,发布文章异步提交,随机 session key等 2018-02-09 13:50:34 +08:00
deepzz0
051b034e51 1. 修复编辑专题:按钮显示"新增专题"错误 2. 编辑专题链接移动到专题名称 2018-02-04 12:39:35 +08:00
Deepzz
27439ecc71 Update install.md 2018-02-01 21:24:00 +08:00
henry.chen
d02c838447 fix archive page bug 2018-01-25 23:09:59 +08:00
Deepzz
d17acf5325 Update amusing.md 2018-01-17 19:16:08 +08:00
deepzz0
b278ca377f update changelog.md 2018-01-14 13:53:32 +08:00
deepzz0
93131441e4 update 2018-01-14 13:38:26 +08:00
deepzz0
ddcc6c2d2e auto archiving by year when the month great than 12 2018-01-14 13:12:59 +08:00
henry.chen
ef63ae9598 fix page archive unable auto update 2018-01-14 02:40:11 +08:00
henry.chen
2ed9db5c7b code logical adjust 2018-01-14 02:02:12 +08:00
deepzz0
06a12bc6f9 update vendor 2018-01-13 18:23:03 +08:00
deepzz0
6524b45751 adjust the code 2018-01-13 18:19:54 +08:00
henry.chen
ceb9e2690b 添加 disqus thread 创建接口 2018-01-13 02:56:35 +08:00
deepzz0
405fbaf24f fix can delete blogroll and about page & fix delete and readd article bug 2018-01-07 20:30:14 +08:00
deepzz0
3245c0e0d3 update vendor & fix upload file url & fix judge file type 2018-01-06 23:24:27 +08:00
Deepzz
badc62e3f0 Update README.md 2018-01-06 11:47:20 +08:00
deepzz0
a5561f257b comment docker-compose.yml backup 2018-01-02 20:21:45 +08:00
deepzz0
eb37b83ebd update README.md 2018-01-01 19:03:16 +08:00
deepzz0
b2fab703fc Merge branch 'master' of github.com:eiblog/eiblog 2018-01-01 18:59:30 +08:00
deepzz0
37deb390d9 docker-compose.yml 添加数据库备份镜像 2018-01-01 18:59:10 +08:00
Deepzz
6fa5088352 更新 ct 服务器地址 2017-12-30 13:50:19 +08:00
Deepzz
e023a33786 Update app.yml
移除 disqus 评论及 Google 分析私人信息配置
2017-12-08 12:19:01 +08:00
henry.chen
6f818c4b5d fix search.html <no value> 2017-12-05 15:08:32 +08:00
henry.chen
9ad22fb2d9 don't use dynamic link: CGO_ENABLED=0 2017-11-30 10:04:54 +08:00
henry.chen
fc37d5e093 fix page:admin/write-post autocomplete tag 2017-11-29 16:17:58 +08:00
henry.chen
61024bfebd update 2017-11-27 18:34:03 +08:00
henry.chen
f20c4a6063 fix docker image: exec user process caused "no such file or directory" 2017-11-27 18:17:41 +08:00
henry.chen
c24e6bf7bd update .travis.yml 2017-11-27 16:43:30 +08:00
henry.chen
ade94168d3 update .travis.yml 2017-11-27 16:32:39 +08:00
henry.chen
552d010650 fix background turn page 2017-11-27 15:21:28 +08:00
deepzz0
1c3106cbb0 update vendor 2017-11-24 22:58:59 +08:00
henry.chen
168937f1b2 fix gopkg.in/mgo import conflict 2017-11-23 13:57:20 +08:00
255 changed files with 6374 additions and 1370 deletions

View File

@@ -1,36 +1,26 @@
sudo: required # 超级权限 sudo: required # 超级权限
dist: trusty # 在ubuntu:trusty dist: trusty # 在ubuntu:trusty
language: go # 声明构建语言环境 language: go # 声明构建语言环境
go: # 只构建最新版本 go: # 只构建最新版本
- tip - tip
services: # docker环境 services: # docker环境
- docker - docker
branches: # 限定项目分支 branches: # 限定项目分支
only: only:
- /^v[0-9](\.[0-9]){2}(-rc[1-9])?$/ - /^v[0-9](\.[0-9]){2}(-rc[1-9])?$/
install: install:
- curl https://glide.sh/get | sh # 安装glide包管理 - curl https://glide.sh/get | sh # 安装glide包管理
script: script:
- glide up - glide up
- GOOS=linux GOARCH=amd64 go build # 编译版本 - GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build # 编译版本
- docker build -t registry.cn-hangzhou.aliyuncs.com/deepzz/eiblog:$TRAVIS_TAG . # 构建镜像 - docker build -t registry.cn-hangzhou.aliyuncs.com/deepzz/eiblog . # 构建镜像
after_success: after_success:
# - if [ "$TRAVIS_BRANCH" =~ ^v[0-9](\.[0-9])+.*$ ]; then - docker login -u="$DOCKER_USERNAME" -p="$DOCKER_PASSWORD" registry.cn-hangzhou.aliyuncs.com
# docker login -u="$DOCKER_USERNAME" -p="$DOCKER_PASSWORD" registry.cn-hangzhou.aliyuncs.com; - docker push registry.cn-hangzhou.aliyuncs.com/deepzz/eiblog
# docker push registry.cn-hangzhou.aliyuncs.com/deepzz/eiblog; - docker tag registry.cn-hangzhou.aliyuncs.com/deepzz/eiblog registry.cn-hangzhou.aliyuncs.com/deepzz/eiblog:$TRAVIS_TAG
# fi
- docker push registry.cn-hangzhou.aliyuncs.com/deepzz/eiblog:$TRAVIS_TAG - docker push registry.cn-hangzhou.aliyuncs.com/deepzz/eiblog:$TRAVIS_TAG
before_deploy: before_deploy:
- ./dist.sh - ./dist.sh
deploy: deploy:
provider: releases provider: releases
api_key: api_key:

View File

@@ -1,5 +1,44 @@
# Eiblog Changelog # Eiblog Changelog
## v1.4.2 (2018-02-09)
* 修复博客初始化后about 页面不能够评论 #6
* 修复编辑专题,按钮显示“添加专题”错误
* 优化“添加文章”从同步改为异步推送feedesdisqus。速度显著提升
* **重要*)头像图片从 avatar.jpg 改为 avatar.png透明
* docker-compose.yml mongodb 去掉端口映射,防止用户将端口暴露至外网
* session key 每次重启随机生成等一些细节的修复
## v1.4.1 (2018-01-14)
* 修复创建新文章disqus 不收录bug
* 修复创建新文章归档页面不刷新bug
* 修复能够删除关于页面和友情链接页面bug
* 修复重复添加文章错误
* 注释掉 docker-compose.yml 自动备份内容,请自行解开
* 添加当月数大于12归档页面使用年份归档
* 优化代码逻辑
## v1.4.0 (2018-01-01)
* fix 搜索页面 bug
* CGO_ENABLED=0 关闭 cgo
* 更新Makefile ct log 服务器
* 数据库数据终于可以备份了
## v1.3.4 (2017-11-29)
* fix page:admin/write-post autocomplete tag
## v1.3.3 (2017-11-27)
* fix docker image: exec user process caused "no such file or directory"
## v1.3.2 (2017-11-17)
* 修复文章自动保存引起的发布文章不成功的bug
## v1.3.1 (2017-11-05)
* 修复调整 关于、友情链接 创建时间出现文章乱序
* 修复评论时间计算错误
* 调整acme文件验证路径
* 更改七牛SDK包为github包。
* 调整七牛配置文件名称app.yml: kodo -> qiniuname -> bucket请提高静态文件版本 staticversion
## v1.3.0 (2017-07-13) ## v1.3.0 (2017-07-13)
* 更改 app.yml 配置项,将大部分配置归在 general 常规配置下。注意,部署时请先更新 app.yml。 * 更改 app.yml 配置项,将大部分配置归在 general 常规配置下。注意,部署时请先更新 app.yml。
* 静态文件采用动态渲染,即用户不再需要管理 view、static 目录。 * 静态文件采用动态渲染,即用户不再需要管理 view、static 目录。

View File

@@ -7,4 +7,4 @@ ADD static/tzdata/Shanghai /etc/localtime
COPY . /eiblog COPY . /eiblog
EXPOSE 9000 EXPOSE 9000
WORKDIR /eiblog WORKDIR /eiblog
CMD ["./eiblog"] CMD ["sh","-c","/eiblog/eiblog"]

View File

@@ -40,7 +40,7 @@ gencert:makedir
@echo "generate rsa cert..." @echo "generate rsa cert..."
@$(acme.sh) --force --issue --dns dns_ali $(sans) --log \ @$(acme.sh) --force --issue --dns dns_ali $(sans) --log \
--renew-hook "ct-submit ctlog.api.venafi.com < $(config)/ssl/domain.rsa.pem > $(config)/scts/rsa/venafi.sct \ --renew-hook "ct-submit ctlog-gen2.api.venafi.com < $(config)/ssl/domain.rsa.pem > $(config)/scts/rsa/venafi.sct \
&& ct-submit ctlog.wosign.com < $(config)/ssl/domain.rsa.pem > $(config)/scts/rsa/wosign.sct" && ct-submit ctlog.wosign.com < $(config)/ssl/domain.rsa.pem > $(config)/scts/rsa/wosign.sct"
@$(acme.sh) --install-cert -d $(cn) \ @$(acme.sh) --install-cert -d $(cn) \
--key-file $(config)/ssl/domain.rsa.key \ --key-file $(config)/ssl/domain.rsa.key \
@@ -49,7 +49,7 @@ gencert:makedir
@echo "generate ecc cert..." @echo "generate ecc cert..."
@$(acme.sh) --force --issue --dns dns_ali $(sans) -k ec-256 --log \ @$(acme.sh) --force --issue --dns dns_ali $(sans) -k ec-256 --log \
--renew-hook "ct-submit ctlog.api.venafi.com < $(config)/ssl/domain.ecc.pem > $(config)/scts/ecc/venafi.sct \ --renew-hook "ct-submit ctlog-gen2.api.venafi.com < $(config)/ssl/domain.ecc.pem > $(config)/scts/ecc/venafi.sct \
&& ct-submit ctlog.wosign.com < $(config)/ssl/domain.ecc.pem > $(config)/scts/ecc/wosign.sct" && ct-submit ctlog.wosign.com < $(config)/ssl/domain.ecc.pem > $(config)/scts/ecc/wosign.sct"
@$(acme.sh) --install-cert -d $(cn) --ecc \ @$(acme.sh) --install-cert -d $(cn) --ecc \
--key-file $(config)/ssl/domain.ecc.key \ --key-file $(config)/ssl/domain.ecc.key \

View File

@@ -74,6 +74,7 @@
8. 开源 `Typecho` 完整后台系统,全功能 `markdown` 编辑器,让你体验什么是简洁清爽。 8. 开源 `Typecho` 完整后台系统,全功能 `markdown` 编辑器,让你体验什么是简洁清爽。
9. 博客后台直接对接 `七牛 SDK`,实现后台上传文件和删除文件的简单功能。 9. 博客后台直接对接 `七牛 SDK`,实现后台上传文件和删除文件的简单功能。
10. 采用 `elasticsearch` 作为站内搜索,添加 `google opensearch` 功能,搜索更加自然。 10. 采用 `elasticsearch` 作为站内搜索,添加 `google opensearch` 功能,搜索更加自然。
11. 自动备份数据库数据到七牛云。
### 文档 ### 文档
@@ -81,10 +82,10 @@
* [安装部署](https://github.com/eiblog/eiblog/blob/master/docs/install.md) * [安装部署](https://github.com/eiblog/eiblog/blob/master/docs/install.md)
* [写作需知](https://github.com/eiblog/eiblog/blob/master/docs/writing.md) * [写作需知](https://github.com/eiblog/eiblog/blob/master/docs/writing.md)
* [好玩的功能](https://github.com/eiblog/eiblog/blob/master/docs/amusing.md) * [好玩的功能](https://github.com/eiblog/eiblog/blob/master/docs/amusing.md)
* [关于备份](https://github.com/eiblog/backup)
### 成功搭建者博客 ### 成功搭建者博客
* [https://razeen.me](https://razeen.me) - Razeen's Blog * [https://razeen.me](https://razeen.me) - Razeen's Blog
* [https://mxthd.me](https://mxthd.me) - 梦醒逃荒岛
如果你的博客使用`Eiblog`搭建,你可以在 [这里](https://github.com/eiblog/eiblog/issues/1) 提交网址。 如果你的博客使用`Eiblog`搭建,你可以在 [这里](https://github.com/eiblog/eiblog/issues/1) 提交网址。

192
api.go
View File

@@ -4,15 +4,14 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"sort"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/eiblog/eiblog/setting" "github.com/eiblog/eiblog/setting"
"github.com/eiblog/utils/logd" "github.com/eiblog/utils/logd"
"github.com/eiblog/utils/mgo"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gopkg.in/mgo.v2/bson"
) )
const ( const (
@@ -56,6 +55,7 @@ func init() {
APIs["file-delete"] = apiFileDelete APIs["file-delete"] = apiFileDelete
} }
// 更新账号信息Email、PhoneNumber、Address
func apiAccount(c *gin.Context) { func apiAccount(c *gin.Context) {
e := c.PostForm("email") e := c.PostForm("email")
pn := c.PostForm("phoneNumber") pn := c.PostForm("phoneNumber")
@@ -66,8 +66,9 @@ func apiAccount(c *gin.Context) {
return return
} }
err := UpdateAccountField(bson.M{"$set": bson.M{"email": e, "phonen": pn, "address": ad}}) err := UpdateAccountField(mgo.M{"$set": mgo.M{"email": e, "phonen": pn, "address": ad}})
if err != nil { if err != nil {
logd.Error(err)
responseNotice(c, NOTICE_NOTICE, err.Error(), "") responseNotice(c, NOTICE_NOTICE, err.Error(), "")
return return
} }
@@ -77,6 +78,7 @@ func apiAccount(c *gin.Context) {
responseNotice(c, NOTICE_SUCCESS, "更新成功", "") responseNotice(c, NOTICE_SUCCESS, "更新成功", "")
} }
// 更新博客信息
func apiBlog(c *gin.Context) { func apiBlog(c *gin.Context) {
bn := c.PostForm("blogName") bn := c.PostForm("blogName")
bt := c.PostForm("bTitle") bt := c.PostForm("bTitle")
@@ -89,8 +91,11 @@ func apiBlog(c *gin.Context) {
return return
} }
err := UpdateAccountField(bson.M{"$set": bson.M{"blogger.blogname": bn, "blogger.btitle": bt, "blogger.beian": ba, "blogger.subtitle": st, "blogger.seriessay": ss, "blogger.archivessay": as}}) err := UpdateAccountField(mgo.M{"$set": mgo.M{"blogger.blogname": bn,
"blogger.btitle": bt, "blogger.beian": ba, "blogger.subtitle": st,
"blogger.seriessay": ss, "blogger.archivessay": as}})
if err != nil { if err != nil {
logd.Error(err)
responseNotice(c, NOTICE_NOTICE, err.Error(), "") responseNotice(c, NOTICE_NOTICE, err.Error(), "")
return return
} }
@@ -105,6 +110,7 @@ func apiBlog(c *gin.Context) {
responseNotice(c, NOTICE_SUCCESS, "更新成功", "") responseNotice(c, NOTICE_SUCCESS, "更新成功", "")
} }
// 更新密码
func apiPassword(c *gin.Context) { func apiPassword(c *gin.Context) {
logd.Debug(c.Request.PostForm.Encode()) logd.Debug(c.Request.PostForm.Encode())
od := c.PostForm("old") od := c.PostForm("old")
@@ -124,8 +130,9 @@ func apiPassword(c *gin.Context) {
} }
newPwd := EncryptPasswd(Ei.Username, nw) newPwd := EncryptPasswd(Ei.Username, nw)
err := UpdateAccountField(bson.M{"$set": bson.M{"password": newPwd}}) err := UpdateAccountField(mgo.M{"$set": mgo.M{"password": newPwd}})
if err != nil { if err != nil {
logd.Error(err)
responseNotice(c, NOTICE_NOTICE, err.Error(), "") responseNotice(c, NOTICE_NOTICE, err.Error(), "")
return return
} }
@@ -133,46 +140,39 @@ func apiPassword(c *gin.Context) {
responseNotice(c, NOTICE_SUCCESS, "更新成功", "") responseNotice(c, NOTICE_SUCCESS, "更新成功", "")
} }
// 删除文章,软删除:移入到回收箱
func apiPostDelete(c *gin.Context) { func apiPostDelete(c *gin.Context) {
var err error var ids []int32
defer func() { for _, v := range c.PostFormArray("cid[]") {
i, err := strconv.Atoi(v)
if err != nil || int32(i) < setting.Conf.General.StartID {
responseNotice(c, NOTICE_NOTICE, "参数错误", "")
return
}
ids = append(ids, int32(i))
}
err := DelArticles(ids...)
if err != nil { if err != nil {
logd.Error(err) logd.Error(err)
responseNotice(c, NOTICE_NOTICE, err.Error(), "") responseNotice(c, NOTICE_NOTICE, err.Error(), "")
return return
} }
responseNotice(c, NOTICE_SUCCESS, "删除成功", "")
}()
err = c.Request.ParseForm() // elasticsearch
if err != nil {
return
}
var ids []int32
var i int
for _, v := range c.Request.PostForm["cid[]"] {
i, err = strconv.Atoi(v)
if err != nil || i < 1 {
err = errors.New("参数错误")
return
}
ids = append(ids, int32(i))
}
err = DelArticles(ids...)
if err != nil {
return
}
// elasticsearch 删除索引
err = ElasticDelIndex(ids) err = ElasticDelIndex(ids)
if err != nil { if err != nil {
return logd.Error(err)
} }
// TODO disqus delete
responseNotice(c, NOTICE_SUCCESS, "删除成功", "")
} }
func apiPostAdd(c *gin.Context) { func apiPostAdd(c *gin.Context) {
var err error var (
var do string err error
var cid int do string
cid int
)
defer func() { defer func() {
switch do { switch do {
case "auto": // 自动保存 case "auto": // 自动保存
@@ -181,18 +181,16 @@ func apiPostAdd(c *gin.Context) {
return return
} }
c.JSON(http.StatusOK, gin.H{"success": SUCCESS, "time": time.Now().Format("15:04:05 PM"), "cid": cid}) c.JSON(http.StatusOK, gin.H{"success": SUCCESS, "time": time.Now().Format("15:04:05 PM"), "cid": cid})
case "save": // 保存草稿 case "save", "publish": // 草稿,发布
if err != nil { if err != nil {
responseNotice(c, NOTICE_NOTICE, err.Error(), "") responseNotice(c, NOTICE_NOTICE, err.Error(), "")
return return
} }
c.Redirect(http.StatusFound, "/admin/manage-draft") uri := "/admin/manage-draft"
case "publish": // 发布 if do == "publish" {
if err != nil { uri = "/admin/manage-posts"
responseNotice(c, NOTICE_NOTICE, err.Error(), "")
return
} }
c.Redirect(http.StatusFound, "/admin/manage-posts") c.Redirect(http.StatusFound, uri)
} }
}() }()
@@ -200,11 +198,11 @@ func apiPostAdd(c *gin.Context) {
slug := c.PostForm("slug") slug := c.PostForm("slug")
title := c.PostForm("title") title := c.PostForm("title")
text := c.PostForm("text") text := c.PostForm("text")
date := c.PostForm("date") date := CheckDate(c.PostForm("date"))
serie := c.PostForm("serie") serie := c.PostForm("serie")
tag := c.PostForm("tags") tag := c.PostForm("tags")
update := c.PostForm("update") update := c.PostForm("update")
if title == "" || text == "" || slug == "" { if slug == "" || title == "" || text == "" {
err = errors.New("参数错误") err = errors.New("参数错误")
return return
} }
@@ -217,13 +215,14 @@ func apiPostAdd(c *gin.Context) {
Title: title, Title: title,
Content: text, Content: text,
Slug: slug, Slug: slug,
CreateTime: CheckDate(date), CreateTime: date,
IsDraft: do != "publish", IsDraft: do != "publish",
Author: Ei.Username, Author: Ei.Username,
SerieID: serieid, SerieID: serieid,
Tags: tags, Tags: tags,
} }
cid, err = strconv.Atoi(c.PostForm("cid")) cid, err = strconv.Atoi(c.PostForm("cid"))
// 新文章
if err != nil || cid < 1 { if err != nil || cid < 1 {
err = AddArticle(artc) err = AddArticle(artc)
if err != nil { if err != nil {
@@ -232,58 +231,55 @@ func apiPostAdd(c *gin.Context) {
} }
cid = int(artc.ID) cid = int(artc.ID)
if !artc.IsDraft { if !artc.IsDraft {
// 异步执行,快
go func() {
// elastic
ElasticIndex(artc) ElasticIndex(artc)
// rss
DoPings(slug) DoPings(slug)
// disqus
ThreadCreate(artc)
}()
} }
return return
} }
// 旧文章
artc.ID = int32(cid) artc.ID = int32(cid)
i, a := GetArticle(artc.ID) _, a := GetArticle(artc.ID)
if a != nil { if a != nil {
artc.IsDraft = false artc.IsDraft = false
artc.Count = a.Count artc.Count = a.Count
artc.UpdateTime = a.UpdateTime artc.UpdateTime = a.UpdateTime
Ei.Articles = append(Ei.Articles[0:i], Ei.Articles[i+1:]...)
DelFromLinkedList(a)
ManageTagsArticle(a, false, DELETE)
ManageSeriesArticle(a, false, DELETE)
ManageArchivesArticle(a, false, DELETE)
delete(Ei.MapArticles, a.Slug)
a = nil
} }
if CheckBool(update) { if CheckBool(update) {
artc.UpdateTime = time.Now() artc.UpdateTime = time.Now()
} }
err = UpdateArticle(bson.M{"id": artc.ID}, artc) // 数据库更新
err = UpdateArticle(mgo.M{"id": artc.ID}, artc)
if err != nil { if err != nil {
logd.Error(err) logd.Error(err)
return return
} }
if !artc.IsDraft { if !artc.IsDraft {
Ei.MapArticles[artc.Slug] = artc ReplaceArticle(a, artc)
Ei.Articles = append(Ei.Articles, artc) // 异步执行,快
sort.Sort(Ei.Articles) go func() {
GenerateExcerptAndRender(artc) // elastic
// elasticsearch 索引
ElasticIndex(artc) ElasticIndex(artc)
// rss
DoPings(slug) DoPings(slug)
if artc.ID >= setting.Conf.General.StartID { // disqus
ManageTagsArticle(artc, true, ADD) if a == nil {
ManageSeriesArticle(artc, true, ADD) ThreadCreate(artc)
ManageArchivesArticle(artc, true, ADD)
AddToLinkedList(artc.ID)
} }
}()
} }
} }
// 只能逐一删除,专题下不能有文章
func apiSerieDelete(c *gin.Context) { func apiSerieDelete(c *gin.Context) {
err := c.Request.ParseForm() for _, v := range c.PostFormArray("mid[]") {
if err != nil {
responseNotice(c, NOTICE_NOTICE, err.Error(), "")
return
}
// 只能逐一删除
for _, v := range c.Request.PostForm["mid[]"] {
id, err := strconv.Atoi(v) id, err := strconv.Atoi(v)
if err != nil || id < 1 { if err != nil || id < 1 {
responseNotice(c, NOTICE_NOTICE, err.Error(), "") responseNotice(c, NOTICE_NOTICE, err.Error(), "")
@@ -299,6 +295,7 @@ func apiSerieDelete(c *gin.Context) {
responseNotice(c, NOTICE_SUCCESS, "删除成功", "") responseNotice(c, NOTICE_SUCCESS, "删除成功", "")
} }
// 添加专题,如果专题有提交 mid 即更新专题
func apiSerieAdd(c *gin.Context) { func apiSerieAdd(c *gin.Context) {
name := c.PostForm("name") name := c.PostForm("name")
slug := c.PostForm("slug") slug := c.PostForm("slug")
@@ -335,24 +332,15 @@ func apiSerieAdd(c *gin.Context) {
responseNotice(c, NOTICE_SUCCESS, "操作成功", "") responseNotice(c, NOTICE_SUCCESS, "操作成功", "")
} }
// 暂未启用 // NOTE 排序专题,暂未实现
func apiSerieSort(c *gin.Context) { func apiSerieSort(c *gin.Context) {
err := c.Request.ParseForm() v := c.PostFormArray("mid[]")
if err != nil {
responseNotice(c, NOTICE_NOTICE, err.Error(), "")
return
}
v := c.Request.PostForm["mid[]"]
logd.Debug(v) logd.Debug(v)
} }
// 删除草稿箱,物理删除
func apiDraftDelete(c *gin.Context) { func apiDraftDelete(c *gin.Context) {
err := c.Request.ParseForm() for _, v := range c.PostFormArray("mid[]") {
if err != nil {
responseNotice(c, NOTICE_NOTICE, err.Error(), "")
return
}
for _, v := range c.Request.PostForm["mid[]"] {
i, err := strconv.Atoi(v) i, err := strconv.Atoi(v)
if err != nil || i < 1 { if err != nil || i < 1 {
responseNotice(c, NOTICE_NOTICE, "参数错误", "") responseNotice(c, NOTICE_NOTICE, "参数错误", "")
@@ -367,15 +355,9 @@ func apiDraftDelete(c *gin.Context) {
responseNotice(c, NOTICE_SUCCESS, "删除成功", "") responseNotice(c, NOTICE_SUCCESS, "删除成功", "")
} }
// 删除垃圾箱,物理删除
func apiTrashDelete(c *gin.Context) { func apiTrashDelete(c *gin.Context) {
logd.Debug(c.PostForm("key")) for _, v := range c.PostFormArray("mid[]") {
logd.Debug(c.Request.PostForm)
err := c.Request.ParseForm()
if err != nil {
responseNotice(c, NOTICE_NOTICE, err.Error(), "")
return
}
for _, v := range c.Request.PostForm["mid[]"] {
i, err := strconv.Atoi(v) i, err := strconv.Atoi(v)
if err != nil || i < 1 { if err != nil || i < 1 {
responseNotice(c, NOTICE_NOTICE, "参数错误", "") responseNotice(c, NOTICE_NOTICE, "参数错误", "")
@@ -390,15 +372,9 @@ func apiTrashDelete(c *gin.Context) {
responseNotice(c, NOTICE_SUCCESS, "删除成功", "") responseNotice(c, NOTICE_SUCCESS, "删除成功", "")
} }
// 从垃圾箱恢复到草稿箱
func apiTrashRecover(c *gin.Context) { func apiTrashRecover(c *gin.Context) {
logd.Debug(c.PostForm("key")) for _, v := range c.PostFormArray("mid[]") {
logd.Debug(c.Request.PostForm)
err := c.Request.ParseForm()
if err != nil {
responseNotice(c, NOTICE_NOTICE, err.Error(), "")
return
}
for _, v := range c.Request.PostForm["mid[]"] {
i, err := strconv.Atoi(v) i, err := strconv.Atoi(v)
if err != nil || i < 1 { if err != nil || i < 1 {
responseNotice(c, NOTICE_NOTICE, "参数错误", "") responseNotice(c, NOTICE_NOTICE, "参数错误", "")
@@ -414,6 +390,7 @@ func apiTrashRecover(c *gin.Context) {
responseNotice(c, NOTICE_SUCCESS, "恢复成功", "") responseNotice(c, NOTICE_SUCCESS, "恢复成功", "")
} }
// 上传文件到 qiniu 云
func apiFileUpload(c *gin.Context) { func apiFileUpload(c *gin.Context) {
type Size interface { type Size interface {
Size() int64 Size() int64
@@ -437,7 +414,7 @@ func apiFileUpload(c *gin.Context) {
c.String(http.StatusBadRequest, err.Error()) c.String(http.StatusBadRequest, err.Error())
return return
} }
typ := c.Request.Header.Get("Content-Type") typ := header.Header.Get("Content-Type")
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"title": filename, "title": filename,
"isImage": typ[:5] == "image", "isImage": typ[:5] == "image",
@@ -446,20 +423,19 @@ func apiFileUpload(c *gin.Context) {
}) })
} }
// 删除七牛 CDN 文件
func apiFileDelete(c *gin.Context) { func apiFileDelete(c *gin.Context) {
var err error defer c.String(http.StatusOK, "删掉了吗?鬼知道。。。")
defer func() {
name := c.PostForm("title")
if name == "" {
logd.Error("参数错误")
return
}
err := FileDelete(name)
if err != nil { if err != nil {
logd.Error(err) logd.Error(err)
} }
c.String(http.StatusOK, "删掉了吗?鬼知道。。。")
}()
name := c.PostForm("title")
if name == "" {
err = errors.New("参数错误")
return
}
err = FileDelete(name)
} }
func responseNotice(c *gin.Context, typ, content, hl string) { func responseNotice(c *gin.Context, typ, content, hl string) {

View File

@@ -3,6 +3,7 @@ package main
import ( import (
"bytes" "bytes"
"encoding/json"
"fmt" "fmt"
"html/template" "html/template"
"net/http" "net/http"
@@ -11,9 +12,9 @@ import (
"github.com/eiblog/eiblog/setting" "github.com/eiblog/eiblog/setting"
"github.com/eiblog/utils/logd" "github.com/eiblog/utils/logd"
"github.com/eiblog/utils/mgo"
"github.com/gin-gonic/contrib/sessions" "github.com/gin-gonic/contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gopkg.in/mgo.v2/bson"
) )
// 是否登录 // 是否登录
@@ -73,7 +74,7 @@ func HandleLoginPost(c *gin.Context) {
session.Save() session.Save()
Ei.LoginIP = c.ClientIP() Ei.LoginIP = c.ClientIP()
Ei.LoginTime = time.Now() Ei.LoginTime = time.Now()
UpdateAccountField(bson.M{"$set": bson.M{"loginip": Ei.LoginIP, "logintime": Ei.LoginTime}}) UpdateAccountField(mgo.M{"$set": mgo.M{"loginip": Ei.LoginIP, "logintime": Ei.LoginTime}})
c.Redirect(http.StatusFound, "/admin/profile") c.Redirect(http.StatusFound, "/admin/profile")
} }
@@ -118,7 +119,8 @@ func HandlePost(c *gin.Context) {
for tag, _ := range Ei.Tags { for tag, _ := range Ei.Tags {
tags = append(tags, T{tag, tag}) tags = append(tags, T{tag, tag})
} }
h["Tags"] = tags str, _ := json.Marshal(tags)
h["Tags"] = string(str)
c.Status(http.StatusOK) c.Status(http.StatusOK)
RenderHTMLBack(c, "admin-post", h) RenderHTMLBack(c, "admin-post", h)
} }

View File

@@ -35,13 +35,14 @@ general:
clean: 1 clean: 1
# 评论相关 # 评论相关
disqus: disqus:
shortname: deepzz shortname: xxxxxx
publickey: wdSgxRm9rdGAlLKFcFdToBe3GT4SibmV7Y8EjJQ0r4GWXeKtxpopMAeIeoI2dTEg publickey: wdSgxRm9rdGAlLKFcFdToBe3GT4SibmV7Y8EjJQ0r4GWXeKtxpopMAeIeoI2dTEg
accesstoken: 50023908f39f4607957e909b495326af accesstoken: 50023908f39f4607957e909b495326af
postscount: https://disqus.com/api/3.0/threads/set.json postscount: https://disqus.com/api/3.0/threads/set.json
postslist: https://disqus.com/api/3.0/threads/listPosts.json postslist: https://disqus.com/api/3.0/threads/listPosts.json
postcreate: https://disqus.com/api/3.0/posts/create.json postcreate: https://disqus.com/api/3.0/posts/create.json
postapprove: https://disqus.com/api/3.0/posts/approve.json postapprove: https://disqus.com/api/3.0/posts/approve.json
threadcreate: https://disqus.com/api/3.0/threads/create.json
# disqus.js 文件名 # disqus.js 文件名
embed: disqus_7d3cf2.js embed: disqus_7d3cf2.js
# 获取评论数量间隔 # 获取评论数量间隔
@@ -49,7 +50,7 @@ disqus:
# 谷歌统计 # 谷歌统计
google: google:
url: https://www.google-analytics.com/collect url: https://www.google-analytics.com/collect
tid: UA-77251712-1 tid: UA-xxxxxx-1
v: "1" v: "1"
t: pageview t: pageview
# 七牛CDN # 七牛CDN

402
db.go
View File

@@ -13,9 +13,7 @@ import (
"github.com/eiblog/blackfriday" "github.com/eiblog/blackfriday"
"github.com/eiblog/eiblog/setting" "github.com/eiblog/eiblog/setting"
"github.com/eiblog/utils/logd" "github.com/eiblog/utils/logd"
db "github.com/eiblog/utils/mgo" "github.com/eiblog/utils/mgo"
"gopkg.in/mgo.v2"
"gopkg.in/mgo.v2/bson"
) )
// 数据库及表名 // 数据库及表名
@@ -62,44 +60,24 @@ var (
func init() { func init() {
// 数据库加索引 // 数据库加索引
ms, c := db.Connect(DB, COLLECTION_ACCOUNT) err := mgo.Index(DB, COLLECTION_ACCOUNT, []string{"username"})
index := mgo.Index{ if err != nil {
Key: []string{"username"},
Unique: true,
DropDups: true,
Background: true,
Sparse: true,
}
if err := c.EnsureIndex(index); err != nil {
logd.Fatal(err) logd.Fatal(err)
} }
ms.Close()
ms, c = db.Connect(DB, COLLECTION_ARTICLE) err = mgo.Index(DB, COLLECTION_ARTICLE, []string{"id"})
index = mgo.Index{ if err != nil {
Key: []string{"id"},
Unique: true,
DropDups: true,
Background: true,
Sparse: true,
}
if err := c.EnsureIndex(index); err != nil {
logd.Fatal(err) logd.Fatal(err)
} }
index = mgo.Index{
Key: []string{"slug"}, err = mgo.Index(DB, COLLECTION_ARTICLE, []string{"slug"})
Unique: true, if err != nil {
DropDups: true,
Background: true,
Sparse: true,
}
if err := c.EnsureIndex(index); err != nil {
logd.Fatal(err) logd.Fatal(err)
} }
ms.Close()
// 读取帐号信息 // 读取帐号信息
Ei = loadAccount() loadAccount()
// 获取文章数据 // 获取文章数据
Ei.Articles = loadArticles() loadArticles()
// 生成markdown文档 // 生成markdown文档
go generateMarkdown() go generateMarkdown()
// 启动定时器 // 启动定时器
@@ -109,12 +87,13 @@ func init() {
} }
// 读取或初始化帐号信息 // 读取或初始化帐号信息
func loadAccount() (a *Account) { func loadAccount() {
a = &Account{} Ei = &Account{}
err := db.FindOne(DB, COLLECTION_ACCOUNT, bson.M{"username": setting.Conf.Account.Username}, a) err := mgo.FindOne(DB, COLLECTION_ACCOUNT, mgo.M{"username": setting.Conf.Account.Username}, Ei)
// 初始化用户数据 // 初始化用户数据
if err == mgo.ErrNotFound { if err == mgo.ErrNotFound {
a = &Account{ logd.Printf("Initializing account: %s\n", setting.Conf.Account.Username)
Ei = &Account{
Username: setting.Conf.Account.Username, Username: setting.Conf.Account.Username,
Password: EncryptPasswd(setting.Conf.Account.Username, setting.Conf.Account.Password), Password: EncryptPasswd(setting.Conf.Account.Username, setting.Conf.Account.Password),
Email: setting.Conf.Account.Email, Email: setting.Conf.Account.Email,
@@ -122,29 +101,28 @@ func loadAccount() (a *Account) {
Address: setting.Conf.Account.Address, Address: setting.Conf.Account.Address,
CreateTime: time.Now(), CreateTime: time.Now(),
} }
a.BlogName = setting.Conf.Blogger.BlogName Ei.BlogName = setting.Conf.Blogger.BlogName
a.SubTitle = setting.Conf.Blogger.SubTitle Ei.SubTitle = setting.Conf.Blogger.SubTitle
a.BeiAn = setting.Conf.Blogger.BeiAn Ei.BeiAn = setting.Conf.Blogger.BeiAn
a.BTitle = setting.Conf.Blogger.BTitle Ei.BTitle = setting.Conf.Blogger.BTitle
a.Copyright = setting.Conf.Blogger.Copyright Ei.Copyright = setting.Conf.Blogger.Copyright
err = db.Insert(DB, COLLECTION_ACCOUNT, a) err = mgo.Insert(DB, COLLECTION_ACCOUNT, Ei)
generateTopic() generateTopic()
} else if err != nil { } else if err != nil {
logd.Fatal(err) logd.Fatal(err)
} }
a.CH = make(chan string, 2) Ei.CH = make(chan string, 2)
a.MapArticles = make(map[string]*Article) Ei.MapArticles = make(map[string]*Article)
a.Tags = make(map[string]SortArticles) Ei.Tags = make(map[string]SortArticles)
return
} }
func loadArticles() (artcs SortArticles) { func loadArticles() {
err := db.FindAll(DB, COLLECTION_ARTICLE, bson.M{"isdraft": false, "deletetime": bson.M{"$eq": time.Time{}}}, &artcs) err := mgo.FindAll(DB, COLLECTION_ARTICLE, mgo.M{"isdraft": false, "deletetime": mgo.M{"$eq": time.Time{}}}, &Ei.Articles)
if err != nil { if err != nil {
logd.Fatal(err) logd.Fatal(err)
} }
sort.Sort(artcs) sort.Sort(Ei.Articles)
for i, v := range artcs { for i, v := range Ei.Articles {
// 渲染文章 // 渲染文章
GenerateExcerptAndRender(v) GenerateExcerptAndRender(v)
Ei.MapArticles[v.Slug] = v Ei.MapArticles[v.Slug] = v
@@ -153,18 +131,15 @@ func loadArticles() (artcs SortArticles) {
continue continue
} }
if i > 0 { if i > 0 {
v.Prev = artcs[i-1] v.Prev = Ei.Articles[i-1]
} }
if artcs[i+1].ID >= setting.Conf.General.StartID { if Ei.Articles[i+1].ID >= setting.Conf.General.StartID {
v.Next = artcs[i+1] v.Next = Ei.Articles[i+1]
} }
ManageTagsArticle(v, false, ADD) upArticle(v, false)
ManageSeriesArticle(v, false, ADD)
ManageArchivesArticle(v, false, ADD)
} }
Ei.CH <- SERIES_MD Ei.CH <- SERIES_MD
Ei.CH <- ARCHIVE_MD Ei.CH <- ARCHIVE_MD
return
} }
// generate series,archive markdown // generate series,archive markdown
@@ -183,7 +158,8 @@ func generateMarkdown() {
buffer.WriteString("\n\n") buffer.WriteString("\n\n")
for _, artc := range serie.Articles { for _, artc := range serie.Articles {
//eg. * [标题一](/post/hello-world.html) <span class="date">(Man 02, 2006)</span> //eg. * [标题一](/post/hello-world.html) <span class="date">(Man 02, 2006)</span>
buffer.WriteString("* [" + artc.Title + "](/post/" + artc.Slug + ".html) <span class=\"date\">(" + artc.CreateTime.Format("Jan 02, 2006") + ")</span>\n") buffer.WriteString("* [" + artc.Title + "](/post/" + artc.Slug +
".html) <span class=\"date\">(" + artc.CreateTime.Format("Jan 02, 2006") + ")</span>\n")
} }
buffer.WriteByte('\n') buffer.WriteByte('\n')
} }
@@ -191,15 +167,31 @@ func generateMarkdown() {
case ARCHIVE_MD: case ARCHIVE_MD:
sort.Sort(Ei.Archives) sort.Sort(Ei.Archives)
var buffer bytes.Buffer var buffer bytes.Buffer
buffer.WriteString(Ei.ArchivesSay) buffer.WriteString(Ei.ArchivesSay + "\n")
buffer.WriteString("\n\n")
var (
currentYear string
gt12Month = len(Ei.Archives) > 12
)
for _, archive := range Ei.Archives { for _, archive := range Ei.Archives {
buffer.WriteString(fmt.Sprintf("### %s", archive.Time.Format("2006年01月"))) if gt12Month {
buffer.WriteString("\n\n") year := archive.Time.Format("2006 年")
for _, artc := range archive.Articles { if currentYear != year {
buffer.WriteString("* [" + artc.Title + "](/post/" + artc.Slug + ".html) <span class=\"date\">(" + artc.CreateTime.Format("Jan 02, 2006") + ")</span>\n") currentYear = year
buffer.WriteString(fmt.Sprintf("\n### %s\n\n", archive.Time.Format("2006 年")))
}
} else {
buffer.WriteString(fmt.Sprintf("\n### %s\n\n", archive.Time.Format("2006年1月")))
}
for i, artc := range archive.Articles {
if i == 0 && gt12Month {
buffer.WriteString("* *[" + artc.Title + "](/post/" + artc.Slug +
".html) <span class=\"date\">(" + artc.CreateTime.Format("Jan 02, 2006") + ")</span>*\n")
} else {
buffer.WriteString("* [" + artc.Title + "](/post/" + artc.Slug +
".html) <span class=\"date\">(" + artc.CreateTime.Format("Jan 02, 2006") + ")</span>\n")
}
} }
buffer.WriteByte('\n')
} }
Ei.PageArchives = string(renderPage(buffer.Bytes())) Ei.PageArchives = string(renderPage(buffer.Bytes()))
} }
@@ -209,26 +201,29 @@ func generateMarkdown() {
// init account: generate blogroll and about page // init account: generate blogroll and about page
func generateTopic() { func generateTopic() {
about := &Article{ about := &Article{
ID: db.NextVal(DB, COUNTER_ARTICLE), ID: mgo.NextVal(DB, COUNTER_ARTICLE),
Author: setting.Conf.Account.Username, Author: setting.Conf.Account.Username,
Title: "关于", Title: "关于",
Slug: "about", Slug: "about",
CreateTime: time.Time{}, CreateTime: time.Time{},
UpdateTime: time.Time{}, UpdateTime: time.Time{},
} }
// 推送到 disqus
go func() { ThreadCreate(about) }()
blogroll := &Article{ blogroll := &Article{
ID: db.NextVal(DB, COUNTER_ARTICLE), ID: mgo.NextVal(DB, COUNTER_ARTICLE),
Author: setting.Conf.Account.Username, Author: setting.Conf.Account.Username,
Title: "友情链接", Title: "友情链接",
Slug: "blogroll", Slug: "blogroll",
CreateTime: time.Time{}, CreateTime: time.Time{},
UpdateTime: time.Time{}, UpdateTime: time.Time{},
} }
err := db.Insert(DB, COLLECTION_ARTICLE, blogroll) err := mgo.Insert(DB, COLLECTION_ARTICLE, blogroll)
if err != nil { if err != nil {
logd.Fatal(err) logd.Fatal(err)
} }
err = db.Insert(DB, COLLECTION_ARTICLE, about) err = mgo.Insert(DB, COLLECTION_ARTICLE, about)
if err != nil { if err != nil {
logd.Fatal(err) logd.Fatal(err)
} }
@@ -273,97 +268,6 @@ func PageList(p, n int) (prev int, next int, artcs []*Article) {
return return
} }
// 管理 tag
func ManageTagsArticle(artc *Article, s bool, do string) {
switch do {
case ADD:
for _, tag := range artc.Tags {
Ei.Tags[tag] = append(Ei.Tags[tag], artc)
if s {
sort.Sort(Ei.Tags[tag])
}
}
case DELETE:
for _, tag := range artc.Tags {
for i, v := range Ei.Tags[tag] {
if v == artc {
Ei.Tags[tag] = append(Ei.Tags[tag][0:i], Ei.Tags[tag][i+1:]...)
if len(Ei.Tags[tag]) == 0 {
delete(Ei.Tags, tag)
}
return
}
}
}
}
}
// 管理专题
func ManageSeriesArticle(artc *Article, s bool, do string) {
switch do {
case ADD:
for i, serie := range Ei.Series {
if serie.ID == artc.SerieID {
Ei.Series[i].Articles = append(Ei.Series[i].Articles, artc)
if s {
sort.Sort(Ei.Series[i].Articles)
Ei.CH <- SERIES_MD
return
}
}
}
case DELETE:
for i, serie := range Ei.Series {
if serie.ID == artc.SerieID {
for j, v := range serie.Articles {
if v == artc {
Ei.Series[i].Articles = append(Ei.Series[i].Articles[0:j], Ei.Series[i].Articles[j+1:]...)
Ei.CH <- SERIES_MD
return
}
}
}
}
}
}
// 管理归档
func ManageArchivesArticle(artc *Article, s bool, do string) {
switch do {
case ADD:
add := false
y, m, _ := artc.CreateTime.Date()
for i, archive := range Ei.Archives {
ay, am, _ := archive.Time.Date()
if y == ay && m == am {
add = true
Ei.Archives[i].Articles = append(Ei.Archives[i].Articles, artc)
if s {
sort.Sort(Ei.Archives[i].Articles)
Ei.CH <- ARCHIVE_MD
break
}
}
}
if !add {
Ei.Archives = append(Ei.Archives, &Archive{Time: artc.CreateTime, Articles: SortArticles{artc}})
}
case DELETE:
for i, archive := range Ei.Archives {
ay, am, _ := archive.Time.Date()
if y, m, _ := artc.CreateTime.Date(); ay == y && am == m {
for j, v := range archive.Articles {
if v == artc {
Ei.Archives[i].Articles = append(Ei.Archives[i].Articles[0:j], Ei.Archives[i].Articles[j+1:]...)
Ei.CH <- ARCHIVE_MD
return
}
}
}
}
}
}
// 渲染markdown操作和截取摘要操作 // 渲染markdown操作和截取摘要操作
var reg = regexp.MustCompile(setting.Conf.General.Identifier) var reg = regexp.MustCompile(setting.Conf.General.Identifier)
@@ -401,23 +305,128 @@ func GenerateExcerptAndRender(artc *Article) {
// 读取草稿箱 // 读取草稿箱
func LoadDraft() (artcs SortArticles, err error) { func LoadDraft() (artcs SortArticles, err error) {
err = db.FindAll(DB, COLLECTION_ARTICLE, bson.M{"isdraft": true}, &artcs) err = mgo.FindAll(DB, COLLECTION_ARTICLE, mgo.M{"isdraft": true}, &artcs)
sort.Sort(artcs) sort.Sort(artcs)
return return
} }
// 读取回收箱 // 读取回收箱
func LoadTrash() (artcs SortArticles, err error) { func LoadTrash() (artcs SortArticles, err error) {
err = db.FindAll(DB, COLLECTION_ARTICLE, bson.M{"deletetime": bson.M{"$ne": time.Time{}}}, &artcs) err = mgo.FindAll(DB, COLLECTION_ARTICLE, mgo.M{"deletetime": mgo.M{"$ne": time.Time{}}}, &artcs)
sort.Sort(artcs) sort.Sort(artcs)
return return
} }
// 添加文章到tag、serie、archive
func upArticle(artc *Article, needSort bool) {
// tag
for _, tag := range artc.Tags {
Ei.Tags[tag] = append(Ei.Tags[tag], artc)
if needSort {
sort.Sort(Ei.Tags[tag])
}
}
// serie
for i, serie := range Ei.Series {
if serie.ID == artc.SerieID {
Ei.Series[i].Articles = append(Ei.Series[i].Articles, artc)
if needSort {
sort.Sort(Ei.Series[i].Articles)
Ei.CH <- SERIES_MD
}
break
}
}
// archive
y, m, _ := artc.CreateTime.Date()
for i, archive := range Ei.Archives {
if ay, am, _ := archive.Time.Date(); y == ay && m == am {
Ei.Archives[i].Articles = append(Ei.Archives[i].Articles, artc)
if needSort {
sort.Sort(Ei.Archives[i].Articles)
Ei.CH <- ARCHIVE_MD
}
return
}
}
Ei.Archives = append(Ei.Archives, &Archive{Time: artc.CreateTime,
Articles: SortArticles{artc}})
if needSort {
Ei.CH <- ARCHIVE_MD
}
}
// 删除文章从tag、serie、archive
func dropArticle(artc *Article) {
// tag
for _, tag := range artc.Tags {
for i, v := range Ei.Tags[tag] {
if v == artc {
Ei.Tags[tag] = append(Ei.Tags[tag][0:i], Ei.Tags[tag][i+1:]...)
if len(Ei.Tags[tag]) == 0 {
delete(Ei.Tags, tag)
}
}
}
}
// serie
for i, serie := range Ei.Series {
if serie.ID == artc.SerieID {
for j, v := range serie.Articles {
if v == artc {
Ei.Series[i].Articles = append(Ei.Series[i].Articles[0:j],
Ei.Series[i].Articles[j+1:]...)
Ei.CH <- SERIES_MD
break
}
}
}
}
// archive
for i, archive := range Ei.Archives {
ay, am, _ := archive.Time.Date()
if y, m, _ := artc.CreateTime.Date(); ay == y && am == m {
for j, v := range archive.Articles {
if v == artc {
Ei.Archives[i].Articles = append(Ei.Archives[i].Articles[0:j],
Ei.Archives[i].Articles[j+1:]...)
if len(Ei.Archives[i].Articles) == 0 {
Ei.Archives = append(Ei.Archives[:i], Ei.Archives[i+1:]...)
}
Ei.CH <- ARCHIVE_MD
break
}
}
}
}
}
// 替换文章
func ReplaceArticle(oldArtc *Article, newArtc *Article) {
if oldArtc != nil {
i, artc := GetArticle(oldArtc.ID)
DelFromLinkedList(artc)
Ei.Articles = append(Ei.Articles[:i], Ei.Articles[i+1:]...)
delete(Ei.MapArticles, artc.Slug)
dropArticle(oldArtc)
}
Ei.MapArticles[newArtc.Slug] = newArtc
Ei.Articles = append(Ei.Articles, newArtc)
sort.Sort(Ei.Articles)
GenerateExcerptAndRender(newArtc)
AddToLinkedList(newArtc.ID)
upArticle(newArtc, true)
}
// 添加文章 // 添加文章
func AddArticle(artc *Article) error { func AddArticle(artc *Article) error {
// 分配ID, 占位至起始id // 分配ID, 占位至起始id
for { for {
if id := db.NextVal(DB, COUNTER_ARTICLE); id < setting.Conf.General.StartID { if id := mgo.NextVal(DB, COUNTER_ARTICLE); id < setting.Conf.General.StartID {
continue continue
} else { } else {
artc.ID = id artc.ID = id
@@ -425,22 +434,22 @@ func AddArticle(artc *Article) error {
} }
} }
if !artc.IsDraft { err := mgo.Insert(DB, COLLECTION_ARTICLE, artc)
if err != nil {
return err
}
// 正式发布文章 // 正式发布文章
if !artc.IsDraft {
defer GenerateExcerptAndRender(artc) defer GenerateExcerptAndRender(artc)
Ei.MapArticles[artc.Slug] = artc Ei.MapArticles[artc.Slug] = artc
Ei.Articles = append([]*Article{artc}, Ei.Articles...) Ei.Articles = append([]*Article{artc}, Ei.Articles...)
sort.Sort(Ei.Articles) sort.Sort(Ei.Articles)
AddToLinkedList(artc.ID) AddToLinkedList(artc.ID)
ManageTagsArticle(artc, true, ADD)
ManageSeriesArticle(artc, true, ADD) upArticle(artc, true)
ManageArchivesArticle(artc, true, ADD)
Ei.CH <- ARCHIVE_MD
if artc.SerieID > 0 {
Ei.CH <- SERIES_MD
} }
} return nil
return db.Insert(DB, COLLECTION_ARTICLE, artc)
} }
// 删除文章,移入回收箱 // 删除文章,移入回收箱
@@ -452,17 +461,13 @@ func DelArticles(ids ...int32) error {
DelFromLinkedList(artc) DelFromLinkedList(artc)
Ei.Articles = append(Ei.Articles[:i], Ei.Articles[i+1:]...) Ei.Articles = append(Ei.Articles[:i], Ei.Articles[i+1:]...)
delete(Ei.MapArticles, artc.Slug) delete(Ei.MapArticles, artc.Slug)
ManageTagsArticle(artc, false, DELETE)
ManageSeriesArticle(artc, false, DELETE) err := UpdateArticle(mgo.M{"id": id}, mgo.M{"$set": mgo.M{"deletetime": time.Now()}})
ManageArchivesArticle(artc, false, DELETE)
err := UpdateArticle(bson.M{"id": id}, bson.M{"$set": bson.M{"deletetime": time.Now()}})
if err != nil { if err != nil {
return err return err
} }
artc = nil dropArticle(artc)
} }
Ei.CH <- ARCHIVE_MD
Ei.CH <- SERIES_MD
return nil return nil
} }
@@ -509,34 +514,36 @@ func timer() {
delT := time.NewTicker(time.Duration(setting.Conf.General.Clean) * time.Hour) delT := time.NewTicker(time.Duration(setting.Conf.General.Clean) * time.Hour)
for { for {
<-delT.C <-delT.C
db.Remove(DB, COLLECTION_ARTICLE, bson.M{"deletetime": bson.M{"$gt": time.Time{}, "$lt": time.Now().Add(time.Duration(setting.Conf.General.Trash) * time.Hour)}}) mgo.Remove(DB, COLLECTION_ARTICLE, mgo.M{"deletetime": mgo.M{"$gt": time.Time{},
"$lt": time.Now().Add(time.Duration(setting.Conf.General.Trash) * time.Hour)}})
} }
} }
// 操作帐号字段 // 操作帐号字段
func UpdateAccountField(M bson.M) error { func UpdateAccountField(M mgo.M) error {
return db.Update(DB, COLLECTION_ACCOUNT, bson.M{"username": Ei.Username}, M) return mgo.Update(DB, COLLECTION_ACCOUNT, mgo.M{"username": Ei.Username}, M)
} }
// 删除草稿箱或回收箱,永久删除 // 删除草稿箱或回收箱,永久删除
func RemoveArticle(id int32) error { func RemoveArticle(id int32) error {
return db.Remove(DB, COLLECTION_ARTICLE, bson.M{"id": id}) return mgo.Remove(DB, COLLECTION_ARTICLE, mgo.M{"id": id})
} }
// 恢复删除文章到草稿箱 // 恢复删除文章到草稿箱
func RecoverArticle(id int32) error { func RecoverArticle(id int32) error {
return db.Update(DB, COLLECTION_ARTICLE, bson.M{"id": id}, bson.M{"$set": bson.M{"deletetime": time.Time{}, "isdraft": true}}) return mgo.Update(DB, COLLECTION_ARTICLE, mgo.M{"id": id},
mgo.M{"$set": mgo.M{"deletetime": time.Time{}, "isdraft": true}})
} }
// 更新文章 // 更新文章
func UpdateArticle(query, update interface{}) error { func UpdateArticle(query, update interface{}) error {
return db.Update(DB, COLLECTION_ARTICLE, query, update) return mgo.Update(DB, COLLECTION_ARTICLE, query, update)
} }
// 编辑文档 // 编辑文档
func QueryArticle(id int32) *Article { func QueryArticle(id int32) *Article {
artc := &Article{} artc := &Article{}
if err := db.FindOne(DB, COLLECTION_ARTICLE, bson.M{"id": id}, artc); err != nil { if err := mgo.FindOne(DB, COLLECTION_ARTICLE, mgo.M{"id": id}, artc); err != nil {
return nil return nil
} }
return artc return artc
@@ -544,17 +551,18 @@ func QueryArticle(id int32) *Article {
// 添加专题 // 添加专题
func AddSerie(name, slug, desc string) error { func AddSerie(name, slug, desc string) error {
serie := &Serie{db.NextVal(DB, COUNTER_SERIE), name, slug, desc, time.Now(), nil} serie := &Serie{mgo.NextVal(DB, COUNTER_SERIE), name, slug, desc, time.Now(), nil}
Ei.Series = append(Ei.Series, serie) Ei.Series = append(Ei.Series, serie)
sort.Sort(Ei.Series) sort.Sort(Ei.Series)
Ei.CH <- SERIES_MD Ei.CH <- SERIES_MD
return UpdateAccountField(bson.M{"$addToSet": bson.M{"blogger.series": serie}}) return UpdateAccountField(mgo.M{"$addToSet": mgo.M{"blogger.series": serie}})
} }
// 更新专题 // 更新专题
func UpdateSerie(serie *Serie) error { func UpdateSerie(serie *Serie) error {
Ei.CH <- SERIES_MD Ei.CH <- SERIES_MD
return db.Update(DB, COLLECTION_ACCOUNT, bson.M{"username": Ei.Username, "blogger.series.id": serie.ID}, bson.M{"$set": bson.M{"blogger.series.$": serie}}) return mgo.Update(DB, COLLECTION_ACCOUNT, mgo.M{"username": Ei.Username,
"blogger.series.id": serie.ID}, mgo.M{"$set": mgo.M{"blogger.series.$": serie}})
} }
// 删除专题 // 删除专题
@@ -564,7 +572,7 @@ func DelSerie(id int32) error {
if len(serie.Articles) > 0 { if len(serie.Articles) > 0 {
return fmt.Errorf("请删除该专题下的所有文章") return fmt.Errorf("请删除该专题下的所有文章")
} }
err := UpdateAccountField(bson.M{"$pull": bson.M{"blogger.series": bson.M{"id": id}}}) err := UpdateAccountField(mgo.M{"$pull": mgo.M{"blogger.series": mgo.M{"id": id}}})
if err != nil { if err != nil {
return err return err
} }
@@ -588,24 +596,24 @@ func QuerySerie(id int32) *Serie {
// 后台分页 // 后台分页
func PageListBack(se int, kw string, draft, del bool, p, n int) (max int, artcs []*Article) { func PageListBack(se int, kw string, draft, del bool, p, n int) (max int, artcs []*Article) {
M := bson.M{} M := mgo.M{}
if draft { if draft {
M["isdraft"] = true M["isdraft"] = true
} else if del { } else if del {
M["deletetime"] = bson.M{"$ne": time.Time{}} M["deletetime"] = mgo.M{"$ne": time.Time{}}
} else { } else {
M["isdraft"] = false M["isdraft"] = false
M["deletetime"] = bson.M{"$eq": time.Time{}} M["deletetime"] = mgo.M{"$eq": time.Time{}}
if se > 0 { if se > 0 {
M["serieid"] = se M["serieid"] = se
} }
if kw != "" { if kw != "" {
M["title"] = bson.M{"$regex": kw, "$options": "$i"} M["title"] = mgo.M{"$regex": kw, "$options": "$i"}
} }
} }
ms, c := db.Connect(DB, COLLECTION_ARTICLE) ms, c := mgo.Connect(DB, COLLECTION_ARTICLE)
defer ms.Close() defer ms.Close()
err := c.Find(M).Select(bson.M{"content": 0}).Sort("-createtime").Limit(n).Skip((p - 1) * n).All(&artcs) err := c.Find(M).Select(mgo.M{"content": 0}).Sort("-createtime").Limit(n).Skip((p - 1) * n).All(&artcs)
if err != nil { if err != nil {
logd.Error(err) logd.Error(err)
} }

147
disqus.go
View File

@@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url"
"strings" "strings"
"time" "time"
@@ -17,6 +18,12 @@ import (
var ErrDisqusConfig = errors.New("disqus config incorrect") var ErrDisqusConfig = errors.New("disqus config incorrect")
func correctDisqusConfig() bool {
return setting.Conf.Disqus.PostsCount != "" &&
setting.Conf.Disqus.PublicKey != "" &&
setting.Conf.Disqus.ShortName != ""
}
// 定时获取所有文章评论数量 // 定时获取所有文章评论数量
type postsCountResp struct { type postsCountResp struct {
Code int Code int
@@ -28,9 +35,7 @@ type postsCountResp struct {
} }
func PostsCount() error { func PostsCount() error {
if setting.Conf.Disqus.PostsCount == "" || if !correctDisqusConfig() {
setting.Conf.Disqus.PublicKey == "" ||
setting.Conf.Disqus.ShortName == "" {
return ErrDisqusConfig return ErrDisqusConfig
} }
@@ -41,20 +46,19 @@ func PostsCount() error {
} }
}) })
baseUrl := setting.Conf.Disqus.PostsCount + vals := url.Values{}
"?api_key=" + setting.Conf.Disqus.PublicKey + vals.Set("api_key", setting.Conf.Disqus.PublicKey)
"&forum=" + setting.Conf.Disqus.ShortName + "&" vals.Set("forum", setting.Conf.Disqus.ShortName)
var count, index int var count, index int
for index < len(Ei.Articles) { for index < len(Ei.Articles) {
var threads []string
for ; index < len(Ei.Articles) && count < 50; index++ { for ; index < len(Ei.Articles) && count < 50; index++ {
artc := Ei.Articles[index] artc := Ei.Articles[index]
threads = append(threads, fmt.Sprintf("thread:ident=post-%s", artc.Slug)) vals.Add("thread:ident", "post-"+artc.Slug)
count++ count++
} }
count = 0 count = 0
url := baseUrl + strings.Join(threads, "&") resp, err := http.Get(setting.Conf.Disqus.PostsCount + "?" + vals.Encode())
resp, err := http.Get(url)
if err != nil { if err != nil {
return err return err
} }
@@ -113,16 +117,18 @@ type postDetail struct {
} }
func PostsList(slug, cursor string) (*postsListResp, error) { func PostsList(slug, cursor string) (*postsListResp, error) {
if setting.Conf.Disqus.PostsList == "" || if !correctDisqusConfig() {
setting.Conf.Disqus.PublicKey == "" ||
setting.Conf.Disqus.ShortName == "" {
return nil, ErrDisqusConfig return nil, ErrDisqusConfig
} }
url := setting.Conf.Disqus.PostsList + "?limit=50&api_key=" + vals := url.Values{}
setting.Conf.Disqus.PublicKey + "&forum=" + setting.Conf.Disqus.ShortName + vals.Set("api_key", setting.Conf.Disqus.PublicKey)
"&cursor=" + cursor + "&thread:ident=post-" + slug vals.Set("forum", setting.Conf.Disqus.ShortName)
resp, err := http.Get(url) vals.Set("thread:ident", "post-"+slug)
vals.Set("cursor", cursor)
vals.Set("limit", "50")
resp, err := http.Get(setting.Conf.Disqus.PostsList + "?" + vals.Encode())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -144,15 +150,15 @@ func PostsList(slug, cursor string) (*postsListResp, error) {
return result, nil return result, nil
} }
type PostCreate struct { type PostComment struct {
Message string `json:"message"` Message string
Parent string `json:"parent"` Parent string
Thread string `json:"thread"` Thread string
AuthorEmail string `json:"author_email"` AuthorEmail string
AuthorName string `json:"autor_name"` AuthorName string
IpAddress string `json:"ip_address"` IpAddress string
Identifier string `json:"identifier"` Identifier string
UserAgent string `json:"user_agent"` UserAgent string
} }
type postCreateResp struct { type postCreateResp struct {
@@ -161,19 +167,21 @@ type postCreateResp struct {
} }
// 评论文章 // 评论文章
func PostComment(pc *PostCreate) (*postCreateResp, error) { func PostCreate(pc *PostComment) (*postCreateResp, error) {
if setting.Conf.Disqus.PostsList == "" || if !correctDisqusConfig() {
setting.Conf.Disqus.PublicKey == "" ||
setting.Conf.Disqus.ShortName == "" {
return nil, ErrDisqusConfig return nil, ErrDisqusConfig
} }
url := setting.Conf.Disqus.PostCreate +
"?api_key=E8Uh5l5fHZ6gD8U3KycjAIAk46f68Zw7C6eW8WSjZvCLXebZ7p0r1yrYDrLilk2F" +
"&message=" + pc.Message + "&parent=" + pc.Parent +
"&thread=" + pc.Thread + "&author_email=" + pc.AuthorEmail +
"&author_name=" + pc.AuthorName
request, err := http.NewRequest("POST", url, nil) vals := url.Values{}
vals.Set("api_key", "E8Uh5l5fHZ6gD8U3KycjAIAk46f68Zw7C6eW8WSjZvCLXebZ7p0r1yrYDrLilk2F")
vals.Set("message", pc.Message)
vals.Set("parent", pc.Parent)
vals.Set("thread", pc.Thread)
vals.Set("author_email", pc.AuthorEmail)
vals.Set("author_name", pc.AuthorName)
// vals.Set("state", "approved")
request, err := http.NewRequest("POST", setting.Conf.Disqus.PostCreate, strings.NewReader(vals.Encode()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -201,24 +209,23 @@ func PostComment(pc *PostCreate) (*postCreateResp, error) {
// 批准评论通过 // 批准评论通过
type approvedResp struct { type approvedResp struct {
Code int `json:"code"` Code int
Response []struct { Response []struct {
Id string `json:"id"` Id string
} `json:"response"` }
} }
func PostApprove(post string) error { func PostApprove(post string) error {
if setting.Conf.Disqus.PostsList == "" || if !correctDisqusConfig() {
setting.Conf.Disqus.PublicKey == "" ||
setting.Conf.Disqus.ShortName == "" {
return ErrDisqusConfig return ErrDisqusConfig
} }
url := setting.Conf.Disqus.PostApprove + vals := url.Values{}
"?api_key=" + setting.Conf.Disqus.PublicKey + vals.Set("api_key", setting.Conf.Disqus.PublicKey)
"&access_token=" + setting.Conf.Disqus.AccessToken + vals.Set("access_token", setting.Conf.Disqus.AccessToken)
"&post=" + post vals.Set("post", post)
request, err := http.NewRequest("POST", url, nil)
request, err := http.NewRequest("POST", setting.Conf.Disqus.PostApprove, strings.NewReader(vals.Encode()))
if err != nil { if err != nil {
return err return err
} }
@@ -246,3 +253,49 @@ func PostApprove(post string) error {
return nil return nil
} }
// 创建thread
type threadCreateResp struct {
Code int
Response struct {
Id string
}
}
func ThreadCreate(artc *Article) error {
if !correctDisqusConfig() {
return ErrDisqusConfig
}
vals := url.Values{}
vals.Set("api_key", setting.Conf.Disqus.PublicKey)
vals.Set("access_token", setting.Conf.Disqus.AccessToken)
vals.Set("forum", setting.Conf.Disqus.ShortName)
vals.Set("title", artc.Title+" | "+Ei.BTitle)
vals.Set("identifier", "post-"+artc.Slug)
urlPath := fmt.Sprintf("https://%s/post/%s.html", setting.Conf.Mode.Domain, artc.Slug)
vals.Set("url", urlPath)
resp, err := http.PostForm(setting.Conf.Disqus.ThreadCreate, vals)
if err != nil {
return err
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
return errors.New(string(b))
}
result := &threadCreateResp{}
err = json.Unmarshal(b, result)
if err != nil {
return err
}
artc.Thread = result.Response.Id
return nil
}

View File

@@ -8,18 +8,29 @@ func TestDisqus(t *testing.T) {
PostsCount() PostsCount()
} }
func TestPostComment(t *testing.T) { func TestPostCreate(t *testing.T) {
pc := &PostCreate{ pc := &PostComment{
Message: "hahahaha", Message: "hahahaha",
Thread: "52799014", Thread: "52799014",
AuthorEmail: "deepzz.qi@gmail.com", AuthorEmail: "deepzz.qi@gmail.com",
AuthorName: "deepzz", AuthorName: "deepzz",
} }
id, err := PostComment(pc) id, err := PostCreate(pc)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
t.Log("post success", id) t.Log("post success", id)
} }
func TestThreadCreate(t *testing.T) {
tc := &Article{
Title: "测试test7",
Slug: "test7",
}
err := ThreadCreate(tc)
if err != nil {
t.Fatal(err)
}
}

View File

@@ -6,8 +6,6 @@ services:
volumes: volumes:
- /data/eiblog/mgodb:/data/db - /data/eiblog/mgodb:/data/db
restart: always restart: always
ports:
- 27017:27017
elasticsearch: elasticsearch:
image: elasticsearch:2.4.1 image: elasticsearch:2.4.1
container_name: eisearch container_name: eisearch
@@ -35,3 +33,14 @@ services:
ports: ports:
- "9000:9000" - "9000:9000"
restart: always restart: always
# backup:
# image: registry.cn-hangzhou.aliyuncs.com/deepzz/backup
# container_name: backup
# links:
# - mongodb
# environment:
# - QINIU_BUCKET=xxxx
# - QINIU_DOMAIN=xx.example.com
# - ACCESS_KEY=xxxxxxxxxx
# - SECRECT_KEY=xxxxxxxxxx
# restart: always

View File

@@ -18,3 +18,7 @@ twitter:
![twitter-card2](http://7xokm2.com1.z0.glb.clouddn.com/img/twitter-pub2.png) ![twitter-card2](http://7xokm2.com1.z0.glb.clouddn.com/img/twitter-pub2.png)
可以看到``之前是没有内容的,该内容是我们文章的描述。 可以看到``之前是没有内容的,该内容是我们文章的描述。
### Google OpenSearch
在 Chrome 浏览器上,你可以在输入网站后按 TAB 键进入搜索模式,如:
![opensearch](http://7xokm2.com1.z0.glb.clouddn.com/opensearch.gif)

View File

@@ -90,7 +90,7 @@ $ docker run -d --name eisearch \
| ------------------ | ---------------------------------------- | ---------------------------------------- | | ------------------ | ---------------------------------------- | ---------------------------------------- |
| favicon.ico | st.example.com/static/img/favicon.ico | cdn 中的文件名为 `static/img/favicon.ico`。你也可以复制 favicon.ico 到 static 文件夹下,通过 example.com/favicon.ico 也是能够访问到。docker 用户可能需要重新打包镜像。 | | favicon.ico | st.example.com/static/img/favicon.ico | cdn 中的文件名为 `static/img/favicon.ico`。你也可以复制 favicon.ico 到 static 文件夹下,通过 example.com/favicon.ico 也是能够访问到。docker 用户可能需要重新打包镜像。 |
| bg04.jpg | st.example.com/static/img/bg04.jpg | 首页左侧的大背景图,需要更名请到 views/st_blog.css 修改。 | | bg04.jpg | st.example.com/static/img/bg04.jpg | 首页左侧的大背景图,需要更名请到 views/st_blog.css 修改。 |
| avatar.jpg | st.example.com/static/img/avatar.jpg | 头像 | | avatar.png | st.example.com/static/img/avatar.png | 头像 |
| blank.gif | st.example.com/static/img/blank.gif | 空白图片,[下载](https://st.deepzz.com/static/img/blank.gif) | | blank.gif | st.example.com/static/img/blank.gif | 空白图片,[下载](https://st.deepzz.com/static/img/blank.gif) |
| default_avatar.png | st.example.com/static/img/default_avatar.png | disqus 默认图片,[下载](https://st.deepzz.com/static/img/default_avatar.png) | | default_avatar.png | st.example.com/static/img/default_avatar.png | disqus 默认图片,[下载](https://st.deepzz.com/static/img/default_avatar.png) |
| disqus.js | st.example.com/static/js/disqus_xxx.js | disqus 文件,你可以通过 https://short_name.disqus.com/embed.js 下载你的专属文件,并上传到七牛。更新配置文件 app.yml。 | | disqus.js | st.example.com/static/js/disqus_xxx.js | disqus 文件,你可以通过 https://short_name.disqus.com/embed.js 下载你的专属文件,并上传到七牛。更新配置文件 app.yml。 |

View File

@@ -6,6 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net"
"net/http" "net/http"
"regexp" "regexp"
"strings" "strings"
@@ -23,10 +24,20 @@ const (
ES_DATE = `{"range":{"date":{"gte":"%s","lte": "%s","format": "yyyy-MM-dd||yyyy-MM||yyyy"}}}` // 2016-10||/M ES_DATE = `{"range":{"date":{"gte":"%s","lte": "%s","format": "yyyy-MM-dd||yyyy-MM||yyyy"}}}` // 2016-10||/M
) )
var es *ElasticService var (
ErrUninitializedES = errors.New("uninitialized elasticsearch")
es *ElasticService
)
// 初始化 Elasticsearch 服务器 // 初始化 Elasticsearch 服务器
func init() { func init() {
_, err := net.LookupIP("elasticsearch")
if err != nil {
logd.Info(err)
return
}
es = &ElasticService{url: "http://elasticsearch:9200", c: new(http.Client)} es = &ElasticService{url: "http://elasticsearch:9200", c: new(http.Client)}
initIndex() initIndex()
} }
@@ -41,7 +52,11 @@ func initIndex() {
} }
// 查询 // 查询
func Elasticsearch(qStr string, size, from int) *ESSearchResult { func Elasticsearch(qStr string, size, from int) (*ESSearchResult, error) {
if es == nil {
return nil, ErrUninitializedES
}
// 分析查询字符串 // 分析查询字符串
reg := regexp.MustCompile(`(tag|slug|date):`) reg := regexp.MustCompile(`(tag|slug|date):`)
indexs := reg.FindAllStringIndex(qStr, -1) indexs := reg.FindAllStringIndex(qStr, -1)
@@ -92,14 +107,17 @@ func Elasticsearch(qStr string, size, from int) *ESSearchResult {
} }
docs, err := IndexQueryDSL(INDEX, TYPE, size, from, []byte(dsl)) docs, err := IndexQueryDSL(INDEX, TYPE, size, from, []byte(dsl))
if err != nil { if err != nil {
logd.Error(err) return nil, err
return nil
} }
return docs return docs, nil
} }
// 添加或更新索引 // 添加或更新索引
func ElasticIndex(artc *Article) error { func ElasticIndex(artc *Article) error {
if es == nil {
return ErrUninitializedES
}
img := PickFirstImage(artc.Content) img := PickFirstImage(artc.Content)
mapping := map[string]interface{}{ mapping := map[string]interface{}{
"title": artc.Title, "title": artc.Title,
@@ -115,6 +133,10 @@ func ElasticIndex(artc *Article) error {
// 删除索引 // 删除索引
func ElasticDelIndex(ids []int32) error { func ElasticDelIndex(ids []int32) error {
if es == nil {
return ErrUninitializedES
}
var target []string var target []string
for _, id := range ids { for _, id := range ids {
target = append(target, fmt.Sprint(id)) target = append(target, fmt.Sprint(id))

View File

@@ -199,10 +199,12 @@ func HandleSearchPage(c *gin.Context) {
start = 1 start = 1
} }
h["Word"] = q h["Word"] = q
var result *ESSearchResult
vals := c.Request.URL.Query() vals := c.Request.URL.Query()
result = Elasticsearch(q, setting.Conf.General.PageNum, start-1) result, err := Elasticsearch(q, setting.Conf.General.PageNum, start-1)
if result != nil { if err != nil {
logd.Error(err)
} else {
result.Took /= 1000 result.Took /= 1000
for i, v := range result.Hits.Hits { for i, v := range result.Hits.Hits {
if artc := Ei.MapArticles[result.Hits.Hits[i].Source.Slug]; len(v.Highlight.Content) == 0 && artc != nil { if artc := Ei.MapArticles[result.Hits.Hits[i].Source.Slug]; len(v.Highlight.Content) == 0 && artc != nil {
@@ -379,7 +381,8 @@ func HandleDisqus(c *gin.Context) {
} }
// 发表评论 // 发表评论
// [thread:[5279901489] parent:[] identifier:[post-troubleshooting-https] next:[] author_name:[你好] author_email:[chenqijing2@163.com] message:[fdsfdsf]] // [thread:[5279901489] parent:[] identifier:[post-troubleshooting-https]
// next:[] author_name:[你好] author_email:[chenqijing2@163.com] message:[fdsfdsf]]
type DisqusCreate struct { type DisqusCreate struct {
ErrNo int `json:"errno"` ErrNo int `json:"errno"`
ErrMsg string `json:"errmsg"` ErrMsg string `json:"errmsg"`
@@ -400,7 +403,7 @@ func HandleDisqusCreate(c *gin.Context) {
resp.ErrMsg = "参数错误" resp.ErrMsg = "参数错误"
return return
} }
pc := &PostCreate{ pc := &PostComment{
Message: msg, Message: msg,
Parent: c.PostForm("parent"), Parent: c.PostForm("parent"),
Thread: thread, Thread: thread,
@@ -410,7 +413,7 @@ func HandleDisqusCreate(c *gin.Context) {
IpAddress: c.ClientIP(), IpAddress: c.ClientIP(),
} }
postDetail, err := PostComment(pc) postDetail, err := PostCreate(pc)
if err != nil { if err != nil {
logd.Error(err) logd.Error(err)
resp.ErrNo = FAIL resp.ErrNo = FAIL

18
glide.lock generated
View File

@@ -1,28 +1,28 @@
hash: c733fa4abeda21b59b001578b37a168bd33038d337b61198cc5fd94be8bfdf77 hash: c733fa4abeda21b59b001578b37a168bd33038d337b61198cc5fd94be8bfdf77
updated: 2017-11-05T12:08:01.167405372+08:00 updated: 2018-01-13T18:22:28.620808+08:00
imports: imports:
- name: github.com/boj/redistore - name: github.com/boj/redistore
version: 4562487a4bee9a7c272b72bfaeda4917d0a47ab9 version: 4562487a4bee9a7c272b72bfaeda4917d0a47ab9
- name: github.com/deepzz0/logd - name: github.com/deepzz0/logd
version: 2bbe53d047054777f3a171cdfc6dca7aa9f8af78 version: f91dd8c6316f0e156e93895a96739b67577b6a63
- name: github.com/eiblog/blackfriday - name: github.com/eiblog/blackfriday
version: c0ec111761ae784fe31cc076f2fa0e2d2216d623 version: c0ec111761ae784fe31cc076f2fa0e2d2216d623
- name: github.com/eiblog/utils - name: github.com/eiblog/utils
version: ddfd888542f9a093000f71c3709009c1440a0789 version: e8f16268dae939f920ddc55f1c9e46a97a5e3559
subpackages: subpackages:
- logd - logd
- mgo - mgo
- tmpl - tmpl
- uuid - uuid
- name: github.com/garyburd/redigo - name: github.com/garyburd/redigo
version: 47dc60e71eed504e3ef8e77ee3c6fe720f3be57f version: d1ed5c67e5794de818ea85e6b522fda02623a484
subpackages: subpackages:
- internal - internal
- redis - redis
- name: github.com/gin-gonic/autotls - name: github.com/gin-gonic/autotls
version: 8ca25fbde72bb72a00466215b94b489c71fcb815 version: 8ca25fbde72bb72a00466215b94b489c71fcb815
- name: github.com/gin-gonic/contrib - name: github.com/gin-gonic/contrib
version: 5aa1e38d1d932e45fa5032bd1b8739e1a548e596 version: 88aede40372d4bcb11e45168a8c30d99e44cf617
subpackages: subpackages:
- sessions - sessions
- name: github.com/gin-gonic/gin - name: github.com/gin-gonic/gin
@@ -43,7 +43,7 @@ imports:
- name: github.com/manucorporat/sse - name: github.com/manucorporat/sse
version: ee05b128a739a0fb76c7ebd3ae4810c1de808d6d version: ee05b128a739a0fb76c7ebd3ae4810c1de808d6d
- name: github.com/mattn/go-isatty - name: github.com/mattn/go-isatty
version: a5cdd64afdee435007ee3e9f6ed4684af949d568 version: 6ca4dbf54d38eea1a992b3c722a76a5d1c4cb25c
- name: github.com/qiniu/api.v7 - name: github.com/qiniu/api.v7
version: b7c7d6a2ce0aff8e5e7d14c39c3cde867efa1123 version: b7c7d6a2ce0aff8e5e7d14c39c3cde867efa1123
subpackages: subpackages:
@@ -61,7 +61,7 @@ imports:
- name: github.com/shurcooL/sanitized_anchor_name - name: github.com/shurcooL/sanitized_anchor_name
version: 86672fcb3f950f35f2e675df2240550f2a50762f version: 86672fcb3f950f35f2e675df2240550f2a50762f
- name: golang.org/x/crypto - name: golang.org/x/crypto
version: bd6f299fb381e4c3393d1c4b1f0b94f5e77650c8 version: 13931e22f9e72ea58bb73048bc752b48c6d4d4ac
subpackages: subpackages:
- acme - acme
- acme/autocert - acme/autocert
@@ -70,7 +70,7 @@ imports:
subpackages: subpackages:
- context - context
- name: golang.org/x/sys - name: golang.org/x/sys
version: 8eb05f94d449fdf134ec24630ce69ada5b469c1c version: 810d7000345868fc619eb81f46307107118f4ae1
subpackages: subpackages:
- unix - unix
- name: gopkg.in/go-playground/validator.v8 - name: gopkg.in/go-playground/validator.v8
@@ -83,7 +83,7 @@ imports:
- internal/sasl - internal/sasl
- internal/scram - internal/scram
- name: gopkg.in/yaml.v2 - name: gopkg.in/yaml.v2
version: eb3733d160e74a9c7e442f435eb3bea458e1d19f version: d670f9405373e636a5a2765eea47fac0c9bc91a4
- name: qiniupkg.com/x - name: qiniupkg.com/x
version: 946c4a16076d6d98aeb78619e2bd4012357f7228 version: 946c4a16076d6d98aeb78619e2bd4012357f7228
subpackages: subpackages:

View File

@@ -3,6 +3,7 @@ package main
import ( import (
"bytes" "bytes"
"encoding/xml" "encoding/xml"
"fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
@@ -64,7 +65,7 @@ func (p *pingRPC) PingFunc(slug string) {
if len(setting.Conf.PingRPCs) == 0 { if len(setting.Conf.PingRPCs) == 0 {
return return
} }
p.Params.Param[1].Value = "https://" + setting.Conf.Mode.Domain + "/post/" + slug + ".html" p.Params.Param[1].Value = fmt.Sprintf("https://%s/post/%s.html", setting.Conf.Mode.Domain, slug)
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
buf.WriteString(xml.Header) buf.WriteString(xml.Header)
enc := xml.NewEncoder(buf) enc := xml.NewEncoder(buf)

View File

@@ -4,7 +4,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net/url"
"path/filepath" "path/filepath"
"github.com/eiblog/eiblog/setting" "github.com/eiblog/eiblog/setting"
@@ -12,18 +11,6 @@ import (
"github.com/qiniu/api.v7/storage" "github.com/qiniu/api.v7/storage"
) )
type bucket struct {
name string
domain string
accessKey string
secretKey string
}
type PutRet struct {
Hash string `json:"hash"`
Key string `json:"key"`
}
// 进度条 // 进度条
func onProgress(fsize, uploaded int64) { func onProgress(fsize, uploaded int64) {
d := int(float64(uploaded) / float64(fsize) * 100) d := int(float64(uploaded) / float64(fsize) * 100)
@@ -41,10 +28,6 @@ func FileUpload(name string, size int64, data io.Reader) (string, error) {
} }
key := getKey(name) key := getKey(name)
if key == "" {
return "", errors.New("不支持的文件类型")
}
mac := qbox.NewMac(setting.Conf.Qiniu.AccessKey, setting.Conf.Qiniu.SecretKey) mac := qbox.NewMac(setting.Conf.Qiniu.AccessKey, setting.Conf.Qiniu.SecretKey)
// 设置上传的策略 // 设置上传的策略
putPolicy := &storage.PutPolicy{ putPolicy := &storage.PutPolicy{
@@ -63,22 +46,20 @@ func FileUpload(name string, size int64, data io.Reader) (string, error) {
// uploader // uploader
uploader := storage.NewFormUploader(cfg) uploader := storage.NewFormUploader(cfg)
ret := new(storage.PutRet) ret := new(storage.PutRet)
putExtra := &storage.PutExtra{OnProgress: onProgress} putExtra := &storage.PutExtra{}
err := uploader.Put(nil, ret, upToken, key, data, size, putExtra) err := uploader.Put(nil, ret, upToken, key, data, size, putExtra)
if err != nil { if err != nil {
return "", err return "", err
} }
url := "https://" + setting.Conf.Qiniu.Domain + "/" + url.QueryEscape(key) url := "https://" + setting.Conf.Qiniu.Domain + "/" + key
return url, nil return url, nil
} }
// 删除文件 // 删除文件
func FileDelete(name string) error { func FileDelete(name string) error {
key := getKey(name) key := getKey(name)
if key == "" {
return errors.New("不支持的文件类型")
}
mac := qbox.NewMac(setting.Conf.Qiniu.AccessKey, setting.Conf.Qiniu.SecretKey) mac := qbox.NewMac(setting.Conf.Qiniu.AccessKey, setting.Conf.Qiniu.SecretKey)
// 上传配置 // 上传配置
@@ -105,9 +86,12 @@ func getKey(name string) string {
key = "blog/img/" + name key = "blog/img/" + name
case ".mov", ".mp4": case ".mov", ".mp4":
key = "blog/video/" + name key = "blog/video/" + name
case ".go", ".js", ".css", ".cpp", ".php", ".rb", ".java", ".py", ".sql", ".lua", ".html", ".sh", ".xml", ".cs": case ".go", ".js", ".css", ".cpp", ".php", ".rb",
".java", ".py", ".sql", ".lua", ".html",
".sh", ".xml", ".cs":
key = "blog/code/" + name key = "blog/code/" + name
case ".txt", ".md", ".ini", ".yaml", ".yml", ".doc", ".ppt", ".pdf": case ".txt", ".md", ".ini", ".yaml", ".yml",
".doc", ".ppt", ".pdf":
key = "blog/document/" + name key = "blog/document/" + name
case ".zip", ".rar", ".tar", ".gz": case ".zip", ".rar", ".tar", ".gz":
key = "blog/archive/" + name key = "blog/archive/" + name

View File

@@ -2,8 +2,9 @@
package main package main
import ( import (
"crypto/rand"
"fmt" "fmt"
"html/template" "text/template"
"time" "time"
"github.com/eiblog/eiblog/setting" "github.com/eiblog/eiblog/setting"
@@ -27,7 +28,12 @@ func init() {
} }
router = gin.Default() router = gin.Default()
store := sessions.NewCookieStore([]byte("eiblog321")) b := make([]byte, 16)
_, err := rand.Read(b)
if err != nil {
logd.Fatal(err)
}
store := sessions.NewCookieStore(b)
store.Options(sessions.Options{ store.Options(sessions.Options{
MaxAge: 86400 * 7, MaxAge: 86400 * 7,
Path: "/", Path: "/",
@@ -43,7 +49,7 @@ func init() {
} }
return false return false
}) })
_, err := Tmpl.ParseFiles(files...) _, err = Tmpl.ParseFiles(files...)
if err != nil { if err != nil {
logd.Fatal(err) logd.Fatal(err)
} }

View File

@@ -42,6 +42,7 @@ type Config struct {
PostsList string PostsList string
PostCreate string PostCreate string
PostApprove string PostApprove string
ThreadCreate string
Embed string Embed string
Interval int Interval int
} }

View File

@@ -67,10 +67,20 @@ type LogOption struct {
Mails Emailer // 告警邮件 Mails Emailer // 告警邮件
} }
func osSep() string {
var sep string
if os.IsPathSeparator('\\') {
sep = "\\"
} else {
sep = "/"
}
return sep
}
// 新建日志打印器 // 新建日志打印器
func New(option LogOption) *Logger { func New(option LogOption) *Logger {
wd, _ := os.Getwd() wd, _ := os.Getwd()
index := strings.LastIndex(wd, "/") index := strings.LastIndex(wd, osSep())
logger := &Logger{ logger := &Logger{
obj: wd[index+1:], obj: wd[index+1:],
out: option.Out, out: option.Out,

View File

@@ -29,6 +29,10 @@ import (
"time" "time"
) )
var (
_ ConnWithTimeout = (*conn)(nil)
)
// conn is the low-level implementation of Conn // conn is the low-level implementation of Conn
type conn struct { type conn struct {
// Shared // Shared
@@ -72,6 +76,7 @@ type DialOption struct {
type dialOptions struct { type dialOptions struct {
readTimeout time.Duration readTimeout time.Duration
writeTimeout time.Duration writeTimeout time.Duration
dialer *net.Dialer
dial func(network, addr string) (net.Conn, error) dial func(network, addr string) (net.Conn, error)
db int db int
password string password string
@@ -94,17 +99,27 @@ func DialWriteTimeout(d time.Duration) DialOption {
}} }}
} }
// DialConnectTimeout specifies the timeout for connecting to the Redis server. // DialConnectTimeout specifies the timeout for connecting to the Redis server when
// no DialNetDial option is specified.
func DialConnectTimeout(d time.Duration) DialOption { func DialConnectTimeout(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) { return DialOption{func(do *dialOptions) {
dialer := net.Dialer{Timeout: d} do.dialer.Timeout = d
do.dial = dialer.Dial }}
}
// DialKeepAlive specifies the keep-alive period for TCP connections to the Redis server
// when no DialNetDial option is specified.
// If zero, keep-alives are not enabled. If no DialKeepAlive option is specified then
// the default of 5 minutes is used to ensure that half-closed TCP sessions are detected.
func DialKeepAlive(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) {
do.dialer.KeepAlive = d
}} }}
} }
// DialNetDial specifies a custom dial function for creating TCP // DialNetDial specifies a custom dial function for creating TCP
// connections. If this option is left out, then net.Dial is // connections, otherwise a net.Dialer customized via the other options is used.
// used. DialNetDial overrides DialConnectTimeout. // DialNetDial overrides DialConnectTimeout and DialKeepAlive.
func DialNetDial(dial func(network, addr string) (net.Conn, error)) DialOption { func DialNetDial(dial func(network, addr string) (net.Conn, error)) DialOption {
return DialOption{func(do *dialOptions) { return DialOption{func(do *dialOptions) {
do.dial = dial do.dial = dial
@@ -154,11 +169,16 @@ func DialUseTLS(useTLS bool) DialOption {
// address using the specified options. // address using the specified options.
func Dial(network, address string, options ...DialOption) (Conn, error) { func Dial(network, address string, options ...DialOption) (Conn, error) {
do := dialOptions{ do := dialOptions{
dial: net.Dial, dialer: &net.Dialer{
KeepAlive: time.Minute * 5,
},
} }
for _, option := range options { for _, option := range options {
option.f(&do) option.f(&do)
} }
if do.dial == nil {
do.dial = do.dialer.Dial
}
netConn, err := do.dial(network, address) netConn, err := do.dial(network, address)
if err != nil { if err != nil {
@@ -166,7 +186,12 @@ func Dial(network, address string, options ...DialOption) (Conn, error) {
} }
if do.useTLS { if do.useTLS {
tlsConfig := cloneTLSClientConfig(do.tlsConfig, do.skipVerify) var tlsConfig *tls.Config
if do.tlsConfig == nil {
tlsConfig = &tls.Config{InsecureSkipVerify: do.skipVerify}
} else {
tlsConfig = cloneTLSConfig(do.tlsConfig)
}
if tlsConfig.ServerName == "" { if tlsConfig.ServerName == "" {
host, _, err := net.SplitHostPort(address) host, _, err := net.SplitHostPort(address)
if err != nil { if err != nil {
@@ -555,10 +580,17 @@ func (c *conn) Flush() error {
return nil return nil
} }
func (c *conn) Receive() (reply interface{}, err error) { func (c *conn) Receive() (interface{}, error) {
if c.readTimeout != 0 { return c.ReceiveWithTimeout(c.readTimeout)
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
} }
func (c *conn) ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error) {
var deadline time.Time
if timeout != 0 {
deadline = time.Now().Add(timeout)
}
c.conn.SetReadDeadline(deadline)
if reply, err = c.readReply(); err != nil { if reply, err = c.readReply(); err != nil {
return nil, c.fatal(err) return nil, c.fatal(err)
} }
@@ -581,6 +613,10 @@ func (c *conn) Receive() (reply interface{}, err error) {
} }
func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) { func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
return c.DoWithTimeout(c.readTimeout, cmd, args...)
}
func (c *conn) DoWithTimeout(readTimeout time.Duration, cmd string, args ...interface{}) (interface{}, error) {
c.mu.Lock() c.mu.Lock()
pending := c.pending pending := c.pending
c.pending = 0 c.pending = 0
@@ -604,9 +640,11 @@ func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
return nil, c.fatal(err) return nil, c.fatal(err)
} }
if c.readTimeout != 0 { var deadline time.Time
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) if readTimeout != 0 {
deadline = time.Now().Add(readTimeout)
} }
c.conn.SetReadDeadline(deadline)
if cmd == "" { if cmd == "" {
reply := make([]interface{}, pending) reply := make([]interface{}, pending)

View File

@@ -34,14 +34,16 @@ import (
type testConn struct { type testConn struct {
io.Reader io.Reader
io.Writer io.Writer
readDeadline time.Time
writeDeadline time.Time
} }
func (*testConn) Close() error { return nil } func (*testConn) Close() error { return nil }
func (*testConn) LocalAddr() net.Addr { return nil } func (*testConn) LocalAddr() net.Addr { return nil }
func (*testConn) RemoteAddr() net.Addr { return nil } func (*testConn) RemoteAddr() net.Addr { return nil }
func (*testConn) SetDeadline(t time.Time) error { return nil } func (c *testConn) SetDeadline(t time.Time) error { c.readDeadline = t; c.writeDeadline = t; return nil }
func (*testConn) SetReadDeadline(t time.Time) error { return nil } func (c *testConn) SetReadDeadline(t time.Time) error { c.readDeadline = t; return nil }
func (*testConn) SetWriteDeadline(t time.Time) error { return nil } func (c *testConn) SetWriteDeadline(t time.Time) error { c.writeDeadline = t; return nil }
func dialTestConn(r string, w io.Writer) redis.DialOption { func dialTestConn(r string, w io.Writer) redis.DialOption {
return redis.DialNetDial(func(network, addr string) (net.Conn, error) { return redis.DialNetDial(func(network, addr string) (net.Conn, error) {
@@ -821,3 +823,45 @@ Bjqn3yoLHaoZVvbWOi0C2TCN4FjXjaLNZGifQPbIcaA=
clientTLSConfig.RootCAs = x509.NewCertPool() clientTLSConfig.RootCAs = x509.NewCertPool()
clientTLSConfig.RootCAs.AddCert(certificate) clientTLSConfig.RootCAs.AddCert(certificate)
} }
func TestWithTimeout(t *testing.T) {
for _, recv := range []bool{true, false} {
for _, defaultTimout := range []time.Duration{0, time.Minute} {
var buf bytes.Buffer
nc := &testConn{Reader: strings.NewReader("+OK\r\n+OK\r\n+OK\r\n+OK\r\n+OK\r\n+OK\r\n+OK\r\n+OK\r\n+OK\r\n+OK\r\n"), Writer: &buf}
c, _ := redis.Dial("", "", redis.DialReadTimeout(defaultTimout), redis.DialNetDial(func(network, addr string) (net.Conn, error) { return nc, nil }))
for i := 0; i < 4; i++ {
var minDeadline, maxDeadline time.Time
// Alternate between default and specified timeout.
if i%2 == 0 {
if defaultTimout != 0 {
minDeadline = time.Now().Add(defaultTimout)
}
if recv {
c.Receive()
} else {
c.Do("PING")
}
if defaultTimout != 0 {
maxDeadline = time.Now().Add(defaultTimout)
}
} else {
timeout := 10 * time.Minute
minDeadline = time.Now().Add(timeout)
if recv {
redis.ReceiveWithTimeout(c, timeout)
} else {
redis.DoWithTimeout(c, timeout, "PING")
}
maxDeadline = time.Now().Add(timeout)
}
// Expect set deadline in expected range.
if nc.readDeadline.Before(minDeadline) || nc.readDeadline.After(maxDeadline) {
t.Errorf("recv %v, %d: do deadline error: %v, %v, %v", recv, i, minDeadline, nc.readDeadline, maxDeadline)
}
}
}
}
}

View File

@@ -4,11 +4,7 @@ package redis
import "crypto/tls" import "crypto/tls"
// similar cloneTLSClientConfig in the stdlib, but also honor skipVerify for the nil case func cloneTLSConfig(cfg *tls.Config) *tls.Config {
func cloneTLSClientConfig(cfg *tls.Config, skipVerify bool) *tls.Config {
if cfg == nil {
return &tls.Config{InsecureSkipVerify: skipVerify}
}
return &tls.Config{ return &tls.Config{
Rand: cfg.Rand, Rand: cfg.Rand,
Time: cfg.Time, Time: cfg.Time,

View File

@@ -1,14 +1,10 @@
// +build go1.7 // +build go1.7,!go1.8
package redis package redis
import "crypto/tls" import "crypto/tls"
// similar cloneTLSClientConfig in the stdlib, but also honor skipVerify for the nil case func cloneTLSConfig(cfg *tls.Config) *tls.Config {
func cloneTLSClientConfig(cfg *tls.Config, skipVerify bool) *tls.Config {
if cfg == nil {
return &tls.Config{InsecureSkipVerify: skipVerify}
}
return &tls.Config{ return &tls.Config{
Rand: cfg.Rand, Rand: cfg.Rand,
Time: cfg.Time, Time: cfg.Time,

9
vendor/github.com/garyburd/redigo/redis/go18.go generated vendored Normal file
View File

@@ -0,0 +1,9 @@
// +build go1.8
package redis
import "crypto/tls"
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
return cfg.Clone()
}

View File

@@ -18,6 +18,11 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"log" "log"
"time"
)
var (
_ ConnWithTimeout = (*loggingConn)(nil)
) )
// NewLoggingConn returns a logging wrapper around a connection. // NewLoggingConn returns a logging wrapper around a connection.
@@ -104,6 +109,12 @@ func (c *loggingConn) Do(commandName string, args ...interface{}) (interface{},
return reply, err return reply, err
} }
func (c *loggingConn) DoWithTimeout(timeout time.Duration, commandName string, args ...interface{}) (interface{}, error) {
reply, err := DoWithTimeout(c.Conn, timeout, commandName, args...)
c.print("DoWithTimeout", commandName, args, reply, err)
return reply, err
}
func (c *loggingConn) Send(commandName string, args ...interface{}) error { func (c *loggingConn) Send(commandName string, args ...interface{}) error {
err := c.Conn.Send(commandName, args...) err := c.Conn.Send(commandName, args...)
c.print("Send", commandName, args, nil, err) c.print("Send", commandName, args, nil, err)
@@ -115,3 +126,9 @@ func (c *loggingConn) Receive() (interface{}, error) {
c.print("Receive", "", nil, reply, err) c.print("Receive", "", nil, reply, err)
return reply, err return reply, err
} }
func (c *loggingConn) ReceiveWithTimeout(timeout time.Duration) (interface{}, error) {
reply, err := ReceiveWithTimeout(c.Conn, timeout)
c.print("ReceiveWithTimeout", "", nil, reply, err)
return reply, err
}

View File

@@ -28,6 +28,11 @@ import (
"github.com/garyburd/redigo/internal" "github.com/garyburd/redigo/internal"
) )
var (
_ ConnWithTimeout = (*pooledConnection)(nil)
_ ConnWithTimeout = (*errorConnection)(nil)
)
var nowFunc = time.Now // for testing var nowFunc = time.Now // for testing
// ErrPoolExhausted is returned from a pool connection method (Do, Send, // ErrPoolExhausted is returned from a pool connection method (Do, Send,
@@ -96,7 +101,7 @@ var (
// return nil, err // return nil, err
// } // }
// return c, nil // return c, nil
// } // },
// } // }
// //
// Use the TestOnBorrow function to check the health of an idle connection // Use the TestOnBorrow function to check the health of an idle connection
@@ -418,6 +423,16 @@ func (pc *pooledConnection) Do(commandName string, args ...interface{}) (reply i
return pc.c.Do(commandName, args...) return pc.c.Do(commandName, args...)
} }
func (pc *pooledConnection) DoWithTimeout(timeout time.Duration, commandName string, args ...interface{}) (reply interface{}, err error) {
cwt, ok := pc.c.(ConnWithTimeout)
if !ok {
return nil, errTimeoutNotSupported
}
ci := internal.LookupCommandInfo(commandName)
pc.state = (pc.state | ci.Set) &^ ci.Clear
return cwt.DoWithTimeout(timeout, commandName, args...)
}
func (pc *pooledConnection) Send(commandName string, args ...interface{}) error { func (pc *pooledConnection) Send(commandName string, args ...interface{}) error {
ci := internal.LookupCommandInfo(commandName) ci := internal.LookupCommandInfo(commandName)
pc.state = (pc.state | ci.Set) &^ ci.Clear pc.state = (pc.state | ci.Set) &^ ci.Clear
@@ -432,11 +447,23 @@ func (pc *pooledConnection) Receive() (reply interface{}, err error) {
return pc.c.Receive() return pc.c.Receive()
} }
func (pc *pooledConnection) ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error) {
cwt, ok := pc.c.(ConnWithTimeout)
if !ok {
return nil, errTimeoutNotSupported
}
return cwt.ReceiveWithTimeout(timeout)
}
type errorConnection struct{ err error } type errorConnection struct{ err error }
func (ec errorConnection) Do(string, ...interface{}) (interface{}, error) { return nil, ec.err } func (ec errorConnection) Do(string, ...interface{}) (interface{}, error) { return nil, ec.err }
func (ec errorConnection) DoWithTimeout(time.Duration, string, ...interface{}) (interface{}, error) {
return nil, ec.err
}
func (ec errorConnection) Send(string, ...interface{}) error { return ec.err } func (ec errorConnection) Send(string, ...interface{}) error { return ec.err }
func (ec errorConnection) Err() error { return ec.err } func (ec errorConnection) Err() error { return ec.err }
func (ec errorConnection) Close() error { return ec.err } func (ec errorConnection) Close() error { return nil }
func (ec errorConnection) Flush() error { return ec.err } func (ec errorConnection) Flush() error { return ec.err }
func (ec errorConnection) Receive() (interface{}, error) { return nil, ec.err } func (ec errorConnection) Receive() (interface{}, error) { return nil, ec.err }
func (ec errorConnection) ReceiveWithTimeout(time.Duration) (interface{}, error) { return nil, ec.err }

View File

@@ -14,7 +14,10 @@
package redis package redis
import "errors" import (
"errors"
"time"
)
// Subscription represents a subscribe or unsubscribe notification. // Subscription represents a subscribe or unsubscribe notification.
type Subscription struct { type Subscription struct {
@@ -103,7 +106,17 @@ func (c PubSubConn) Ping(data string) error {
// or error. The return value is intended to be used directly in a type switch // or error. The return value is intended to be used directly in a type switch
// as illustrated in the PubSubConn example. // as illustrated in the PubSubConn example.
func (c PubSubConn) Receive() interface{} { func (c PubSubConn) Receive() interface{} {
reply, err := Values(c.Conn.Receive()) return c.receiveInternal(c.Conn.Receive())
}
// ReceiveWithTimeout is like Receive, but it allows the application to
// override the connection's default timeout.
func (c PubSubConn) ReceiveWithTimeout(timeout time.Duration) interface{} {
return c.receiveInternal(ReceiveWithTimeout(c.Conn, timeout))
}
func (c PubSubConn) receiveInternal(replyArg interface{}, errArg error) interface{} {
reply, err := Values(replyArg, errArg)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -0,0 +1,165 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
// +build go1.7
package redis_test
import (
"context"
"fmt"
"time"
"github.com/garyburd/redigo/redis"
)
// listenPubSubChannels listens for messages on Redis pubsub channels. The
// onStart function is called after the channels are subscribed. The onMessage
// function is called for each message.
func listenPubSubChannels(ctx context.Context, redisServerAddr string,
onStart func() error,
onMessage func(channel string, data []byte) error,
channels ...string) error {
// A ping is set to the server with this period to test for the health of
// the connection and server.
const healthCheckPeriod = time.Minute
c, err := redis.Dial("tcp", redisServerAddr,
// Read timeout on server should be greater than ping period.
redis.DialReadTimeout(healthCheckPeriod+10*time.Second),
redis.DialWriteTimeout(10*time.Second))
if err != nil {
return err
}
defer c.Close()
psc := redis.PubSubConn{Conn: c}
if err := psc.Subscribe(redis.Args{}.AddFlat(channels)...); err != nil {
return err
}
done := make(chan error, 1)
// Start a goroutine to receive notifications from the server.
go func() {
for {
switch n := psc.Receive().(type) {
case error:
done <- n
return
case redis.Message:
if err := onMessage(n.Channel, n.Data); err != nil {
done <- err
return
}
case redis.Subscription:
switch n.Count {
case len(channels):
// Notify application when all channels are subscribed.
if err := onStart(); err != nil {
done <- err
return
}
case 0:
// Return from the goroutine when all channels are unsubscribed.
done <- nil
return
}
}
}
}()
ticker := time.NewTicker(healthCheckPeriod)
defer ticker.Stop()
loop:
for err == nil {
select {
case <-ticker.C:
// Send ping to test health of connection and server. If
// corresponding pong is not received, then receive on the
// connection will timeout and the receive goroutine will exit.
if err = psc.Ping(""); err != nil {
break loop
}
case <-ctx.Done():
break loop
case err := <-done:
// Return error from the receive goroutine.
return err
}
}
// Signal the receiving goroutine to exit by unsubscribing from all channels.
psc.Unsubscribe()
// Wait for goroutine to complete.
return <-done
}
func publish() {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
c.Do("PUBLISH", "c1", "hello")
c.Do("PUBLISH", "c2", "world")
c.Do("PUBLISH", "c1", "goodbye")
}
// This example shows how receive pubsub notifications with cancelation and
// health checks.
func ExamplePubSubConn() {
redisServerAddr, err := serverAddr()
if err != nil {
fmt.Println(err)
return
}
ctx, cancel := context.WithCancel(context.Background())
err = listenPubSubChannels(ctx,
redisServerAddr,
func() error {
// The start callback is a good place to backfill missed
// notifications. For the purpose of this example, a goroutine is
// started to send notifications.
go publish()
return nil
},
func(channel string, message []byte) error {
fmt.Printf("channel: %s, message: %s\n", channel, message)
// For the purpose of this example, cancel the listener's context
// after receiving last message sent by publish().
if string(message) == "goodbye" {
cancel()
}
return nil
},
"c1", "c2")
if err != nil {
fmt.Println(err)
return
}
// Output:
// channel: c1, message: hello
// channel: c2, message: world
// channel: c1, message: goodbye
}

View File

@@ -15,93 +15,13 @@
package redis_test package redis_test
import ( import (
"fmt"
"reflect" "reflect"
"sync"
"testing" "testing"
"time"
"github.com/garyburd/redigo/redis" "github.com/garyburd/redigo/redis"
) )
func publish(channel, value interface{}) {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
c.Do("PUBLISH", channel, value)
}
// Applications can receive pushed messages from one goroutine and manage subscriptions from another goroutine.
func ExamplePubSubConn() {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
var wg sync.WaitGroup
wg.Add(2)
psc := redis.PubSubConn{Conn: c}
// This goroutine receives and prints pushed notifications from the server.
// The goroutine exits when the connection is unsubscribed from all
// channels or there is an error.
go func() {
defer wg.Done()
for {
switch n := psc.Receive().(type) {
case redis.Message:
fmt.Printf("Message: %s %s\n", n.Channel, n.Data)
case redis.PMessage:
fmt.Printf("PMessage: %s %s %s\n", n.Pattern, n.Channel, n.Data)
case redis.Subscription:
fmt.Printf("Subscription: %s %s %d\n", n.Kind, n.Channel, n.Count)
if n.Count == 0 {
return
}
case error:
fmt.Printf("error: %v\n", n)
return
}
}
}()
// This goroutine manages subscriptions for the connection.
go func() {
defer wg.Done()
psc.Subscribe("example")
psc.PSubscribe("p*")
// The following function calls publish a message using another
// connection to the Redis server.
publish("example", "hello")
publish("example", "world")
publish("pexample", "foo")
publish("pexample", "bar")
// Unsubscribe from all connections. This will cause the receiving
// goroutine to exit.
psc.Unsubscribe()
psc.PUnsubscribe()
}()
wg.Wait()
// Output:
// Subscription: subscribe example 1
// Subscription: psubscribe p* 2
// Message: example hello
// Message: example world
// PMessage: p* pexample foo
// PMessage: p* pexample bar
// Subscription: unsubscribe example 1
// Subscription: punsubscribe p* 0
}
func expectPushed(t *testing.T, c redis.PubSubConn, message string, expected interface{}) { func expectPushed(t *testing.T, c redis.PubSubConn, message string, expected interface{}) {
actual := c.Receive() actual := c.Receive()
if !reflect.DeepEqual(actual, expected) { if !reflect.DeepEqual(actual, expected) {
@@ -145,4 +65,10 @@ func TestPushed(t *testing.T) {
c.Conn.Send("PING") c.Conn.Send("PING")
c.Conn.Flush() c.Conn.Flush()
expectPushed(t, c, `Send("PING")`, redis.Pong{}) expectPushed(t, c, `Send("PING")`, redis.Pong{})
c.Ping("timeout")
got := c.ReceiveWithTimeout(time.Minute)
if want := (redis.Pong{Data: "timeout"}); want != got {
t.Errorf("recv /w timeout got %v, want %v", got, want)
}
} }

View File

@@ -14,6 +14,11 @@
package redis package redis
import (
"errors"
"time"
)
// Error represents an error returned in a command reply. // Error represents an error returned in a command reply.
type Error string type Error string
@@ -59,3 +64,54 @@ type Scanner interface {
// loss of information. // loss of information.
RedisScan(src interface{}) error RedisScan(src interface{}) error
} }
// ConnWithTimeout is an optional interface that allows the caller to override
// a connection's default read timeout. This interface is useful for executing
// the BLPOP, BRPOP, BRPOPLPUSH, XREAD and other commands that block at the
// server.
//
// A connection's default read timeout is set with the DialReadTimeout dial
// option. Applications should rely on the default timeout for commands that do
// not block at the server.
//
// All of the Conn implementations in this package satisfy the ConnWithTimeout
// interface.
//
// Use the DoWithTimeout and ReceiveWithTimeout helper functions to simplify
// use of this interface.
type ConnWithTimeout interface {
Conn
// Do sends a command to the server and returns the received reply.
// The timeout overrides the read timeout set when dialing the
// connection.
DoWithTimeout(timeout time.Duration, commandName string, args ...interface{}) (reply interface{}, err error)
// Receive receives a single reply from the Redis server. The timeout
// overrides the read timeout set when dialing the connection.
ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error)
}
var errTimeoutNotSupported = errors.New("redis: connection does not support ConnWithTimeout")
// DoWithTimeout executes a Redis command with the specified read timeout. If
// the connection does not satisfy the ConnWithTimeout interface, then an error
// is returned.
func DoWithTimeout(c Conn, timeout time.Duration, cmd string, args ...interface{}) (interface{}, error) {
cwt, ok := c.(ConnWithTimeout)
if !ok {
return nil, errTimeoutNotSupported
}
return cwt.DoWithTimeout(timeout, cmd, args...)
}
// ReceiveWithTimeout receives a reply with the specified read timeout. If the
// connection does not satisfy the ConnWithTimeout interface, then an error is
// returned.
func ReceiveWithTimeout(c Conn, timeout time.Duration) (interface{}, error) {
cwt, ok := c.(ConnWithTimeout)
if !ok {
return nil, errTimeoutNotSupported
}
return cwt.ReceiveWithTimeout(timeout)
}

71
vendor/github.com/garyburd/redigo/redis/redis_test.go generated vendored Normal file
View File

@@ -0,0 +1,71 @@
// Copyright 2017 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis_test
import (
"testing"
"time"
"github.com/garyburd/redigo/redis"
)
type timeoutTestConn int
func (tc timeoutTestConn) Do(string, ...interface{}) (interface{}, error) {
return time.Duration(-1), nil
}
func (tc timeoutTestConn) DoWithTimeout(timeout time.Duration, cmd string, args ...interface{}) (interface{}, error) {
return timeout, nil
}
func (tc timeoutTestConn) Receive() (interface{}, error) {
return time.Duration(-1), nil
}
func (tc timeoutTestConn) ReceiveWithTimeout(timeout time.Duration) (interface{}, error) {
return timeout, nil
}
func (tc timeoutTestConn) Send(string, ...interface{}) error { return nil }
func (tc timeoutTestConn) Err() error { return nil }
func (tc timeoutTestConn) Close() error { return nil }
func (tc timeoutTestConn) Flush() error { return nil }
func testTimeout(t *testing.T, c redis.Conn) {
r, err := c.Do("PING")
if r != time.Duration(-1) || err != nil {
t.Errorf("Do() = %v, %v, want %v, %v", r, err, time.Duration(-1), nil)
}
r, err = redis.DoWithTimeout(c, time.Minute, "PING")
if r != time.Minute || err != nil {
t.Errorf("DoWithTimeout() = %v, %v, want %v, %v", r, err, time.Minute, nil)
}
r, err = c.Receive()
if r != time.Duration(-1) || err != nil {
t.Errorf("Receive() = %v, %v, want %v, %v", r, err, time.Duration(-1), nil)
}
r, err = redis.ReceiveWithTimeout(c, time.Minute)
if r != time.Minute || err != nil {
t.Errorf("ReceiveWithTimeout() = %v, %v, want %v, %v", r, err, time.Minute, nil)
}
}
func TestConnTimeout(t *testing.T) {
testTimeout(t, timeoutTestConn(0))
}
func TestPoolConnTimeout(t *testing.T) {
p := &redis.Pool{Dial: func() (redis.Conn, error) { return timeoutTestConn(0), nil }}
testTimeout(t, p.Get())
}

View File

@@ -140,6 +140,11 @@ func dial() (redis.Conn, error) {
return redis.DialDefaultServer() return redis.DialDefaultServer()
} }
// serverAddr wraps DefaultServerAddr() with a more suitable function name for examples.
func serverAddr() (string, error) {
return redis.DefaultServerAddr()
}
func ExampleBool() { func ExampleBool() {
c, err := dial() c, err := dial()
if err != nil { if err != nil {

View File

@@ -38,6 +38,7 @@ var (
ErrNegativeInt = errNegativeInt ErrNegativeInt = errNegativeInt
serverPath = flag.String("redis-server", "redis-server", "Path to redis server binary") serverPath = flag.String("redis-server", "redis-server", "Path to redis server binary")
serverAddress = flag.String("redis-address", "127.0.0.1", "The address of the server")
serverBasePort = flag.Int("redis-port", 16379, "Beginning of port range for test servers") serverBasePort = flag.Int("redis-port", 16379, "Beginning of port range for test servers")
serverLogName = flag.String("redis-log", "", "Write Redis server logs to `filename`") serverLogName = flag.String("redis-log", "", "Write Redis server logs to `filename`")
serverLog = ioutil.Discard serverLog = ioutil.Discard
@@ -126,28 +127,32 @@ func stopDefaultServer() {
} }
} }
// startDefaultServer starts the default server if not already running. // DefaultServerAddr starts the test server if not already started and returns
func startDefaultServer() error { // the address of that server.
func DefaultServerAddr() (string, error) {
defaultServerMu.Lock() defaultServerMu.Lock()
defer defaultServerMu.Unlock() defer defaultServerMu.Unlock()
addr := fmt.Sprintf("%v:%d", *serverAddress, *serverBasePort)
if defaultServer != nil || defaultServerErr != nil { if defaultServer != nil || defaultServerErr != nil {
return defaultServerErr return addr, defaultServerErr
} }
defaultServer, defaultServerErr = NewServer( defaultServer, defaultServerErr = NewServer(
"default", "default",
"--port", strconv.Itoa(*serverBasePort), "--port", strconv.Itoa(*serverBasePort),
"--bind", *serverAddress,
"--save", "", "--save", "",
"--appendonly", "no") "--appendonly", "no")
return defaultServerErr return addr, defaultServerErr
} }
// DialDefaultServer starts the test server if not already started and dials a // DialDefaultServer starts the test server if not already started and dials a
// connection to the server. // connection to the server.
func DialDefaultServer() (Conn, error) { func DialDefaultServer() (Conn, error) {
if err := startDefaultServer(); err != nil { addr, err := DefaultServerAddr()
if err != nil {
return nil, err return nil, err
} }
c, err := Dial("tcp", fmt.Sprintf(":%d", *serverBasePort), DialReadTimeout(1*time.Second), DialWriteTimeout(1*time.Second)) c, err := Dial("tcp", addr, DialReadTimeout(1*time.Second), DialWriteTimeout(1*time.Second))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -34,3 +34,5 @@ Each author is responsible of maintaining his own code, although if you submit a
+ [gin-oauth2](https://github.com/zalando/gin-oauth2) - for working with OAuth2 + [gin-oauth2](https://github.com/zalando/gin-oauth2) - for working with OAuth2
+ [static](https://github.com/hyperboloide/static) An alternative static assets handler for the gin framework. + [static](https://github.com/hyperboloide/static) An alternative static assets handler for the gin framework.
+ [xss-mw](https://github.com/dvwright/xss-mw) - XssMw is a middleware designed to "auto remove XSS" from user submitted input + [xss-mw](https://github.com/dvwright/xss-mw) - XssMw is a middleware designed to "auto remove XSS" from user submitted input
+ [gin-helmet](https://github.com/danielkov/gin-helmet) - Collection of simple security middleware.
+ [gin-jwt-session](https://github.com/ScottHuangZL/gin-jwt-session) - middleware to provide JWT/Session/Flashes, easy to use while also provide options for adjust if necessary. Provide sample too.

View File

@@ -10,6 +10,10 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
type loggerEntryWithFields interface {
WithFields(fields logrus.Fields) *logrus.Entry
}
// Ginrus returns a gin.HandlerFunc (middleware) that logs requests using logrus. // Ginrus returns a gin.HandlerFunc (middleware) that logs requests using logrus.
// //
// Requests with errors are logged using logrus.Error(). // Requests with errors are logged using logrus.Error().
@@ -18,7 +22,7 @@ import (
// It receives: // It receives:
// 1. A time package format string (e.g. time.RFC3339). // 1. A time package format string (e.g. time.RFC3339).
// 2. A boolean stating whether to use UTC time zone or local. // 2. A boolean stating whether to use UTC time zone or local.
func Ginrus(logger *logrus.Logger, timeFormat string, utc bool) gin.HandlerFunc { func Ginrus(logger loggerEntryWithFields, timeFormat string, utc bool) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
start := time.Now() start := time.Now()
// some evil middlewares modify this values // some evil middlewares modify this values

View File

@@ -2,6 +2,10 @@ language: go
go: go:
- tip - tip
os:
- linux
- osx
before_install: before_install:
- go get github.com/mattn/goveralls - go get github.com/mattn/goveralls
- go get golang.org/x/tools/cmd/cover - go get golang.org/x/tools/cmd/cover

View File

@@ -3,7 +3,7 @@
package isatty package isatty
// IsCygwinTerminal() return true if the file descriptor is a cygwin or msys2 // IsCygwinTerminal return true if the file descriptor is a cygwin or msys2
// terminal. This is also always false on this environment. // terminal. This is also always false on this environment.
func IsCygwinTerminal(fd uintptr) bool { func IsCygwinTerminal(fd uintptr) bool {
return false return false

View File

@@ -946,7 +946,7 @@ func TestNonce_add(t *testing.T) {
c.addNonce(http.Header{"Replay-Nonce": {}}) c.addNonce(http.Header{"Replay-Nonce": {}})
c.addNonce(http.Header{"Replay-Nonce": {"nonce"}}) c.addNonce(http.Header{"Replay-Nonce": {"nonce"}})
nonces := map[string]struct{}{"nonce": struct{}{}} nonces := map[string]struct{}{"nonce": {}}
if !reflect.DeepEqual(c.nonces, nonces) { if !reflect.DeepEqual(c.nonces, nonces) {
t.Errorf("c.nonces = %q; want %q", c.nonces, nonces) t.Errorf("c.nonces = %q; want %q", c.nonces, nonces)
} }

View File

@@ -24,7 +24,9 @@ import (
"fmt" "fmt"
"io" "io"
mathrand "math/rand" mathrand "math/rand"
"net"
"net/http" "net/http"
"path"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@@ -80,8 +82,9 @@ func defaultHostPolicy(context.Context, string) error {
} }
// Manager is a stateful certificate manager built on top of acme.Client. // Manager is a stateful certificate manager built on top of acme.Client.
// It obtains and refreshes certificates automatically, // It obtains and refreshes certificates automatically using "tls-sni-01",
// as well as providing them to a TLS server via tls.Config. // "tls-sni-02" and "http-01" challenge types, as well as providing them
// to a TLS server via tls.Config.
// //
// You must specify a cache implementation, such as DirCache, // You must specify a cache implementation, such as DirCache,
// to reuse obtained certificates across program restarts. // to reuse obtained certificates across program restarts.
@@ -150,15 +153,26 @@ type Manager struct {
stateMu sync.Mutex stateMu sync.Mutex
state map[string]*certState // keyed by domain name state map[string]*certState // keyed by domain name
// tokenCert is keyed by token domain name, which matches server name
// of ClientHello. Keys always have ".acme.invalid" suffix.
tokenCertMu sync.RWMutex
tokenCert map[string]*tls.Certificate
// renewal tracks the set of domains currently running renewal timers. // renewal tracks the set of domains currently running renewal timers.
// It is keyed by domain name. // It is keyed by domain name.
renewalMu sync.Mutex renewalMu sync.Mutex
renewal map[string]*domainRenewal renewal map[string]*domainRenewal
// tokensMu guards the rest of the fields: tryHTTP01, certTokens and httpTokens.
tokensMu sync.RWMutex
// tryHTTP01 indicates whether the Manager should try "http-01" challenge type
// during the authorization flow.
tryHTTP01 bool
// httpTokens contains response body values for http-01 challenges
// and is keyed by the URL path at which a challenge response is expected
// to be provisioned.
// The entries are stored for the duration of the authorization flow.
httpTokens map[string][]byte
// certTokens contains temporary certificates for tls-sni challenges
// and is keyed by token domain name, which matches server name of ClientHello.
// Keys always have ".acme.invalid" suffix.
// The entries are stored for the duration of the authorization flow.
certTokens map[string]*tls.Certificate
} }
// GetCertificate implements the tls.Config.GetCertificate hook. // GetCertificate implements the tls.Config.GetCertificate hook.
@@ -185,14 +199,16 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
return nil, errors.New("acme/autocert: server name contains invalid character") return nil, errors.New("acme/autocert: server name contains invalid character")
} }
// In the worst-case scenario, the timeout needs to account for caching, host policy,
// domain ownership verification and certificate issuance.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel() defer cancel()
// check whether this is a token cert requested for TLS-SNI challenge // check whether this is a token cert requested for TLS-SNI challenge
if strings.HasSuffix(name, ".acme.invalid") { if strings.HasSuffix(name, ".acme.invalid") {
m.tokenCertMu.RLock() m.tokensMu.RLock()
defer m.tokenCertMu.RUnlock() defer m.tokensMu.RUnlock()
if cert := m.tokenCert[name]; cert != nil { if cert := m.certTokens[name]; cert != nil {
return cert, nil return cert, nil
} }
if cert, err := m.cacheGet(ctx, name); err == nil { if cert, err := m.cacheGet(ctx, name); err == nil {
@@ -224,6 +240,68 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
return cert, nil return cert, nil
} }
// HTTPHandler configures the Manager to provision ACME "http-01" challenge responses.
// It returns an http.Handler that responds to the challenges and must be
// running on port 80. If it receives a request that is not an ACME challenge,
// it delegates the request to the optional fallback handler.
//
// If fallback is nil, the returned handler redirects all GET and HEAD requests
// to the default TLS port 443 with 302 Found status code, preserving the original
// request path and query. It responds with 400 Bad Request to all other HTTP methods.
// The fallback is not protected by the optional HostPolicy.
//
// Because the fallback handler is run with unencrypted port 80 requests,
// the fallback should not serve TLS-only requests.
//
// If HTTPHandler is never called, the Manager will only use TLS SNI
// challenges for domain verification.
func (m *Manager) HTTPHandler(fallback http.Handler) http.Handler {
m.tokensMu.Lock()
defer m.tokensMu.Unlock()
m.tryHTTP01 = true
if fallback == nil {
fallback = http.HandlerFunc(handleHTTPRedirect)
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.URL.Path, "/.well-known/acme-challenge/") {
fallback.ServeHTTP(w, r)
return
}
// A reasonable context timeout for cache and host policy only,
// because we don't wait for a new certificate issuance here.
ctx, cancel := context.WithTimeout(r.Context(), time.Minute)
defer cancel()
if err := m.hostPolicy()(ctx, r.Host); err != nil {
http.Error(w, err.Error(), http.StatusForbidden)
return
}
data, err := m.httpToken(ctx, r.URL.Path)
if err != nil {
http.Error(w, err.Error(), http.StatusNotFound)
return
}
w.Write(data)
})
}
func handleHTTPRedirect(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" && r.Method != "HEAD" {
http.Error(w, "Use HTTPS", http.StatusBadRequest)
return
}
target := "https://" + stripPort(r.Host) + r.URL.RequestURI()
http.Redirect(w, r, target, http.StatusFound)
}
func stripPort(hostport string) string {
host, _, err := net.SplitHostPort(hostport)
if err != nil {
return hostport
}
return net.JoinHostPort(host, "443")
}
// cert returns an existing certificate either from m.state or cache. // cert returns an existing certificate either from m.state or cache.
// If a certificate is found in cache but not in m.state, the latter will be filled // If a certificate is found in cache but not in m.state, the latter will be filled
// with the cached value. // with the cached value.
@@ -442,13 +520,14 @@ func (m *Manager) certState(domain string) (*certState, error) {
// authorizedCert starts the domain ownership verification process and requests a new cert upon success. // authorizedCert starts the domain ownership verification process and requests a new cert upon success.
// The key argument is the certificate private key. // The key argument is the certificate private key.
func (m *Manager) authorizedCert(ctx context.Context, key crypto.Signer, domain string) (der [][]byte, leaf *x509.Certificate, err error) { func (m *Manager) authorizedCert(ctx context.Context, key crypto.Signer, domain string) (der [][]byte, leaf *x509.Certificate, err error) {
if err := m.verify(ctx, domain); err != nil {
return nil, nil, err
}
client, err := m.acmeClient(ctx) client, err := m.acmeClient(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if err := m.verify(ctx, client, domain); err != nil {
return nil, nil, err
}
csr, err := certRequest(key, domain) csr, err := certRequest(key, domain)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@@ -464,98 +543,171 @@ func (m *Manager) authorizedCert(ctx context.Context, key crypto.Signer, domain
return der, leaf, nil return der, leaf, nil
} }
// verify starts a new identifier (domain) authorization flow. // verify runs the identifier (domain) authorization flow
// It prepares a challenge response and then blocks until the authorization // using each applicable ACME challenge type.
// is marked as "completed" by the CA (either succeeded or failed). func (m *Manager) verify(ctx context.Context, client *acme.Client, domain string) error {
// // The list of challenge types we'll try to fulfill
// verify returns nil iff the verification was successful. // in this specific order.
func (m *Manager) verify(ctx context.Context, domain string) error { challengeTypes := []string{"tls-sni-02", "tls-sni-01"}
client, err := m.acmeClient(ctx) m.tokensMu.RLock()
if err != nil { if m.tryHTTP01 {
return err challengeTypes = append(challengeTypes, "http-01")
} }
m.tokensMu.RUnlock()
// start domain authorization and get the challenge var nextTyp int // challengeType index of the next challenge type to try
for {
// Start domain authorization and get the challenge.
authz, err := client.Authorize(ctx, domain) authz, err := client.Authorize(ctx, domain)
if err != nil { if err != nil {
return err return err
} }
// maybe don't need to at all // No point in accepting challenges if the authorization status
if authz.Status == acme.StatusValid { // is in a final state.
switch authz.Status {
case acme.StatusValid:
return nil // already authorized
case acme.StatusInvalid:
return fmt.Errorf("acme/autocert: invalid authorization %q", authz.URI)
}
// Pick the next preferred challenge.
var chal *acme.Challenge
for chal == nil && nextTyp < len(challengeTypes) {
chal = pickChallenge(challengeTypes[nextTyp], authz.Challenges)
nextTyp++
}
if chal == nil {
return fmt.Errorf("acme/autocert: unable to authorize %q; tried %q", domain, challengeTypes)
}
cleanup, err := m.fulfill(ctx, client, chal)
if err != nil {
continue
}
defer cleanup()
if _, err := client.Accept(ctx, chal); err != nil {
continue
}
// A challenge is fulfilled and accepted: wait for the CA to validate.
if _, err := client.WaitAuthorization(ctx, authz.URI); err == nil {
return nil
}
}
}
// fulfill provisions a response to the challenge chal.
// The cleanup is non-nil only if provisioning succeeded.
func (m *Manager) fulfill(ctx context.Context, client *acme.Client, chal *acme.Challenge) (cleanup func(), err error) {
switch chal.Type {
case "tls-sni-01":
cert, name, err := client.TLSSNI01ChallengeCert(chal.Token)
if err != nil {
return nil, err
}
m.putCertToken(ctx, name, &cert)
return func() { go m.deleteCertToken(name) }, nil
case "tls-sni-02":
cert, name, err := client.TLSSNI02ChallengeCert(chal.Token)
if err != nil {
return nil, err
}
m.putCertToken(ctx, name, &cert)
return func() { go m.deleteCertToken(name) }, nil
case "http-01":
resp, err := client.HTTP01ChallengeResponse(chal.Token)
if err != nil {
return nil, err
}
p := client.HTTP01ChallengePath(chal.Token)
m.putHTTPToken(ctx, p, resp)
return func() { go m.deleteHTTPToken(p) }, nil
}
return nil, fmt.Errorf("acme/autocert: unknown challenge type %q", chal.Type)
}
func pickChallenge(typ string, chal []*acme.Challenge) *acme.Challenge {
for _, c := range chal {
if c.Type == typ {
return c
}
}
return nil return nil
} }
// pick a challenge: prefer tls-sni-02 over tls-sni-01 // putCertToken stores the cert under the named key in both m.certTokens map
// TODO: consider authz.Combinations
var chal *acme.Challenge
for _, c := range authz.Challenges {
if c.Type == "tls-sni-02" {
chal = c
break
}
if c.Type == "tls-sni-01" {
chal = c
}
}
if chal == nil {
return errors.New("acme/autocert: no supported challenge type found")
}
// create a token cert for the challenge response
var (
cert tls.Certificate
name string
)
switch chal.Type {
case "tls-sni-01":
cert, name, err = client.TLSSNI01ChallengeCert(chal.Token)
case "tls-sni-02":
cert, name, err = client.TLSSNI02ChallengeCert(chal.Token)
default:
err = fmt.Errorf("acme/autocert: unknown challenge type %q", chal.Type)
}
if err != nil {
return err
}
m.putTokenCert(ctx, name, &cert)
defer func() {
// verification has ended at this point
// don't need token cert anymore
go m.deleteTokenCert(name)
}()
// ready to fulfill the challenge
if _, err := client.Accept(ctx, chal); err != nil {
return err
}
// wait for the CA to validate
_, err = client.WaitAuthorization(ctx, authz.URI)
return err
}
// putTokenCert stores the cert under the named key in both m.tokenCert map
// and m.Cache. // and m.Cache.
func (m *Manager) putTokenCert(ctx context.Context, name string, cert *tls.Certificate) { func (m *Manager) putCertToken(ctx context.Context, name string, cert *tls.Certificate) {
m.tokenCertMu.Lock() m.tokensMu.Lock()
defer m.tokenCertMu.Unlock() defer m.tokensMu.Unlock()
if m.tokenCert == nil { if m.certTokens == nil {
m.tokenCert = make(map[string]*tls.Certificate) m.certTokens = make(map[string]*tls.Certificate)
} }
m.tokenCert[name] = cert m.certTokens[name] = cert
m.cachePut(ctx, name, cert) m.cachePut(ctx, name, cert)
} }
// deleteTokenCert removes the token certificate for the specified domain name // deleteCertToken removes the token certificate for the specified domain name
// from both m.tokenCert map and m.Cache. // from both m.certTokens map and m.Cache.
func (m *Manager) deleteTokenCert(name string) { func (m *Manager) deleteCertToken(name string) {
m.tokenCertMu.Lock() m.tokensMu.Lock()
defer m.tokenCertMu.Unlock() defer m.tokensMu.Unlock()
delete(m.tokenCert, name) delete(m.certTokens, name)
if m.Cache != nil { if m.Cache != nil {
m.Cache.Delete(context.Background(), name) m.Cache.Delete(context.Background(), name)
} }
} }
// httpToken retrieves an existing http-01 token value from an in-memory map
// or the optional cache.
func (m *Manager) httpToken(ctx context.Context, tokenPath string) ([]byte, error) {
m.tokensMu.RLock()
defer m.tokensMu.RUnlock()
if v, ok := m.httpTokens[tokenPath]; ok {
return v, nil
}
if m.Cache == nil {
return nil, fmt.Errorf("acme/autocert: no token at %q", tokenPath)
}
return m.Cache.Get(ctx, httpTokenCacheKey(tokenPath))
}
// putHTTPToken stores an http-01 token value using tokenPath as key
// in both in-memory map and the optional Cache.
//
// It ignores any error returned from Cache.Put.
func (m *Manager) putHTTPToken(ctx context.Context, tokenPath, val string) {
m.tokensMu.Lock()
defer m.tokensMu.Unlock()
if m.httpTokens == nil {
m.httpTokens = make(map[string][]byte)
}
b := []byte(val)
m.httpTokens[tokenPath] = b
if m.Cache != nil {
m.Cache.Put(ctx, httpTokenCacheKey(tokenPath), b)
}
}
// deleteHTTPToken removes an http-01 token value from both in-memory map
// and the optional Cache, ignoring any error returned from the latter.
//
// If m.Cache is non-nil, it blocks until Cache.Delete returns without a timeout.
func (m *Manager) deleteHTTPToken(tokenPath string) {
m.tokensMu.Lock()
defer m.tokensMu.Unlock()
delete(m.httpTokens, tokenPath)
if m.Cache != nil {
m.Cache.Delete(context.Background(), httpTokenCacheKey(tokenPath))
}
}
// httpTokenCacheKey returns a key at which an http-01 token value may be stored
// in the Manager's optional Cache.
func httpTokenCacheKey(tokenPath string) string {
return "http-01-" + path.Base(tokenPath)
}
// renew starts a cert renewal timer loop, one per domain. // renew starts a cert renewal timer loop, one per domain.
// //
// The loop is scheduled in two cases: // The loop is scheduled in two cases:

View File

@@ -23,6 +23,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect" "reflect"
"strings"
"sync" "sync"
"testing" "testing"
"time" "time"
@@ -48,6 +49,16 @@ var authzTmpl = template.Must(template.New("authz").Parse(`{
"uri": "{{.}}/challenge/2", "uri": "{{.}}/challenge/2",
"type": "tls-sni-02", "type": "tls-sni-02",
"token": "token-02" "token": "token-02"
},
{
"uri": "{{.}}/challenge/dns-01",
"type": "dns-01",
"token": "token-dns-01"
},
{
"uri": "{{.}}/challenge/http-01",
"type": "http-01",
"token": "token-http-01"
} }
] ]
}`)) }`))
@@ -419,6 +430,146 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl
} }
func TestVerifyHTTP01(t *testing.T) {
var (
http01 http.Handler
authzCount int // num. of created authorizations
didAcceptHTTP01 bool
)
verifyHTTPToken := func() {
r := httptest.NewRequest("GET", "/.well-known/acme-challenge/token-http-01", nil)
w := httptest.NewRecorder()
http01.ServeHTTP(w, r)
if w.Code != http.StatusOK {
t.Errorf("http token: w.Code = %d; want %d", w.Code, http.StatusOK)
}
if v := string(w.Body.Bytes()); !strings.HasPrefix(v, "token-http-01.") {
t.Errorf("http token value = %q; want 'token-http-01.' prefix", v)
}
}
// ACME CA server stub, only the needed bits.
// TODO: Merge this with startACMEServerStub, making it a configurable CA for testing.
var ca *httptest.Server
ca = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Replay-Nonce", "nonce")
if r.Method == "HEAD" {
// a nonce request
return
}
switch r.URL.Path {
// Discovery.
case "/":
if err := discoTmpl.Execute(w, ca.URL); err != nil {
t.Errorf("discoTmpl: %v", err)
}
// Client key registration.
case "/new-reg":
w.Write([]byte("{}"))
// New domain authorization.
case "/new-authz":
authzCount++
w.Header().Set("Location", fmt.Sprintf("%s/authz/%d", ca.URL, authzCount))
w.WriteHeader(http.StatusCreated)
if err := authzTmpl.Execute(w, ca.URL); err != nil {
t.Errorf("authzTmpl: %v", err)
}
// Accept tls-sni-02.
case "/challenge/2":
w.Write([]byte("{}"))
// Reject tls-sni-01.
case "/challenge/1":
http.Error(w, "won't accept tls-sni-01", http.StatusBadRequest)
// Should not accept dns-01.
case "/challenge/dns-01":
t.Errorf("dns-01 challenge was accepted")
http.Error(w, "won't accept dns-01", http.StatusBadRequest)
// Accept http-01.
case "/challenge/http-01":
didAcceptHTTP01 = true
verifyHTTPToken()
w.Write([]byte("{}"))
// Authorization statuses.
// Make tls-sni-xxx invalid.
case "/authz/1", "/authz/2":
w.Write([]byte(`{"status": "invalid"}`))
case "/authz/3", "/authz/4":
w.Write([]byte(`{"status": "valid"}`))
default:
http.NotFound(w, r)
t.Errorf("unrecognized r.URL.Path: %s", r.URL.Path)
}
}))
defer ca.Close()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
m := &Manager{
Client: &acme.Client{
Key: key,
DirectoryURL: ca.URL,
},
}
http01 = m.HTTPHandler(nil)
if err := m.verify(context.Background(), m.Client, "example.org"); err != nil {
t.Errorf("m.verify: %v", err)
}
// Only tls-sni-01, tls-sni-02 and http-01 must be accepted
// The dns-01 challenge is unsupported.
if authzCount != 3 {
t.Errorf("authzCount = %d; want 3", authzCount)
}
if !didAcceptHTTP01 {
t.Error("did not accept http-01 challenge")
}
}
func TestHTTPHandlerDefaultFallback(t *testing.T) {
tt := []struct {
method, url string
wantCode int
wantLocation string
}{
{"GET", "http://example.org", 302, "https://example.org/"},
{"GET", "http://example.org/foo", 302, "https://example.org/foo"},
{"GET", "http://example.org/foo/bar/", 302, "https://example.org/foo/bar/"},
{"GET", "http://example.org/?a=b", 302, "https://example.org/?a=b"},
{"GET", "http://example.org/foo?a=b", 302, "https://example.org/foo?a=b"},
{"GET", "http://example.org:80/foo?a=b", 302, "https://example.org:443/foo?a=b"},
{"GET", "http://example.org:80/foo%20bar", 302, "https://example.org:443/foo%20bar"},
{"GET", "http://[2602:d1:xxxx::c60a]:1234", 302, "https://[2602:d1:xxxx::c60a]:443/"},
{"GET", "http://[2602:d1:xxxx::c60a]", 302, "https://[2602:d1:xxxx::c60a]/"},
{"GET", "http://[2602:d1:xxxx::c60a]/foo?a=b", 302, "https://[2602:d1:xxxx::c60a]/foo?a=b"},
{"HEAD", "http://example.org", 302, "https://example.org/"},
{"HEAD", "http://example.org/foo", 302, "https://example.org/foo"},
{"HEAD", "http://example.org/foo/bar/", 302, "https://example.org/foo/bar/"},
{"HEAD", "http://example.org/?a=b", 302, "https://example.org/?a=b"},
{"HEAD", "http://example.org/foo?a=b", 302, "https://example.org/foo?a=b"},
{"POST", "http://example.org", 400, ""},
{"PUT", "http://example.org", 400, ""},
{"GET", "http://example.org/.well-known/acme-challenge/x", 404, ""},
}
var m Manager
h := m.HTTPHandler(nil)
for i, test := range tt {
r := httptest.NewRequest(test.method, test.url, nil)
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
if w.Code != test.wantCode {
t.Errorf("%d: w.Code = %d; want %d", i, w.Code, test.wantCode)
t.Errorf("%d: body: %s", i, w.Body.Bytes())
}
if v := w.Header().Get("Location"); v != test.wantLocation {
t.Errorf("%d: Location = %q; want %q", i, v, test.wantLocation)
}
}
}
func TestAccountKeyCache(t *testing.T) { func TestAccountKeyCache(t *testing.T) {
m := Manager{Cache: newMemCache()} m := Manager{Cache: newMemCache()}
ctx := context.Background() ctx := context.Background()

View File

@@ -22,11 +22,12 @@ func ExampleNewListener() {
} }
func ExampleManager() { func ExampleManager() {
m := autocert.Manager{ m := &autocert.Manager{
Cache: autocert.DirCache("secret-dir"), Cache: autocert.DirCache("secret-dir"),
Prompt: autocert.AcceptTOS, Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist("example.org"), HostPolicy: autocert.HostWhitelist("example.org"),
} }
go http.ListenAndServe(":http", m.HTTPHandler(nil))
s := &http.Server{ s := &http.Server{
Addr: ":https", Addr: ":https",
TLSConfig: &tls.Config{GetCertificate: m.GetCertificate}, TLSConfig: &tls.Config{GetCertificate: m.GetCertificate},

228
vendor/golang.org/x/crypto/argon2/argon2.go generated vendored Normal file
View File

@@ -0,0 +1,228 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package argon2 implements the key derivation function Argon2.
// Argon2 was selected as the winner of the Password Hashing Competition and can
// be used to derive cryptographic keys from passwords.
// Argon2 is specfifed at https://github.com/P-H-C/phc-winner-argon2/blob/master/argon2-specs.pdf
package argon2
import (
"encoding/binary"
"sync"
"golang.org/x/crypto/blake2b"
)
// The Argon2 version implemented by this package.
const Version = 0x13
const (
argon2d = iota
argon2i
argon2id
)
// Key derives a key from the password, salt, and cost parameters using Argon2i
// returning a byte slice of length keyLen that can be used as cryptographic key.
// The CPU cost and parallism degree must be greater than zero.
//
// For example, you can get a derived key for e.g. AES-256 (which needs a 32-byte key) by doing:
// `key := argon2.Key([]byte("some password"), salt, 4, 32*1024, 4, 32)`
//
// The recommended parameters for interactive logins as of 2017 are time=4, memory=32*1024.
// The number of threads can be adjusted to the numbers of available CPUs.
// The time parameter specifies the number of passes over the memory and the memory
// parameter specifies the size of the memory in KiB. For example memory=32*1024 sets the
// memory cost to ~32 MB.
// The cost parameters should be increased as memory latency and CPU parallelism increases.
// Remember to get a good random salt.
func Key(password, salt []byte, time, memory uint32, threads uint8, keyLen uint32) []byte {
return deriveKey(argon2i, password, salt, nil, nil, time, memory, threads, keyLen)
}
func deriveKey(mode int, password, salt, secret, data []byte, time, memory uint32, threads uint8, keyLen uint32) []byte {
if time < 1 {
panic("argon2: number of rounds too small")
}
if threads < 1 {
panic("argon2: parallelism degree too low")
}
h0 := initHash(password, salt, secret, data, time, memory, uint32(threads), keyLen, mode)
memory = memory / (syncPoints * uint32(threads)) * (syncPoints * uint32(threads))
if memory < 2*syncPoints*uint32(threads) {
memory = 2 * syncPoints * uint32(threads)
}
B := initBlocks(&h0, memory, uint32(threads))
processBlocks(B, time, memory, uint32(threads), mode)
return extractKey(B, memory, uint32(threads), keyLen)
}
const (
blockLength = 128
syncPoints = 4
)
type block [blockLength]uint64
func initHash(password, salt, key, data []byte, time, memory, threads, keyLen uint32, mode int) [blake2b.Size + 8]byte {
var (
h0 [blake2b.Size + 8]byte
params [24]byte
tmp [4]byte
)
b2, _ := blake2b.New512(nil)
binary.LittleEndian.PutUint32(params[0:4], threads)
binary.LittleEndian.PutUint32(params[4:8], keyLen)
binary.LittleEndian.PutUint32(params[8:12], memory)
binary.LittleEndian.PutUint32(params[12:16], time)
binary.LittleEndian.PutUint32(params[16:20], uint32(Version))
binary.LittleEndian.PutUint32(params[20:24], uint32(mode))
b2.Write(params[:])
binary.LittleEndian.PutUint32(tmp[:], uint32(len(password)))
b2.Write(tmp[:])
b2.Write(password)
binary.LittleEndian.PutUint32(tmp[:], uint32(len(salt)))
b2.Write(tmp[:])
b2.Write(salt)
binary.LittleEndian.PutUint32(tmp[:], uint32(len(key)))
b2.Write(tmp[:])
b2.Write(key)
binary.LittleEndian.PutUint32(tmp[:], uint32(len(data)))
b2.Write(tmp[:])
b2.Write(data)
b2.Sum(h0[:0])
return h0
}
func initBlocks(h0 *[blake2b.Size + 8]byte, memory, threads uint32) []block {
var block0 [1024]byte
B := make([]block, memory)
for lane := uint32(0); lane < threads; lane++ {
j := lane * (memory / threads)
binary.LittleEndian.PutUint32(h0[blake2b.Size+4:], lane)
binary.LittleEndian.PutUint32(h0[blake2b.Size:], 0)
blake2bHash(block0[:], h0[:])
for i := range B[j+0] {
B[j+0][i] = binary.LittleEndian.Uint64(block0[i*8:])
}
binary.LittleEndian.PutUint32(h0[blake2b.Size:], 1)
blake2bHash(block0[:], h0[:])
for i := range B[j+1] {
B[j+1][i] = binary.LittleEndian.Uint64(block0[i*8:])
}
}
return B
}
func processBlocks(B []block, time, memory, threads uint32, mode int) {
lanes := memory / threads
segments := lanes / syncPoints
processSegment := func(n, slice, lane uint32, wg *sync.WaitGroup) {
var addresses, in, zero block
if mode == argon2i || (mode == argon2id && n == 0 && slice < syncPoints/2) {
in[0] = uint64(n)
in[1] = uint64(lane)
in[2] = uint64(slice)
in[3] = uint64(memory)
in[4] = uint64(time)
in[5] = uint64(mode)
}
index := uint32(0)
if n == 0 && slice == 0 {
index = 2 // we have already generated the first two blocks
if mode == argon2i || mode == argon2id {
in[6]++
processBlock(&addresses, &in, &zero)
processBlock(&addresses, &addresses, &zero)
}
}
offset := lane*lanes + slice*segments + index
var random uint64
for index < segments {
prev := offset - 1
if index == 0 && slice == 0 {
prev += lanes // last block in lane
}
if mode == argon2i || (mode == argon2id && n == 0 && slice < syncPoints/2) {
if index%blockLength == 0 {
in[6]++
processBlock(&addresses, &in, &zero)
processBlock(&addresses, &addresses, &zero)
}
random = addresses[index%blockLength]
} else {
random = B[prev][0]
}
newOffset := indexAlpha(random, lanes, segments, threads, n, slice, lane, index)
processBlockXOR(&B[offset], &B[prev], &B[newOffset])
index, offset = index+1, offset+1
}
wg.Done()
}
for n := uint32(0); n < time; n++ {
for slice := uint32(0); slice < syncPoints; slice++ {
var wg sync.WaitGroup
for lane := uint32(0); lane < threads; lane++ {
wg.Add(1)
go processSegment(n, slice, lane, &wg)
}
wg.Wait()
}
}
}
func extractKey(B []block, memory, threads, keyLen uint32) []byte {
lanes := memory / threads
for lane := uint32(0); lane < threads-1; lane++ {
for i, v := range B[(lane*lanes)+lanes-1] {
B[memory-1][i] ^= v
}
}
var block [1024]byte
for i, v := range B[memory-1] {
binary.LittleEndian.PutUint64(block[i*8:], v)
}
key := make([]byte, keyLen)
blake2bHash(key, block[:])
return key
}
func indexAlpha(rand uint64, lanes, segments, threads, n, slice, lane, index uint32) uint32 {
refLane := uint32(rand>>32) % threads
if n == 0 && slice == 0 {
refLane = lane
}
m, s := 3*segments, ((slice+1)%syncPoints)*segments
if lane == refLane {
m += index
}
if n == 0 {
m, s = slice*segments, 0
if slice == 0 || lane == refLane {
m += index
}
}
if index == 0 || lane == refLane {
m--
}
return phi(rand, uint64(m), uint64(s), refLane, lanes)
}
func phi(rand, m, s uint64, lane, lanes uint32) uint32 {
p := rand & 0xFFFFFFFF
p = (p * p) >> 32
p = (p * m) >> 32
return lane*lanes + uint32((s+m-(p+1))%uint64(lanes))
}

233
vendor/golang.org/x/crypto/argon2/argon2_test.go generated vendored Normal file
View File

@@ -0,0 +1,233 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package argon2
import (
"bytes"
"encoding/hex"
"testing"
)
var (
genKatPassword = []byte{
0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
}
genKatSalt = []byte{0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02}
genKatSecret = []byte{0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03}
genKatAAD = []byte{0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04}
)
func TestArgon2(t *testing.T) {
defer func(sse4 bool) { useSSE4 = sse4 }(useSSE4)
if useSSE4 {
t.Log("SSE4.1 version")
testArgon2i(t)
testArgon2d(t)
testArgon2id(t)
useSSE4 = false
}
t.Log("generic version")
testArgon2i(t)
testArgon2d(t)
testArgon2id(t)
}
func testArgon2d(t *testing.T) {
want := []byte{
0x51, 0x2b, 0x39, 0x1b, 0x6f, 0x11, 0x62, 0x97,
0x53, 0x71, 0xd3, 0x09, 0x19, 0x73, 0x42, 0x94,
0xf8, 0x68, 0xe3, 0xbe, 0x39, 0x84, 0xf3, 0xc1,
0xa1, 0x3a, 0x4d, 0xb9, 0xfa, 0xbe, 0x4a, 0xcb,
}
hash := deriveKey(argon2d, genKatPassword, genKatSalt, genKatSecret, genKatAAD, 3, 32, 4, 32)
if !bytes.Equal(hash, want) {
t.Errorf("derived key does not match - got: %s , want: %s", hex.EncodeToString(hash), hex.EncodeToString(want))
}
}
func testArgon2i(t *testing.T) {
want := []byte{
0xc8, 0x14, 0xd9, 0xd1, 0xdc, 0x7f, 0x37, 0xaa,
0x13, 0xf0, 0xd7, 0x7f, 0x24, 0x94, 0xbd, 0xa1,
0xc8, 0xde, 0x6b, 0x01, 0x6d, 0xd3, 0x88, 0xd2,
0x99, 0x52, 0xa4, 0xc4, 0x67, 0x2b, 0x6c, 0xe8,
}
hash := deriveKey(argon2i, genKatPassword, genKatSalt, genKatSecret, genKatAAD, 3, 32, 4, 32)
if !bytes.Equal(hash, want) {
t.Errorf("derived key does not match - got: %s , want: %s", hex.EncodeToString(hash), hex.EncodeToString(want))
}
}
func testArgon2id(t *testing.T) {
want := []byte{
0x0d, 0x64, 0x0d, 0xf5, 0x8d, 0x78, 0x76, 0x6c,
0x08, 0xc0, 0x37, 0xa3, 0x4a, 0x8b, 0x53, 0xc9,
0xd0, 0x1e, 0xf0, 0x45, 0x2d, 0x75, 0xb6, 0x5e,
0xb5, 0x25, 0x20, 0xe9, 0x6b, 0x01, 0xe6, 0x59,
}
hash := deriveKey(argon2id, genKatPassword, genKatSalt, genKatSecret, genKatAAD, 3, 32, 4, 32)
if !bytes.Equal(hash, want) {
t.Errorf("derived key does not match - got: %s , want: %s", hex.EncodeToString(hash), hex.EncodeToString(want))
}
}
func TestVectors(t *testing.T) {
password, salt := []byte("password"), []byte("somesalt")
for i, v := range testVectors {
want, err := hex.DecodeString(v.hash)
if err != nil {
t.Fatalf("Test %d: failed to decode hash: %v", i, err)
}
hash := deriveKey(v.mode, password, salt, nil, nil, v.time, v.memory, v.threads, uint32(len(want)))
if !bytes.Equal(hash, want) {
t.Errorf("Test %d - got: %s want: %s", i, hex.EncodeToString(hash), hex.EncodeToString(want))
}
}
}
func benchmarkArgon2(mode int, time, memory uint32, threads uint8, keyLen uint32, b *testing.B) {
password := []byte("password")
salt := []byte("choosing random salts is hard")
b.ReportAllocs()
for i := 0; i < b.N; i++ {
deriveKey(mode, password, salt, nil, nil, time, memory, threads, keyLen)
}
}
func BenchmarkArgon2i(b *testing.B) {
b.Run(" Time: 3 Memory: 32 MB, Threads: 1", func(b *testing.B) { benchmarkArgon2(argon2i, 3, 32*1024, 1, 32, b) })
b.Run(" Time: 4 Memory: 32 MB, Threads: 1", func(b *testing.B) { benchmarkArgon2(argon2i, 4, 32*1024, 1, 32, b) })
b.Run(" Time: 5 Memory: 32 MB, Threads: 1", func(b *testing.B) { benchmarkArgon2(argon2i, 5, 32*1024, 1, 32, b) })
b.Run(" Time: 3 Memory: 64 MB, Threads: 4", func(b *testing.B) { benchmarkArgon2(argon2i, 3, 64*1024, 4, 32, b) })
b.Run(" Time: 4 Memory: 64 MB, Threads: 4", func(b *testing.B) { benchmarkArgon2(argon2i, 4, 64*1024, 4, 32, b) })
b.Run(" Time: 5 Memory: 64 MB, Threads: 4", func(b *testing.B) { benchmarkArgon2(argon2i, 5, 64*1024, 4, 32, b) })
}
func BenchmarkArgon2d(b *testing.B) {
b.Run(" Time: 3, Memory: 32 MB, Threads: 1", func(b *testing.B) { benchmarkArgon2(argon2d, 3, 32*1024, 1, 32, b) })
b.Run(" Time: 4, Memory: 32 MB, Threads: 1", func(b *testing.B) { benchmarkArgon2(argon2d, 4, 32*1024, 1, 32, b) })
b.Run(" Time: 5, Memory: 32 MB, Threads: 1", func(b *testing.B) { benchmarkArgon2(argon2d, 5, 32*1024, 1, 32, b) })
b.Run(" Time: 3, Memory: 64 MB, Threads: 4", func(b *testing.B) { benchmarkArgon2(argon2d, 3, 64*1024, 4, 32, b) })
b.Run(" Time: 4, Memory: 64 MB, Threads: 4", func(b *testing.B) { benchmarkArgon2(argon2d, 4, 64*1024, 4, 32, b) })
b.Run(" Time: 5, Memory: 64 MB, Threads: 4", func(b *testing.B) { benchmarkArgon2(argon2d, 5, 64*1024, 4, 32, b) })
}
func BenchmarkArgon2id(b *testing.B) {
b.Run(" Time: 3, Memory: 32 MB, Threads: 1", func(b *testing.B) { benchmarkArgon2(argon2id, 3, 32*1024, 1, 32, b) })
b.Run(" Time: 4, Memory: 32 MB, Threads: 1", func(b *testing.B) { benchmarkArgon2(argon2id, 4, 32*1024, 1, 32, b) })
b.Run(" Time: 5, Memory: 32 MB, Threads: 1", func(b *testing.B) { benchmarkArgon2(argon2id, 5, 32*1024, 1, 32, b) })
b.Run(" Time: 3, Memory: 64 MB, Threads: 4", func(b *testing.B) { benchmarkArgon2(argon2id, 3, 64*1024, 4, 32, b) })
b.Run(" Time: 4, Memory: 64 MB, Threads: 4", func(b *testing.B) { benchmarkArgon2(argon2id, 4, 64*1024, 4, 32, b) })
b.Run(" Time: 5, Memory: 64 MB, Threads: 4", func(b *testing.B) { benchmarkArgon2(argon2id, 5, 64*1024, 4, 32, b) })
}
// Generated with the CLI of https://github.com/P-H-C/phc-winner-argon2/blob/master/argon2-specs.pdf
var testVectors = []struct {
mode int
time, memory uint32
threads uint8
hash string
}{
{
mode: argon2i, time: 1, memory: 64, threads: 1,
hash: "b9c401d1844a67d50eae3967dc28870b22e508092e861a37",
},
{
mode: argon2d, time: 1, memory: 64, threads: 1,
hash: "8727405fd07c32c78d64f547f24150d3f2e703a89f981a19",
},
{
mode: argon2id, time: 1, memory: 64, threads: 1,
hash: "655ad15eac652dc59f7170a7332bf49b8469be1fdb9c28bb",
},
{
mode: argon2i, time: 2, memory: 64, threads: 1,
hash: "8cf3d8f76a6617afe35fac48eb0b7433a9a670ca4a07ed64",
},
{
mode: argon2d, time: 2, memory: 64, threads: 1,
hash: "3be9ec79a69b75d3752acb59a1fbb8b295a46529c48fbb75",
},
{
mode: argon2id, time: 2, memory: 64, threads: 1,
hash: "068d62b26455936aa6ebe60060b0a65870dbfa3ddf8d41f7",
},
{
mode: argon2i, time: 2, memory: 64, threads: 2,
hash: "2089f3e78a799720f80af806553128f29b132cafe40d059f",
},
{
mode: argon2d, time: 2, memory: 64, threads: 2,
hash: "68e2462c98b8bc6bb60ec68db418ae2c9ed24fc6748a40e9",
},
{
mode: argon2id, time: 2, memory: 64, threads: 2,
hash: "350ac37222f436ccb5c0972f1ebd3bf6b958bf2071841362",
},
{
mode: argon2i, time: 3, memory: 256, threads: 2,
hash: "f5bbf5d4c3836af13193053155b73ec7476a6a2eb93fd5e6",
},
{
mode: argon2d, time: 3, memory: 256, threads: 2,
hash: "f4f0669218eaf3641f39cc97efb915721102f4b128211ef2",
},
{
mode: argon2id, time: 3, memory: 256, threads: 2,
hash: "4668d30ac4187e6878eedeacf0fd83c5a0a30db2cc16ef0b",
},
{
mode: argon2i, time: 4, memory: 4096, threads: 4,
hash: "a11f7b7f3f93f02ad4bddb59ab62d121e278369288a0d0e7",
},
{
mode: argon2d, time: 4, memory: 4096, threads: 4,
hash: "935598181aa8dc2b720914aa6435ac8d3e3a4210c5b0fb2d",
},
{
mode: argon2id, time: 4, memory: 4096, threads: 4,
hash: "145db9733a9f4ee43edf33c509be96b934d505a4efb33c5a",
},
{
mode: argon2i, time: 4, memory: 1024, threads: 8,
hash: "0cdd3956aa35e6b475a7b0c63488822f774f15b43f6e6e17",
},
{
mode: argon2d, time: 4, memory: 1024, threads: 8,
hash: "83604fc2ad0589b9d055578f4d3cc55bc616df3578a896e9",
},
{
mode: argon2id, time: 4, memory: 1024, threads: 8,
hash: "8dafa8e004f8ea96bf7c0f93eecf67a6047476143d15577f",
},
{
mode: argon2i, time: 2, memory: 64, threads: 3,
hash: "5cab452fe6b8479c8661def8cd703b611a3905a6d5477fe6",
},
{
mode: argon2d, time: 2, memory: 64, threads: 3,
hash: "22474a423bda2ccd36ec9afd5119e5c8949798cadf659f51",
},
{
mode: argon2id, time: 2, memory: 64, threads: 3,
hash: "4a15b31aec7c2590b87d1f520be7d96f56658172deaa3079",
},
{
mode: argon2i, time: 3, memory: 1024, threads: 6,
hash: "d236b29c2b2a09babee842b0dec6aa1e83ccbdea8023dced",
},
{
mode: argon2d, time: 3, memory: 1024, threads: 6,
hash: "a3351b0319a53229152023d9206902f4ef59661cdca89481",
},
{
mode: argon2id, time: 3, memory: 1024, threads: 6,
hash: "1640b932f4b60e272f5d2207b9a9c626ffa1bd88d2349016",
},
}

53
vendor/golang.org/x/crypto/argon2/blake2b.go generated vendored Normal file
View File

@@ -0,0 +1,53 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package argon2
import (
"encoding/binary"
"hash"
"golang.org/x/crypto/blake2b"
)
// blake2bHash computes an arbitrary long hash value of in
// and writes the hash to out.
func blake2bHash(out []byte, in []byte) {
var b2 hash.Hash
if n := len(out); n < blake2b.Size {
b2, _ = blake2b.New(n, nil)
} else {
b2, _ = blake2b.New512(nil)
}
var buffer [blake2b.Size]byte
binary.LittleEndian.PutUint32(buffer[:4], uint32(len(out)))
b2.Write(buffer[:4])
b2.Write(in)
if len(out) <= blake2b.Size {
b2.Sum(out[:0])
return
}
outLen := len(out)
b2.Sum(buffer[:0])
b2.Reset()
copy(out, buffer[:32])
out = out[32:]
for len(out) > blake2b.Size {
b2.Write(buffer[:])
b2.Sum(buffer[:0])
copy(out, buffer[:32])
out = out[32:]
b2.Reset()
}
if outLen%blake2b.Size > 0 { // outLen > 64
r := ((outLen + 31) / 32) - 2 // ⌈τ /32⌉-2
b2, _ = blake2b.New(outLen-32*r, nil)
}
b2.Write(buffer[:])
b2.Sum(out[:0])
}

59
vendor/golang.org/x/crypto/argon2/blamka_amd64.go generated vendored Normal file
View File

@@ -0,0 +1,59 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package argon2
func init() {
useSSE4 = supportsSSE4()
}
//go:noescape
func supportsSSE4() bool
//go:noescape
func mixBlocksSSE2(out, a, b, c *block)
//go:noescape
func xorBlocksSSE2(out, a, b, c *block)
//go:noescape
func blamkaSSE4(b *block)
func processBlockSSE(out, in1, in2 *block, xor bool) {
var t block
mixBlocksSSE2(&t, in1, in2, &t)
if useSSE4 {
blamkaSSE4(&t)
} else {
for i := 0; i < blockLength; i += 16 {
blamkaGeneric(
&t[i+0], &t[i+1], &t[i+2], &t[i+3],
&t[i+4], &t[i+5], &t[i+6], &t[i+7],
&t[i+8], &t[i+9], &t[i+10], &t[i+11],
&t[i+12], &t[i+13], &t[i+14], &t[i+15],
)
}
for i := 0; i < blockLength/8; i += 2 {
blamkaGeneric(
&t[i], &t[i+1], &t[16+i], &t[16+i+1],
&t[32+i], &t[32+i+1], &t[48+i], &t[48+i+1],
&t[64+i], &t[64+i+1], &t[80+i], &t[80+i+1],
&t[96+i], &t[96+i+1], &t[112+i], &t[112+i+1],
)
}
}
if xor {
xorBlocksSSE2(out, in1, in2, &t)
} else {
mixBlocksSSE2(out, in1, in2, &t)
}
}
func processBlock(out, in1, in2 *block) {
processBlockSSE(out, in1, in2, false)
}
func processBlockXOR(out, in1, in2 *block) {
processBlockSSE(out, in1, in2, true)
}

252
vendor/golang.org/x/crypto/argon2/blamka_amd64.s generated vendored Normal file
View File

@@ -0,0 +1,252 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build amd64,!gccgo,!appengine
#include "textflag.h"
DATA ·c40<>+0x00(SB)/8, $0x0201000706050403
DATA ·c40<>+0x08(SB)/8, $0x0a09080f0e0d0c0b
GLOBL ·c40<>(SB), (NOPTR+RODATA), $16
DATA ·c48<>+0x00(SB)/8, $0x0100070605040302
DATA ·c48<>+0x08(SB)/8, $0x09080f0e0d0c0b0a
GLOBL ·c48<>(SB), (NOPTR+RODATA), $16
#define SHUFFLE(v2, v3, v4, v5, v6, v7, t1, t2) \
MOVO v4, t1; \
MOVO v5, v4; \
MOVO t1, v5; \
MOVO v6, t1; \
PUNPCKLQDQ v6, t2; \
PUNPCKHQDQ v7, v6; \
PUNPCKHQDQ t2, v6; \
PUNPCKLQDQ v7, t2; \
MOVO t1, v7; \
MOVO v2, t1; \
PUNPCKHQDQ t2, v7; \
PUNPCKLQDQ v3, t2; \
PUNPCKHQDQ t2, v2; \
PUNPCKLQDQ t1, t2; \
PUNPCKHQDQ t2, v3
#define SHUFFLE_INV(v2, v3, v4, v5, v6, v7, t1, t2) \
MOVO v4, t1; \
MOVO v5, v4; \
MOVO t1, v5; \
MOVO v2, t1; \
PUNPCKLQDQ v2, t2; \
PUNPCKHQDQ v3, v2; \
PUNPCKHQDQ t2, v2; \
PUNPCKLQDQ v3, t2; \
MOVO t1, v3; \
MOVO v6, t1; \
PUNPCKHQDQ t2, v3; \
PUNPCKLQDQ v7, t2; \
PUNPCKHQDQ t2, v6; \
PUNPCKLQDQ t1, t2; \
PUNPCKHQDQ t2, v7
#define HALF_ROUND(v0, v1, v2, v3, v4, v5, v6, v7, t0, c40, c48) \
MOVO v0, t0; \
PMULULQ v2, t0; \
PADDQ v2, v0; \
PADDQ t0, v0; \
PADDQ t0, v0; \
PXOR v0, v6; \
PSHUFD $0xB1, v6, v6; \
MOVO v4, t0; \
PMULULQ v6, t0; \
PADDQ v6, v4; \
PADDQ t0, v4; \
PADDQ t0, v4; \
PXOR v4, v2; \
PSHUFB c40, v2; \
MOVO v0, t0; \
PMULULQ v2, t0; \
PADDQ v2, v0; \
PADDQ t0, v0; \
PADDQ t0, v0; \
PXOR v0, v6; \
PSHUFB c48, v6; \
MOVO v4, t0; \
PMULULQ v6, t0; \
PADDQ v6, v4; \
PADDQ t0, v4; \
PADDQ t0, v4; \
PXOR v4, v2; \
MOVO v2, t0; \
PADDQ v2, t0; \
PSRLQ $63, v2; \
PXOR t0, v2; \
MOVO v1, t0; \
PMULULQ v3, t0; \
PADDQ v3, v1; \
PADDQ t0, v1; \
PADDQ t0, v1; \
PXOR v1, v7; \
PSHUFD $0xB1, v7, v7; \
MOVO v5, t0; \
PMULULQ v7, t0; \
PADDQ v7, v5; \
PADDQ t0, v5; \
PADDQ t0, v5; \
PXOR v5, v3; \
PSHUFB c40, v3; \
MOVO v1, t0; \
PMULULQ v3, t0; \
PADDQ v3, v1; \
PADDQ t0, v1; \
PADDQ t0, v1; \
PXOR v1, v7; \
PSHUFB c48, v7; \
MOVO v5, t0; \
PMULULQ v7, t0; \
PADDQ v7, v5; \
PADDQ t0, v5; \
PADDQ t0, v5; \
PXOR v5, v3; \
MOVO v3, t0; \
PADDQ v3, t0; \
PSRLQ $63, v3; \
PXOR t0, v3
#define LOAD_MSG_0(block, off) \
MOVOU 8*(off+0)(block), X0; \
MOVOU 8*(off+2)(block), X1; \
MOVOU 8*(off+4)(block), X2; \
MOVOU 8*(off+6)(block), X3; \
MOVOU 8*(off+8)(block), X4; \
MOVOU 8*(off+10)(block), X5; \
MOVOU 8*(off+12)(block), X6; \
MOVOU 8*(off+14)(block), X7
#define STORE_MSG_0(block, off) \
MOVOU X0, 8*(off+0)(block); \
MOVOU X1, 8*(off+2)(block); \
MOVOU X2, 8*(off+4)(block); \
MOVOU X3, 8*(off+6)(block); \
MOVOU X4, 8*(off+8)(block); \
MOVOU X5, 8*(off+10)(block); \
MOVOU X6, 8*(off+12)(block); \
MOVOU X7, 8*(off+14)(block)
#define LOAD_MSG_1(block, off) \
MOVOU 8*off+0*8(block), X0; \
MOVOU 8*off+16*8(block), X1; \
MOVOU 8*off+32*8(block), X2; \
MOVOU 8*off+48*8(block), X3; \
MOVOU 8*off+64*8(block), X4; \
MOVOU 8*off+80*8(block), X5; \
MOVOU 8*off+96*8(block), X6; \
MOVOU 8*off+112*8(block), X7
#define STORE_MSG_1(block, off) \
MOVOU X0, 8*off+0*8(block); \
MOVOU X1, 8*off+16*8(block); \
MOVOU X2, 8*off+32*8(block); \
MOVOU X3, 8*off+48*8(block); \
MOVOU X4, 8*off+64*8(block); \
MOVOU X5, 8*off+80*8(block); \
MOVOU X6, 8*off+96*8(block); \
MOVOU X7, 8*off+112*8(block)
#define BLAMKA_ROUND_0(block, off, t0, t1, c40, c48) \
LOAD_MSG_0(block, off); \
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, t0, c40, c48); \
SHUFFLE(X2, X3, X4, X5, X6, X7, t0, t1); \
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, t0, c40, c48); \
SHUFFLE_INV(X2, X3, X4, X5, X6, X7, t0, t1); \
STORE_MSG_0(block, off)
#define BLAMKA_ROUND_1(block, off, t0, t1, c40, c48) \
LOAD_MSG_1(block, off); \
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, t0, c40, c48); \
SHUFFLE(X2, X3, X4, X5, X6, X7, t0, t1); \
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, t0, c40, c48); \
SHUFFLE_INV(X2, X3, X4, X5, X6, X7, t0, t1); \
STORE_MSG_1(block, off)
// func blamkaSSE4(b *block)
TEXT ·blamkaSSE4(SB), 4, $0-8
MOVQ b+0(FP), AX
MOVOU ·c40<>(SB), X10
MOVOU ·c48<>(SB), X11
BLAMKA_ROUND_0(AX, 0, X8, X9, X10, X11)
BLAMKA_ROUND_0(AX, 16, X8, X9, X10, X11)
BLAMKA_ROUND_0(AX, 32, X8, X9, X10, X11)
BLAMKA_ROUND_0(AX, 48, X8, X9, X10, X11)
BLAMKA_ROUND_0(AX, 64, X8, X9, X10, X11)
BLAMKA_ROUND_0(AX, 80, X8, X9, X10, X11)
BLAMKA_ROUND_0(AX, 96, X8, X9, X10, X11)
BLAMKA_ROUND_0(AX, 112, X8, X9, X10, X11)
BLAMKA_ROUND_1(AX, 0, X8, X9, X10, X11)
BLAMKA_ROUND_1(AX, 2, X8, X9, X10, X11)
BLAMKA_ROUND_1(AX, 4, X8, X9, X10, X11)
BLAMKA_ROUND_1(AX, 6, X8, X9, X10, X11)
BLAMKA_ROUND_1(AX, 8, X8, X9, X10, X11)
BLAMKA_ROUND_1(AX, 10, X8, X9, X10, X11)
BLAMKA_ROUND_1(AX, 12, X8, X9, X10, X11)
BLAMKA_ROUND_1(AX, 14, X8, X9, X10, X11)
RET
// func mixBlocksSSE2(out, a, b, c *block)
TEXT ·mixBlocksSSE2(SB), 4, $0-32
MOVQ out+0(FP), DX
MOVQ a+8(FP), AX
MOVQ b+16(FP), BX
MOVQ a+24(FP), CX
MOVQ $128, BP
loop:
MOVOU 0(AX), X0
MOVOU 0(BX), X1
MOVOU 0(CX), X2
PXOR X1, X0
PXOR X2, X0
MOVOU X0, 0(DX)
ADDQ $16, AX
ADDQ $16, BX
ADDQ $16, CX
ADDQ $16, DX
SUBQ $2, BP
JA loop
RET
// func xorBlocksSSE2(out, a, b, c *block)
TEXT ·xorBlocksSSE2(SB), 4, $0-32
MOVQ out+0(FP), DX
MOVQ a+8(FP), AX
MOVQ b+16(FP), BX
MOVQ a+24(FP), CX
MOVQ $128, BP
loop:
MOVOU 0(AX), X0
MOVOU 0(BX), X1
MOVOU 0(CX), X2
MOVOU 0(DX), X3
PXOR X1, X0
PXOR X2, X0
PXOR X3, X0
MOVOU X0, 0(DX)
ADDQ $16, AX
ADDQ $16, BX
ADDQ $16, CX
ADDQ $16, DX
SUBQ $2, BP
JA loop
RET
// func supportsSSE4() bool
TEXT ·supportsSSE4(SB), 4, $0-1
MOVL $1, AX
CPUID
SHRL $19, CX // Bit 19 indicates SSE4 support
ANDL $1, CX // CX != 0 if support SSE4
MOVB CX, ret+0(FP)
RET

163
vendor/golang.org/x/crypto/argon2/blamka_generic.go generated vendored Normal file
View File

@@ -0,0 +1,163 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package argon2
var useSSE4 bool
func processBlockGeneric(out, in1, in2 *block, xor bool) {
var t block
for i := range t {
t[i] = in1[i] ^ in2[i]
}
for i := 0; i < blockLength; i += 16 {
blamkaGeneric(
&t[i+0], &t[i+1], &t[i+2], &t[i+3],
&t[i+4], &t[i+5], &t[i+6], &t[i+7],
&t[i+8], &t[i+9], &t[i+10], &t[i+11],
&t[i+12], &t[i+13], &t[i+14], &t[i+15],
)
}
for i := 0; i < blockLength/8; i += 2 {
blamkaGeneric(
&t[i], &t[i+1], &t[16+i], &t[16+i+1],
&t[32+i], &t[32+i+1], &t[48+i], &t[48+i+1],
&t[64+i], &t[64+i+1], &t[80+i], &t[80+i+1],
&t[96+i], &t[96+i+1], &t[112+i], &t[112+i+1],
)
}
if xor {
for i := range t {
out[i] ^= in1[i] ^ in2[i] ^ t[i]
}
} else {
for i := range t {
out[i] = in1[i] ^ in2[i] ^ t[i]
}
}
}
func blamkaGeneric(t00, t01, t02, t03, t04, t05, t06, t07, t08, t09, t10, t11, t12, t13, t14, t15 *uint64) {
v00, v01, v02, v03 := *t00, *t01, *t02, *t03
v04, v05, v06, v07 := *t04, *t05, *t06, *t07
v08, v09, v10, v11 := *t08, *t09, *t10, *t11
v12, v13, v14, v15 := *t12, *t13, *t14, *t15
v00 += v04 + 2*uint64(uint32(v00))*uint64(uint32(v04))
v12 ^= v00
v12 = v12>>32 | v12<<32
v08 += v12 + 2*uint64(uint32(v08))*uint64(uint32(v12))
v04 ^= v08
v04 = v04>>24 | v04<<40
v00 += v04 + 2*uint64(uint32(v00))*uint64(uint32(v04))
v12 ^= v00
v12 = v12>>16 | v12<<48
v08 += v12 + 2*uint64(uint32(v08))*uint64(uint32(v12))
v04 ^= v08
v04 = v04>>63 | v04<<1
v01 += v05 + 2*uint64(uint32(v01))*uint64(uint32(v05))
v13 ^= v01
v13 = v13>>32 | v13<<32
v09 += v13 + 2*uint64(uint32(v09))*uint64(uint32(v13))
v05 ^= v09
v05 = v05>>24 | v05<<40
v01 += v05 + 2*uint64(uint32(v01))*uint64(uint32(v05))
v13 ^= v01
v13 = v13>>16 | v13<<48
v09 += v13 + 2*uint64(uint32(v09))*uint64(uint32(v13))
v05 ^= v09
v05 = v05>>63 | v05<<1
v02 += v06 + 2*uint64(uint32(v02))*uint64(uint32(v06))
v14 ^= v02
v14 = v14>>32 | v14<<32
v10 += v14 + 2*uint64(uint32(v10))*uint64(uint32(v14))
v06 ^= v10
v06 = v06>>24 | v06<<40
v02 += v06 + 2*uint64(uint32(v02))*uint64(uint32(v06))
v14 ^= v02
v14 = v14>>16 | v14<<48
v10 += v14 + 2*uint64(uint32(v10))*uint64(uint32(v14))
v06 ^= v10
v06 = v06>>63 | v06<<1
v03 += v07 + 2*uint64(uint32(v03))*uint64(uint32(v07))
v15 ^= v03
v15 = v15>>32 | v15<<32
v11 += v15 + 2*uint64(uint32(v11))*uint64(uint32(v15))
v07 ^= v11
v07 = v07>>24 | v07<<40
v03 += v07 + 2*uint64(uint32(v03))*uint64(uint32(v07))
v15 ^= v03
v15 = v15>>16 | v15<<48
v11 += v15 + 2*uint64(uint32(v11))*uint64(uint32(v15))
v07 ^= v11
v07 = v07>>63 | v07<<1
v00 += v05 + 2*uint64(uint32(v00))*uint64(uint32(v05))
v15 ^= v00
v15 = v15>>32 | v15<<32
v10 += v15 + 2*uint64(uint32(v10))*uint64(uint32(v15))
v05 ^= v10
v05 = v05>>24 | v05<<40
v00 += v05 + 2*uint64(uint32(v00))*uint64(uint32(v05))
v15 ^= v00
v15 = v15>>16 | v15<<48
v10 += v15 + 2*uint64(uint32(v10))*uint64(uint32(v15))
v05 ^= v10
v05 = v05>>63 | v05<<1
v01 += v06 + 2*uint64(uint32(v01))*uint64(uint32(v06))
v12 ^= v01
v12 = v12>>32 | v12<<32
v11 += v12 + 2*uint64(uint32(v11))*uint64(uint32(v12))
v06 ^= v11
v06 = v06>>24 | v06<<40
v01 += v06 + 2*uint64(uint32(v01))*uint64(uint32(v06))
v12 ^= v01
v12 = v12>>16 | v12<<48
v11 += v12 + 2*uint64(uint32(v11))*uint64(uint32(v12))
v06 ^= v11
v06 = v06>>63 | v06<<1
v02 += v07 + 2*uint64(uint32(v02))*uint64(uint32(v07))
v13 ^= v02
v13 = v13>>32 | v13<<32
v08 += v13 + 2*uint64(uint32(v08))*uint64(uint32(v13))
v07 ^= v08
v07 = v07>>24 | v07<<40
v02 += v07 + 2*uint64(uint32(v02))*uint64(uint32(v07))
v13 ^= v02
v13 = v13>>16 | v13<<48
v08 += v13 + 2*uint64(uint32(v08))*uint64(uint32(v13))
v07 ^= v08
v07 = v07>>63 | v07<<1
v03 += v04 + 2*uint64(uint32(v03))*uint64(uint32(v04))
v14 ^= v03
v14 = v14>>32 | v14<<32
v09 += v14 + 2*uint64(uint32(v09))*uint64(uint32(v14))
v04 ^= v09
v04 = v04>>24 | v04<<40
v03 += v04 + 2*uint64(uint32(v03))*uint64(uint32(v04))
v14 ^= v03
v14 = v14>>16 | v14<<48
v09 += v14 + 2*uint64(uint32(v09))*uint64(uint32(v14))
v04 ^= v09
v04 = v04>>63 | v04<<1
*t00, *t01, *t02, *t03 = v00, v01, v02, v03
*t04, *t05, *t06, *t07 = v04, v05, v06, v07
*t08, *t09, *t10, *t11 = v08, v09, v10, v11
*t12, *t13, *t14, *t15 = v12, v13, v14, v15
}

15
vendor/golang.org/x/crypto/argon2/blamka_ref.go generated vendored Normal file
View File

@@ -0,0 +1,15 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !amd64 appengine gccgo
package argon2
func processBlock(out, in1, in2 *block) {
processBlockGeneric(out, in1, in2, false)
}
func processBlockXOR(out, in1, in2 *block) {
processBlockGeneric(out, in1, in2, true)
}

View File

@@ -241,11 +241,11 @@ func (p *hashed) Hash() []byte {
n = 3 n = 3
} }
arr[n] = '$' arr[n] = '$'
n += 1 n++
copy(arr[n:], []byte(fmt.Sprintf("%02d", p.cost))) copy(arr[n:], []byte(fmt.Sprintf("%02d", p.cost)))
n += 2 n += 2
arr[n] = '$' arr[n] = '$'
n += 1 n++
copy(arr[n:], p.salt) copy(arr[n:], p.salt)
n += encodedSaltSize n += encodedSaltSize
copy(arr[n:], p.hash) copy(arr[n:], p.hash)

View File

@@ -39,7 +39,10 @@ var (
useSSE4 bool useSSE4 bool
) )
var errKeySize = errors.New("blake2b: invalid key size") var (
errKeySize = errors.New("blake2b: invalid key size")
errHashSize = errors.New("blake2b: invalid hash size")
)
var iv = [8]uint64{ var iv = [8]uint64{
0x6a09e667f3bcc908, 0xbb67ae8584caa73b, 0x3c6ef372fe94f82b, 0xa54ff53a5f1d36f1, 0x6a09e667f3bcc908, 0xbb67ae8584caa73b, 0x3c6ef372fe94f82b, 0xa54ff53a5f1d36f1,
@@ -83,7 +86,18 @@ func New384(key []byte) (hash.Hash, error) { return newDigest(Size384, key) }
// key turns the hash into a MAC. The key must between zero and 64 bytes long. // key turns the hash into a MAC. The key must between zero and 64 bytes long.
func New256(key []byte) (hash.Hash, error) { return newDigest(Size256, key) } func New256(key []byte) (hash.Hash, error) { return newDigest(Size256, key) }
// New returns a new hash.Hash computing the BLAKE2b checksum with a custom length.
// A non-nil key turns the hash into a MAC. The key must between zero and 64 bytes long.
// The hash size can be a value between 1 and 64 but it is highly recommended to use
// values equal or greater than:
// - 32 if BLAKE2b is used as a hash function (The key is zero bytes long).
// - 16 if BLAKE2b is used as a MAC function (The key is at least 16 bytes long).
func New(size int, key []byte) (hash.Hash, error) { return newDigest(size, key) }
func newDigest(hashSize int, key []byte) (*digest, error) { func newDigest(hashSize int, key []byte) (*digest, error) {
if hashSize < 1 || hashSize > Size {
return nil, errHashSize
}
if len(key) > Size { if len(key) > Size {
return nil, errKeySize return nil, errKeySize
} }

View File

@@ -185,7 +185,7 @@ func testHashes2X(t *testing.T) {
if n, err := h.Read(result[:]); err != nil { if n, err := h.Read(result[:]); err != nil {
t.Fatalf("#unknown length: error from Read: %v", err) t.Fatalf("#unknown length: error from Read: %v", err)
} else if n != len(result) { } else if n != len(result) {
t.Fatalf("#unknown length: Read returned %d bytes, want %d: %v", n, len(result)) t.Fatalf("#unknown length: Read returned %d bytes, want %d", n, len(result))
} }
const expected = "2a9a6977d915a2c4dd07dbcafe1918bf1682e56d9c8e567ecd19bfd7cd93528833c764d12b34a5e2a219c9fd463dab45e972c5574d73f45de5b2e23af72530d8" const expected = "2a9a6977d915a2c4dd07dbcafe1918bf1682e56d9c8e567ecd19bfd7cd93528833c764d12b34a5e2a219c9fd463dab45e972c5574d73f45de5b2e23af72530d8"

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package bn256 implements a particular bilinear group at the 128-bit security level. // Package bn256 implements a particular bilinear group.
// //
// Bilinear groups are the basis of many of the new cryptographic protocols // Bilinear groups are the basis of many of the new cryptographic protocols
// that have been proposed over the past decade. They consist of a triplet of // that have been proposed over the past decade. They consist of a triplet of
@@ -14,6 +14,10 @@
// Barreto-Naehrig curve as described in // Barreto-Naehrig curve as described in
// http://cryptojedi.org/papers/dclxvi-20100714.pdf. Its output is compatible // http://cryptojedi.org/papers/dclxvi-20100714.pdf. Its output is compatible
// with the implementation described in that paper. // with the implementation described in that paper.
//
// (This package previously claimed to operate at a 128-bit security level.
// However, recent improvements in attacks mean that is no longer true. See
// https://moderncrypto.org/mail-archive/curves/2016/000740.html.)
package bn256 // import "golang.org/x/crypto/bn256" package bn256 // import "golang.org/x/crypto/bn256"
import ( import (
@@ -49,8 +53,8 @@ func RandomG1(r io.Reader) (*big.Int, *G1, error) {
return k, new(G1).ScalarBaseMult(k), nil return k, new(G1).ScalarBaseMult(k), nil
} }
func (g *G1) String() string { func (e *G1) String() string {
return "bn256.G1" + g.p.String() return "bn256.G1" + e.p.String()
} }
// ScalarBaseMult sets e to g*k where g is the generator of the group and // ScalarBaseMult sets e to g*k where g is the generator of the group and
@@ -92,11 +96,11 @@ func (e *G1) Neg(a *G1) *G1 {
} }
// Marshal converts n to a byte slice. // Marshal converts n to a byte slice.
func (n *G1) Marshal() []byte { func (e *G1) Marshal() []byte {
n.p.MakeAffine(nil) e.p.MakeAffine(nil)
xBytes := new(big.Int).Mod(n.p.x, p).Bytes() xBytes := new(big.Int).Mod(e.p.x, p).Bytes()
yBytes := new(big.Int).Mod(n.p.y, p).Bytes() yBytes := new(big.Int).Mod(e.p.y, p).Bytes()
// Each value is a 256-bit number. // Each value is a 256-bit number.
const numBytes = 256 / 8 const numBytes = 256 / 8
@@ -166,8 +170,8 @@ func RandomG2(r io.Reader) (*big.Int, *G2, error) {
return k, new(G2).ScalarBaseMult(k), nil return k, new(G2).ScalarBaseMult(k), nil
} }
func (g *G2) String() string { func (e *G2) String() string {
return "bn256.G2" + g.p.String() return "bn256.G2" + e.p.String()
} }
// ScalarBaseMult sets e to g*k where g is the generator of the group and // ScalarBaseMult sets e to g*k where g is the generator of the group and

View File

@@ -7,7 +7,7 @@ package chacha20poly1305
import ( import (
"encoding/binary" "encoding/binary"
"golang.org/x/crypto/chacha20poly1305/internal/chacha20" "golang.org/x/crypto/internal/chacha20"
"golang.org/x/crypto/poly1305" "golang.org/x/crypto/poly1305"
) )

View File

@@ -47,7 +47,7 @@ func Sum(m []byte, key *[KeySize]byte) *[Size]byte {
// Verify checks that digest is a valid authenticator of message m under the // Verify checks that digest is a valid authenticator of message m under the
// given secret key. Verify does not leak timing information. // given secret key. Verify does not leak timing information.
func Verify(digest []byte, m []byte, key *[32]byte) bool { func Verify(digest []byte, m []byte, key *[KeySize]byte) bool {
if len(digest) != Size { if len(digest) != Size {
return false return false
} }

View File

@@ -760,7 +760,7 @@ func CreateResponse(issuer, responderCert *x509.Certificate, template Response,
} }
if template.Certificate != nil { if template.Certificate != nil {
response.Certificates = []asn1.RawValue{ response.Certificates = []asn1.RawValue{
asn1.RawValue{FullBytes: template.Certificate.Raw}, {FullBytes: template.Certificate.Raw},
} }
} }
responseDER, err := asn1.Marshal(response) responseDER, err := asn1.Marshal(response)

View File

@@ -43,11 +43,11 @@ func TestOCSPDecode(t *testing.T) {
} }
if !reflect.DeepEqual(resp.ThisUpdate, expected.ThisUpdate) { if !reflect.DeepEqual(resp.ThisUpdate, expected.ThisUpdate) {
t.Errorf("resp.ThisUpdate: got %d, want %d", resp.ThisUpdate, expected.ThisUpdate) t.Errorf("resp.ThisUpdate: got %v, want %v", resp.ThisUpdate, expected.ThisUpdate)
} }
if !reflect.DeepEqual(resp.NextUpdate, expected.NextUpdate) { if !reflect.DeepEqual(resp.NextUpdate, expected.NextUpdate) {
t.Errorf("resp.NextUpdate: got %d, want %d", resp.NextUpdate, expected.NextUpdate) t.Errorf("resp.NextUpdate: got %v, want %v", resp.NextUpdate, expected.NextUpdate)
} }
if resp.Status != expected.Status { if resp.Status != expected.Status {
@@ -218,7 +218,7 @@ func TestOCSPResponse(t *testing.T) {
extensionBytes, _ := hex.DecodeString(ocspExtensionValueHex) extensionBytes, _ := hex.DecodeString(ocspExtensionValueHex)
extensions := []pkix.Extension{ extensions := []pkix.Extension{
pkix.Extension{ {
Id: ocspExtensionOID, Id: ocspExtensionOID,
Critical: false, Critical: false,
Value: extensionBytes, Value: extensionBytes,
@@ -268,15 +268,15 @@ func TestOCSPResponse(t *testing.T) {
} }
if !reflect.DeepEqual(resp.ThisUpdate, template.ThisUpdate) { if !reflect.DeepEqual(resp.ThisUpdate, template.ThisUpdate) {
t.Errorf("resp.ThisUpdate: got %d, want %d", resp.ThisUpdate, template.ThisUpdate) t.Errorf("resp.ThisUpdate: got %v, want %v", resp.ThisUpdate, template.ThisUpdate)
} }
if !reflect.DeepEqual(resp.NextUpdate, template.NextUpdate) { if !reflect.DeepEqual(resp.NextUpdate, template.NextUpdate) {
t.Errorf("resp.NextUpdate: got %d, want %d", resp.NextUpdate, template.NextUpdate) t.Errorf("resp.NextUpdate: got %v, want %v", resp.NextUpdate, template.NextUpdate)
} }
if !reflect.DeepEqual(resp.RevokedAt, template.RevokedAt) { if !reflect.DeepEqual(resp.RevokedAt, template.RevokedAt) {
t.Errorf("resp.RevokedAt: got %d, want %d", resp.RevokedAt, template.RevokedAt) t.Errorf("resp.RevokedAt: got %v, want %v", resp.RevokedAt, template.RevokedAt)
} }
if !reflect.DeepEqual(resp.Extensions, template.ExtraExtensions) { if !reflect.DeepEqual(resp.Extensions, template.ExtraExtensions) {

View File

@@ -325,9 +325,8 @@ func ReadEntity(packets *packet.Reader) (*Entity, error) {
if e.PrivateKey, ok = p.(*packet.PrivateKey); !ok { if e.PrivateKey, ok = p.(*packet.PrivateKey); !ok {
packets.Unread(p) packets.Unread(p)
return nil, errors.StructuralError("first packet was not a public/private key") return nil, errors.StructuralError("first packet was not a public/private key")
} else {
e.PrimaryKey = &e.PrivateKey.PublicKey
} }
e.PrimaryKey = &e.PrivateKey.PublicKey
} }
if !e.PrimaryKey.PubKeyAlgo.CanSign() { if !e.PrimaryKey.PubKeyAlgo.CanSign() {

View File

@@ -155,3 +155,22 @@ func TestWithHMACSHA1(t *testing.T) {
func TestWithHMACSHA256(t *testing.T) { func TestWithHMACSHA256(t *testing.T) {
testHash(t, sha256.New, "SHA256", sha256TestVectors) testHash(t, sha256.New, "SHA256", sha256TestVectors)
} }
var sink uint8
func benchmark(b *testing.B, h func() hash.Hash) {
password := make([]byte, h().Size())
salt := make([]byte, 8)
for i := 0; i < b.N; i++ {
password = Key(password, salt, 4096, len(password), h)
}
sink += password[0]
}
func BenchmarkHMACSHA1(b *testing.B) {
benchmark(b, sha1.New)
}
func BenchmarkHMACSHA256(b *testing.B) {
benchmark(b, sha256.New)
}

View File

@@ -122,7 +122,6 @@ func (c *rc2Cipher) Encrypt(dst, src []byte) {
r3 = r3 + c.k[r2&63] r3 = r3 + c.k[r2&63]
for j <= 40 { for j <= 40 {
// mix r0 // mix r0
r0 = r0 + c.k[j] + (r3 & r2) + ((^r3) & r1) r0 = r0 + c.k[j] + (r3 & r2) + ((^r3) & r1)
r0 = rotl16(r0, 1) r0 = rotl16(r0, 1)
@@ -151,7 +150,6 @@ func (c *rc2Cipher) Encrypt(dst, src []byte) {
r3 = r3 + c.k[r2&63] r3 = r3 + c.k[r2&63]
for j <= 60 { for j <= 60 {
// mix r0 // mix r0
r0 = r0 + c.k[j] + (r3 & r2) + ((^r3) & r1) r0 = r0 + c.k[j] + (r3 & r2) + ((^r3) & r1)
r0 = rotl16(r0, 1) r0 = rotl16(r0, 1)
@@ -244,7 +242,6 @@ func (c *rc2Cipher) Decrypt(dst, src []byte) {
r0 = r0 - c.k[r3&63] r0 = r0 - c.k[r3&63]
for j >= 0 { for j >= 0 {
// unmix r3 // unmix r3
r3 = rotl16(r3, 16-5) r3 = rotl16(r3, 16-5)
r3 = r3 - c.k[j] - (r2 & r1) - ((^r2) & r0) r3 = r3 - c.k[j] - (r2 & r1) - ((^r2) & r0)

View File

@@ -11,7 +11,6 @@ import (
) )
func TestEncryptDecrypt(t *testing.T) { func TestEncryptDecrypt(t *testing.T) {
// TODO(dgryski): add the rest of the test vectors from the RFC // TODO(dgryski): add the rest of the test vectors from the RFC
var tests = []struct { var tests = []struct {
key string key string

View File

@@ -202,7 +202,7 @@ func TestSqueezing(t *testing.T) {
d1 := newShakeHash() d1 := newShakeHash()
d1.Write([]byte(testString)) d1.Write([]byte(testString))
var multiple []byte var multiple []byte
for _ = range ref { for range ref {
one := make([]byte, 1) one := make([]byte, 1)
d1.Read(one) d1.Read(one)
multiple = append(multiple, one...) multiple = append(multiple, one...)

View File

@@ -98,7 +98,7 @@ const (
agentAddIdentity = 17 agentAddIdentity = 17
agentRemoveIdentity = 18 agentRemoveIdentity = 18
agentRemoveAllIdentities = 19 agentRemoveAllIdentities = 19
agentAddIdConstrained = 25 agentAddIDConstrained = 25
// 3.3 Key-type independent requests from client to agent // 3.3 Key-type independent requests from client to agent
agentAddSmartcardKey = 20 agentAddSmartcardKey = 20
@@ -515,7 +515,7 @@ func (c *client) insertKey(s interface{}, comment string, constraints []byte) er
// if constraints are present then the message type needs to be changed. // if constraints are present then the message type needs to be changed.
if len(constraints) != 0 { if len(constraints) != 0 {
req[0] = agentAddIdConstrained req[0] = agentAddIDConstrained
} }
resp, err := c.call(req) resp, err := c.call(req)
@@ -577,11 +577,11 @@ func (c *client) Add(key AddedKey) error {
constraints = append(constraints, agentConstrainConfirm) constraints = append(constraints, agentConstrainConfirm)
} }
if cert := key.Certificate; cert == nil { cert := key.Certificate
if cert == nil {
return c.insertKey(key.PrivateKey, key.Comment, constraints) return c.insertKey(key.PrivateKey, key.Comment, constraints)
} else {
return c.insertCert(key.PrivateKey, cert, key.Comment, constraints)
} }
return c.insertCert(key.PrivateKey, cert, key.Comment, constraints)
} }
func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string, constraints []byte) error { func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string, constraints []byte) error {
@@ -633,7 +633,7 @@ func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string
// if constraints are present then the message type needs to be changed. // if constraints are present then the message type needs to be changed.
if len(constraints) != 0 { if len(constraints) != 0 {
req[0] = agentAddIdConstrained req[0] = agentAddIDConstrained
} }
signer, err := ssh.NewSignerFromKey(s) signer, err := ssh.NewSignerFromKey(s)

View File

@@ -148,7 +148,7 @@ func (s *server) processRequest(data []byte) (interface{}, error) {
} }
return rep, nil return rep, nil
case agentAddIdConstrained, agentAddIdentity: case agentAddIDConstrained, agentAddIdentity:
return nil, s.insertIdentity(data) return nil, s.insertIdentity(data)
} }

View File

@@ -41,6 +41,7 @@ func sshPipe() (Conn, *server, error) {
clientConf := ClientConfig{ clientConf := ClientConfig{
User: "user", User: "user",
HostKeyCallback: InsecureIgnoreHostKey(),
} }
serverConf := ServerConfig{ serverConf := ServerConfig{
NoClientAuth: true, NoClientAuth: true,

View File

@@ -340,10 +340,10 @@ func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permis
// the signature of the certificate. // the signature of the certificate.
func (c *CertChecker) CheckCert(principal string, cert *Certificate) error { func (c *CertChecker) CheckCert(principal string, cert *Certificate) error {
if c.IsRevoked != nil && c.IsRevoked(cert) { if c.IsRevoked != nil && c.IsRevoked(cert) {
return fmt.Errorf("ssh: certicate serial %d revoked", cert.Serial) return fmt.Errorf("ssh: certificate serial %d revoked", cert.Serial)
} }
for opt, _ := range cert.CriticalOptions { for opt := range cert.CriticalOptions {
// sourceAddressCriticalOption will be enforced by // sourceAddressCriticalOption will be enforced by
// serverAuthenticate // serverAuthenticate
if opt == sourceAddressCriticalOption { if opt == sourceAddressCriticalOption {

View File

@@ -6,10 +6,15 @@ package ssh
import ( import (
"bytes" "bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand" "crypto/rand"
"net"
"reflect" "reflect"
"testing" "testing"
"time" "time"
"golang.org/x/crypto/ssh/testdata"
) )
// Cert generated by ssh-keygen 6.0p1 Debian-4. // Cert generated by ssh-keygen 6.0p1 Debian-4.
@@ -220,3 +225,111 @@ func TestHostKeyCert(t *testing.T) {
} }
} }
} }
func TestCertTypes(t *testing.T) {
var testVars = []struct {
name string
keys func() Signer
}{
{
name: CertAlgoECDSA256v01,
keys: func() Signer {
s, _ := ParsePrivateKey(testdata.PEMBytes["ecdsap256"])
return s
},
},
{
name: CertAlgoECDSA384v01,
keys: func() Signer {
s, _ := ParsePrivateKey(testdata.PEMBytes["ecdsap384"])
return s
},
},
{
name: CertAlgoECDSA521v01,
keys: func() Signer {
s, _ := ParsePrivateKey(testdata.PEMBytes["ecdsap521"])
return s
},
},
{
name: CertAlgoED25519v01,
keys: func() Signer {
s, _ := ParsePrivateKey(testdata.PEMBytes["ed25519"])
return s
},
},
{
name: CertAlgoRSAv01,
keys: func() Signer {
s, _ := ParsePrivateKey(testdata.PEMBytes["rsa"])
return s
},
},
{
name: CertAlgoDSAv01,
keys: func() Signer {
s, _ := ParsePrivateKey(testdata.PEMBytes["dsa"])
return s
},
},
}
k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("error generating host key: %v", err)
}
signer, err := NewSignerFromKey(k)
if err != nil {
t.Fatalf("error generating signer for ssh listener: %v", err)
}
conf := &ServerConfig{
PublicKeyCallback: func(c ConnMetadata, k PublicKey) (*Permissions, error) {
return new(Permissions), nil
},
}
conf.AddHostKey(signer)
for _, m := range testVars {
t.Run(m.name, func(t *testing.T) {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go NewServerConn(c1, conf)
priv := m.keys()
if err != nil {
t.Fatalf("error generating ssh pubkey: %v", err)
}
cert := &Certificate{
CertType: UserCert,
Key: priv.PublicKey(),
}
cert.SignCert(rand.Reader, priv)
certSigner, err := NewCertSigner(cert, priv)
if err != nil {
t.Fatalf("error generating cert signer: %v", err)
}
config := &ClientConfig{
User: "user",
HostKeyCallback: func(h string, r net.Addr, k PublicKey) error { return nil },
Auth: []AuthMethod{PublicKeys(certSigner)},
}
_, _, _, err = NewClientConn(c2, "", config)
if err != nil {
t.Fatalf("error connecting: %v", err)
}
})
}
}

View File

@@ -205,32 +205,32 @@ type channel struct {
// writePacket sends a packet. If the packet is a channel close, it updates // writePacket sends a packet. If the packet is a channel close, it updates
// sentClose. This method takes the lock c.writeMu. // sentClose. This method takes the lock c.writeMu.
func (c *channel) writePacket(packet []byte) error { func (ch *channel) writePacket(packet []byte) error {
c.writeMu.Lock() ch.writeMu.Lock()
if c.sentClose { if ch.sentClose {
c.writeMu.Unlock() ch.writeMu.Unlock()
return io.EOF return io.EOF
} }
c.sentClose = (packet[0] == msgChannelClose) ch.sentClose = (packet[0] == msgChannelClose)
err := c.mux.conn.writePacket(packet) err := ch.mux.conn.writePacket(packet)
c.writeMu.Unlock() ch.writeMu.Unlock()
return err return err
} }
func (c *channel) sendMessage(msg interface{}) error { func (ch *channel) sendMessage(msg interface{}) error {
if debugMux { if debugMux {
log.Printf("send(%d): %#v", c.mux.chanList.offset, msg) log.Printf("send(%d): %#v", ch.mux.chanList.offset, msg)
} }
p := Marshal(msg) p := Marshal(msg)
binary.BigEndian.PutUint32(p[1:], c.remoteId) binary.BigEndian.PutUint32(p[1:], ch.remoteId)
return c.writePacket(p) return ch.writePacket(p)
} }
// WriteExtended writes data to a specific extended stream. These streams are // WriteExtended writes data to a specific extended stream. These streams are
// used, for example, for stderr. // used, for example, for stderr.
func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) { func (ch *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) {
if c.sentEOF { if ch.sentEOF {
return 0, io.EOF return 0, io.EOF
} }
// 1 byte message type, 4 bytes remoteId, 4 bytes data length // 1 byte message type, 4 bytes remoteId, 4 bytes data length
@@ -241,16 +241,16 @@ func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err er
opCode = msgChannelExtendedData opCode = msgChannelExtendedData
} }
c.writeMu.Lock() ch.writeMu.Lock()
packet := c.packetPool[extendedCode] packet := ch.packetPool[extendedCode]
// We don't remove the buffer from packetPool, so // We don't remove the buffer from packetPool, so
// WriteExtended calls from different goroutines will be // WriteExtended calls from different goroutines will be
// flagged as errors by the race detector. // flagged as errors by the race detector.
c.writeMu.Unlock() ch.writeMu.Unlock()
for len(data) > 0 { for len(data) > 0 {
space := min(c.maxRemotePayload, len(data)) space := min(ch.maxRemotePayload, len(data))
if space, err = c.remoteWin.reserve(space); err != nil { if space, err = ch.remoteWin.reserve(space); err != nil {
return n, err return n, err
} }
if want := headerLength + space; uint32(cap(packet)) < want { if want := headerLength + space; uint32(cap(packet)) < want {
@@ -262,13 +262,13 @@ func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err er
todo := data[:space] todo := data[:space]
packet[0] = opCode packet[0] = opCode
binary.BigEndian.PutUint32(packet[1:], c.remoteId) binary.BigEndian.PutUint32(packet[1:], ch.remoteId)
if extendedCode > 0 { if extendedCode > 0 {
binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode)) binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode))
} }
binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo))) binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo)))
copy(packet[headerLength:], todo) copy(packet[headerLength:], todo)
if err = c.writePacket(packet); err != nil { if err = ch.writePacket(packet); err != nil {
return n, err return n, err
} }
@@ -276,14 +276,14 @@ func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err er
data = data[len(todo):] data = data[len(todo):]
} }
c.writeMu.Lock() ch.writeMu.Lock()
c.packetPool[extendedCode] = packet ch.packetPool[extendedCode] = packet
c.writeMu.Unlock() ch.writeMu.Unlock()
return n, err return n, err
} }
func (c *channel) handleData(packet []byte) error { func (ch *channel) handleData(packet []byte) error {
headerLen := 9 headerLen := 9
isExtendedData := packet[0] == msgChannelExtendedData isExtendedData := packet[0] == msgChannelExtendedData
if isExtendedData { if isExtendedData {
@@ -303,7 +303,7 @@ func (c *channel) handleData(packet []byte) error {
if length == 0 { if length == 0 {
return nil return nil
} }
if length > c.maxIncomingPayload { if length > ch.maxIncomingPayload {
// TODO(hanwen): should send Disconnect? // TODO(hanwen): should send Disconnect?
return errors.New("ssh: incoming packet exceeds maximum payload size") return errors.New("ssh: incoming packet exceeds maximum payload size")
} }
@@ -313,21 +313,21 @@ func (c *channel) handleData(packet []byte) error {
return errors.New("ssh: wrong packet length") return errors.New("ssh: wrong packet length")
} }
c.windowMu.Lock() ch.windowMu.Lock()
if c.myWindow < length { if ch.myWindow < length {
c.windowMu.Unlock() ch.windowMu.Unlock()
// TODO(hanwen): should send Disconnect with reason? // TODO(hanwen): should send Disconnect with reason?
return errors.New("ssh: remote side wrote too much") return errors.New("ssh: remote side wrote too much")
} }
c.myWindow -= length ch.myWindow -= length
c.windowMu.Unlock() ch.windowMu.Unlock()
if extended == 1 { if extended == 1 {
c.extPending.write(data) ch.extPending.write(data)
} else if extended > 0 { } else if extended > 0 {
// discard other extended data. // discard other extended data.
} else { } else {
c.pending.write(data) ch.pending.write(data)
} }
return nil return nil
} }
@@ -384,31 +384,31 @@ func (c *channel) close() {
// responseMessageReceived is called when a success or failure message is // responseMessageReceived is called when a success or failure message is
// received on a channel to check that such a message is reasonable for the // received on a channel to check that such a message is reasonable for the
// given channel. // given channel.
func (c *channel) responseMessageReceived() error { func (ch *channel) responseMessageReceived() error {
if c.direction == channelInbound { if ch.direction == channelInbound {
return errors.New("ssh: channel response message received on inbound channel") return errors.New("ssh: channel response message received on inbound channel")
} }
if c.decided { if ch.decided {
return errors.New("ssh: duplicate response received for channel") return errors.New("ssh: duplicate response received for channel")
} }
c.decided = true ch.decided = true
return nil return nil
} }
func (c *channel) handlePacket(packet []byte) error { func (ch *channel) handlePacket(packet []byte) error {
switch packet[0] { switch packet[0] {
case msgChannelData, msgChannelExtendedData: case msgChannelData, msgChannelExtendedData:
return c.handleData(packet) return ch.handleData(packet)
case msgChannelClose: case msgChannelClose:
c.sendMessage(channelCloseMsg{PeersId: c.remoteId}) ch.sendMessage(channelCloseMsg{PeersID: ch.remoteId})
c.mux.chanList.remove(c.localId) ch.mux.chanList.remove(ch.localId)
c.close() ch.close()
return nil return nil
case msgChannelEOF: case msgChannelEOF:
// RFC 4254 is mute on how EOF affects dataExt messages but // RFC 4254 is mute on how EOF affects dataExt messages but
// it is logical to signal EOF at the same time. // it is logical to signal EOF at the same time.
c.extPending.eof() ch.extPending.eof()
c.pending.eof() ch.pending.eof()
return nil return nil
} }
@@ -419,24 +419,24 @@ func (c *channel) handlePacket(packet []byte) error {
switch msg := decoded.(type) { switch msg := decoded.(type) {
case *channelOpenFailureMsg: case *channelOpenFailureMsg:
if err := c.responseMessageReceived(); err != nil { if err := ch.responseMessageReceived(); err != nil {
return err return err
} }
c.mux.chanList.remove(msg.PeersId) ch.mux.chanList.remove(msg.PeersID)
c.msg <- msg ch.msg <- msg
case *channelOpenConfirmMsg: case *channelOpenConfirmMsg:
if err := c.responseMessageReceived(); err != nil { if err := ch.responseMessageReceived(); err != nil {
return err return err
} }
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize) return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize)
} }
c.remoteId = msg.MyId ch.remoteId = msg.MyID
c.maxRemotePayload = msg.MaxPacketSize ch.maxRemotePayload = msg.MaxPacketSize
c.remoteWin.add(msg.MyWindow) ch.remoteWin.add(msg.MyWindow)
c.msg <- msg ch.msg <- msg
case *windowAdjustMsg: case *windowAdjustMsg:
if !c.remoteWin.add(msg.AdditionalBytes) { if !ch.remoteWin.add(msg.AdditionalBytes) {
return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes) return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes)
} }
case *channelRequestMsg: case *channelRequestMsg:
@@ -444,12 +444,12 @@ func (c *channel) handlePacket(packet []byte) error {
Type: msg.Request, Type: msg.Request,
WantReply: msg.WantReply, WantReply: msg.WantReply,
Payload: msg.RequestSpecificData, Payload: msg.RequestSpecificData,
ch: c, ch: ch,
} }
c.incomingRequests <- &req ch.incomingRequests <- &req
default: default:
c.msg <- msg ch.msg <- msg
} }
return nil return nil
} }
@@ -488,23 +488,23 @@ func (e *extChannel) Read(data []byte) (n int, err error) {
return e.ch.ReadExtended(data, e.code) return e.ch.ReadExtended(data, e.code)
} }
func (c *channel) Accept() (Channel, <-chan *Request, error) { func (ch *channel) Accept() (Channel, <-chan *Request, error) {
if c.decided { if ch.decided {
return nil, nil, errDecidedAlready return nil, nil, errDecidedAlready
} }
c.maxIncomingPayload = channelMaxPacket ch.maxIncomingPayload = channelMaxPacket
confirm := channelOpenConfirmMsg{ confirm := channelOpenConfirmMsg{
PeersId: c.remoteId, PeersID: ch.remoteId,
MyId: c.localId, MyID: ch.localId,
MyWindow: c.myWindow, MyWindow: ch.myWindow,
MaxPacketSize: c.maxIncomingPayload, MaxPacketSize: ch.maxIncomingPayload,
} }
c.decided = true ch.decided = true
if err := c.sendMessage(confirm); err != nil { if err := ch.sendMessage(confirm); err != nil {
return nil, nil, err return nil, nil, err
} }
return c, c.incomingRequests, nil return ch, ch.incomingRequests, nil
} }
func (ch *channel) Reject(reason RejectionReason, message string) error { func (ch *channel) Reject(reason RejectionReason, message string) error {
@@ -512,7 +512,7 @@ func (ch *channel) Reject(reason RejectionReason, message string) error {
return errDecidedAlready return errDecidedAlready
} }
reject := channelOpenFailureMsg{ reject := channelOpenFailureMsg{
PeersId: ch.remoteId, PeersID: ch.remoteId,
Reason: reason, Reason: reason,
Message: message, Message: message,
Language: "en", Language: "en",
@@ -541,7 +541,7 @@ func (ch *channel) CloseWrite() error {
} }
ch.sentEOF = true ch.sentEOF = true
return ch.sendMessage(channelEOFMsg{ return ch.sendMessage(channelEOFMsg{
PeersId: ch.remoteId}) PeersID: ch.remoteId})
} }
func (ch *channel) Close() error { func (ch *channel) Close() error {
@@ -550,7 +550,7 @@ func (ch *channel) Close() error {
} }
return ch.sendMessage(channelCloseMsg{ return ch.sendMessage(channelCloseMsg{
PeersId: ch.remoteId}) PeersID: ch.remoteId})
} }
// Extended returns an io.ReadWriter that sends and receives data on the given, // Extended returns an io.ReadWriter that sends and receives data on the given,
@@ -577,7 +577,7 @@ func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (boo
} }
msg := channelRequestMsg{ msg := channelRequestMsg{
PeersId: ch.remoteId, PeersID: ch.remoteId,
Request: name, Request: name,
WantReply: wantReply, WantReply: wantReply,
RequestSpecificData: payload, RequestSpecificData: payload,
@@ -614,11 +614,11 @@ func (ch *channel) ackRequest(ok bool) error {
var msg interface{} var msg interface{}
if !ok { if !ok {
msg = channelRequestFailureMsg{ msg = channelRequestFailureMsg{
PeersId: ch.remoteId, PeersID: ch.remoteId,
} }
} else { } else {
msg = channelRequestSuccessMsg{ msg = channelRequestSuccessMsg{
PeersId: ch.remoteId, PeersID: ch.remoteId,
} }
} }
return ch.sendMessage(msg) return ch.sendMessage(msg)

View File

@@ -304,7 +304,7 @@ type gcmCipher struct {
buf []byte buf []byte
} }
func newGCMCipher(iv, key, macKey []byte) (packetCipher, error) { func newGCMCipher(iv, key []byte) (packetCipher, error) {
c, err := aes.NewCipher(key) c, err := aes.NewCipher(key)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -372,7 +372,7 @@ func (c *gcmCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) {
} }
length := binary.BigEndian.Uint32(c.prefix[:]) length := binary.BigEndian.Uint32(c.prefix[:])
if length > maxPacket { if length > maxPacket {
return nil, errors.New("ssh: max packet length exceeded.") return nil, errors.New("ssh: max packet length exceeded")
} }
if cap(c.buf) < int(length+gcmTagSize) { if cap(c.buf) < int(length+gcmTagSize) {
@@ -548,11 +548,11 @@ func (c *cbcCipher) readPacketLeaky(seqNum uint32, r io.Reader) ([]byte, error)
c.packetData = c.packetData[:entirePacketSize] c.packetData = c.packetData[:entirePacketSize]
} }
if n, err := io.ReadFull(r, c.packetData[firstBlockLength:]); err != nil { n, err := io.ReadFull(r, c.packetData[firstBlockLength:])
if err != nil {
return nil, err return nil, err
} else {
c.oracleCamouflage -= uint32(n)
} }
c.oracleCamouflage -= uint32(n)
remainingCrypted := c.packetData[firstBlockLength:macStart] remainingCrypted := c.packetData[firstBlockLength:macStart]
c.decrypter.CryptBlocks(remainingCrypted, remainingCrypted) c.decrypter.CryptBlocks(remainingCrypted, remainingCrypted)

View File

@@ -21,12 +21,19 @@ func TestDefaultCiphersExist(t *testing.T) {
} }
func TestPacketCiphers(t *testing.T) { func TestPacketCiphers(t *testing.T) {
// Still test aes128cbc cipher although it's commented out. defaultMac := "hmac-sha2-256"
cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil} defaultCipher := "aes128-ctr"
defer delete(cipherModes, aes128cbcID)
for cipher := range cipherModes { for cipher := range cipherModes {
t.Run("cipher="+cipher,
func(t *testing.T) { testPacketCipher(t, cipher, defaultMac) })
}
for mac := range macModes { for mac := range macModes {
t.Run("mac="+mac,
func(t *testing.T) { testPacketCipher(t, defaultCipher, mac) })
}
}
func testPacketCipher(t *testing.T, cipher, mac string) {
kr := &kexResult{Hash: crypto.SHA1} kr := &kexResult{Hash: crypto.SHA1}
algs := directionAlgorithms{ algs := directionAlgorithms{
Cipher: cipher, Cipher: cipher,
@@ -35,35 +42,29 @@ func TestPacketCiphers(t *testing.T) {
} }
client, err := newPacketCipher(clientKeys, algs, kr) client, err := newPacketCipher(clientKeys, algs, kr)
if err != nil { if err != nil {
t.Errorf("newPacketCipher(client, %q, %q): %v", cipher, mac, err) t.Fatalf("newPacketCipher(client, %q, %q): %v", cipher, mac, err)
continue
} }
server, err := newPacketCipher(clientKeys, algs, kr) server, err := newPacketCipher(clientKeys, algs, kr)
if err != nil { if err != nil {
t.Errorf("newPacketCipher(client, %q, %q): %v", cipher, mac, err) t.Fatalf("newPacketCipher(client, %q, %q): %v", cipher, mac, err)
continue
} }
want := "bla bla" want := "bla bla"
input := []byte(want) input := []byte(want)
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
if err := client.writePacket(0, buf, rand.Reader, input); err != nil { if err := client.writePacket(0, buf, rand.Reader, input); err != nil {
t.Errorf("writePacket(%q, %q): %v", cipher, mac, err) t.Fatalf("writePacket(%q, %q): %v", cipher, mac, err)
continue
} }
packet, err := server.readPacket(0, buf) packet, err := server.readPacket(0, buf)
if err != nil { if err != nil {
t.Errorf("readPacket(%q, %q): %v", cipher, mac, err) t.Fatalf("readPacket(%q, %q): %v", cipher, mac, err)
continue
} }
if string(packet) != want { if string(packet) != want {
t.Errorf("roundtrip(%q, %q): got %q, want %q", cipher, mac, packet, want) t.Errorf("roundtrip(%q, %q): got %q, want %q", cipher, mac, packet, want)
} }
} }
}
}
func TestCBCOracleCounterMeasure(t *testing.T) { func TestCBCOracleCounterMeasure(t *testing.T) {
cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil} cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil}

View File

@@ -9,6 +9,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"os"
"sync" "sync"
"time" "time"
) )
@@ -187,6 +188,10 @@ func Dial(network, addr string, config *ClientConfig) (*Client, error) {
// net.Conn underlying the the SSH connection. // net.Conn underlying the the SSH connection.
type HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error type HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
// BannerCallback is the function type used for treat the banner sent by
// the server. A BannerCallback receives the message sent by the remote server.
type BannerCallback func(message string) error
// A ClientConfig structure is used to configure a Client. It must not be // A ClientConfig structure is used to configure a Client. It must not be
// modified after having been passed to an SSH function. // modified after having been passed to an SSH function.
type ClientConfig struct { type ClientConfig struct {
@@ -209,6 +214,12 @@ type ClientConfig struct {
// FixedHostKey can be used for simplistic host key checks. // FixedHostKey can be used for simplistic host key checks.
HostKeyCallback HostKeyCallback HostKeyCallback HostKeyCallback
// BannerCallback is called during the SSH dance to display a custom
// server's message. The client configuration can supply this callback to
// handle it as wished. The function BannerDisplayStderr can be used for
// simplistic display on Stderr.
BannerCallback BannerCallback
// ClientVersion contains the version identification string that will // ClientVersion contains the version identification string that will
// be used for the connection. If empty, a reasonable default is used. // be used for the connection. If empty, a reasonable default is used.
ClientVersion string ClientVersion string
@@ -255,3 +266,13 @@ func FixedHostKey(key PublicKey) HostKeyCallback {
hk := &fixedHostKey{key} hk := &fixedHostKey{key}
return hk.check return hk.check
} }
// BannerDisplayStderr returns a function that can be used for
// ClientConfig.BannerCallback to display banners on os.Stderr.
func BannerDisplayStderr() BannerCallback {
return func(banner string) error {
_, err := os.Stderr.WriteString(banner)
return err
}
}

View File

@@ -283,7 +283,9 @@ func confirmKeyAck(key PublicKey, c packetConn) (bool, error) {
} }
switch packet[0] { switch packet[0] {
case msgUserAuthBanner: case msgUserAuthBanner:
// TODO(gpaul): add callback to present the banner to the user if err := handleBannerResponse(c, packet); err != nil {
return false, err
}
case msgUserAuthPubKeyOk: case msgUserAuthPubKeyOk:
var msg userAuthPubKeyOkMsg var msg userAuthPubKeyOkMsg
if err := Unmarshal(packet, &msg); err != nil { if err := Unmarshal(packet, &msg); err != nil {
@@ -325,7 +327,9 @@ func handleAuthResponse(c packetConn) (bool, []string, error) {
switch packet[0] { switch packet[0] {
case msgUserAuthBanner: case msgUserAuthBanner:
// TODO: add callback to present the banner to the user if err := handleBannerResponse(c, packet); err != nil {
return false, nil, err
}
case msgUserAuthFailure: case msgUserAuthFailure:
var msg userAuthFailureMsg var msg userAuthFailureMsg
if err := Unmarshal(packet, &msg); err != nil { if err := Unmarshal(packet, &msg); err != nil {
@@ -340,6 +344,24 @@ func handleAuthResponse(c packetConn) (bool, []string, error) {
} }
} }
func handleBannerResponse(c packetConn, packet []byte) error {
var msg userAuthBannerMsg
if err := Unmarshal(packet, &msg); err != nil {
return err
}
transport, ok := c.(*handshakeTransport)
if !ok {
return nil
}
if transport.bannerCallback != nil {
return transport.bannerCallback(msg.Message)
}
return nil
}
// KeyboardInteractiveChallenge should print questions, optionally // KeyboardInteractiveChallenge should print questions, optionally
// disabling echoing (e.g. for passwords), and return all the answers. // disabling echoing (e.g. for passwords), and return all the answers.
// Challenge may be called multiple times in a single session. After // Challenge may be called multiple times in a single session. After
@@ -385,7 +407,9 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
// like handleAuthResponse, but with less options. // like handleAuthResponse, but with less options.
switch packet[0] { switch packet[0] {
case msgUserAuthBanner: case msgUserAuthBanner:
// TODO: Print banners during userauth. if err := handleBannerResponse(c, packet); err != nil {
return false, nil, err
}
continue continue
case msgUserAuthInfoRequest: case msgUserAuthInfoRequest:
// OK // OK

View File

@@ -5,39 +5,75 @@
package ssh package ssh
import ( import (
"net"
"strings" "strings"
"testing" "testing"
) )
func testClientVersion(t *testing.T, config *ClientConfig, expected string) { func TestClientVersion(t *testing.T) {
clientConn, serverConn := net.Pipe() for _, tt := range []struct {
defer clientConn.Close() name string
receivedVersion := make(chan string, 1) version string
config.HostKeyCallback = InsecureIgnoreHostKey() multiLine string
go func() { wantErr bool
version, err := readVersion(serverConn) }{
{
name: "default version",
version: packageVersion,
},
{
name: "custom version",
version: "SSH-2.0-CustomClientVersionString",
},
{
name: "good multi line version",
version: packageVersion,
multiLine: strings.Repeat("ignored\r\n", 20),
},
{
name: "bad multi line version",
version: packageVersion,
multiLine: "bad multi line version",
wantErr: true,
},
{
name: "long multi line version",
version: packageVersion,
multiLine: strings.Repeat("long multi line version\r\n", 50)[:256],
wantErr: true,
},
} {
t.Run(tt.name, func(t *testing.T) {
c1, c2, err := netPipe()
if err != nil { if err != nil {
receivedVersion <- "" t.Fatalf("netPipe: %v", err)
} else {
receivedVersion <- string(version)
} }
serverConn.Close() defer c1.Close()
defer c2.Close()
go func() {
if tt.multiLine != "" {
c1.Write([]byte(tt.multiLine))
}
NewClientConn(c1, "", &ClientConfig{
ClientVersion: tt.version,
HostKeyCallback: InsecureIgnoreHostKey(),
})
c1.Close()
}() }()
NewClientConn(clientConn, "", config) conf := &ServerConfig{NoClientAuth: true}
actual := <-receivedVersion conf.AddHostKey(testSigners["rsa"])
if actual != expected { conn, _, _, err := NewServerConn(c2, conf)
t.Fatalf("got %s; want %s", actual, expected) if err == nil == tt.wantErr {
t.Fatalf("got err %v; wantErr %t", err, tt.wantErr)
} }
if tt.wantErr {
// Don't verify the version on an expected error.
return
} }
if got := string(conn.ClientVersion()); got != tt.version {
func TestCustomClientVersion(t *testing.T) { t.Fatalf("got %q; want %q", got, tt.version)
version := "Test-Client-Version-0.0" }
testClientVersion(t, &ClientConfig{ClientVersion: version}, version) })
} }
func TestDefaultClientVersion(t *testing.T) {
testClientVersion(t, &ClientConfig{}, packageVersion)
} }
func TestHostKeyCheck(t *testing.T) { func TestHostKeyCheck(t *testing.T) {
@@ -79,3 +115,52 @@ func TestHostKeyCheck(t *testing.T) {
} }
} }
} }
func TestBannerCallback(t *testing.T) {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
serverConf := &ServerConfig{
PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
return &Permissions{}, nil
},
BannerCallback: func(conn ConnMetadata) string {
return "Hello World"
},
}
serverConf.AddHostKey(testSigners["rsa"])
go NewServerConn(c1, serverConf)
var receivedBanner string
var bannerCount int
clientConf := ClientConfig{
Auth: []AuthMethod{
Password("123"),
},
User: "user",
HostKeyCallback: InsecureIgnoreHostKey(),
BannerCallback: func(message string) error {
bannerCount++
receivedBanner = message
return nil
},
}
_, _, _, err = NewClientConn(c2, "", &clientConf)
if err != nil {
t.Fatal(err)
}
if bannerCount != 1 {
t.Errorf("got %d banners; want 1", bannerCount)
}
expected := "Hello World"
if receivedBanner != expected {
t.Fatalf("got %s; want %s", receivedBanner, expected)
}
}

View File

@@ -242,7 +242,7 @@ func (c *Config) SetDefaults() {
// buildDataSignedForAuth returns the data that is signed in order to prove // buildDataSignedForAuth returns the data that is signed in order to prove
// possession of a private key. See RFC 4252, section 7. // possession of a private key. See RFC 4252, section 7.
func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte { func buildDataSignedForAuth(sessionID []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
data := struct { data := struct {
Session []byte Session []byte
Type byte Type byte
@@ -253,7 +253,7 @@ func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubK
Algo []byte Algo []byte
PubKey []byte PubKey []byte
}{ }{
sessionId, sessionID,
msgUserAuthRequest, msgUserAuthRequest,
req.User, req.User,
req.Service, req.Service,

View File

@@ -78,6 +78,11 @@ type handshakeTransport struct {
dialAddress string dialAddress string
remoteAddr net.Addr remoteAddr net.Addr
// bannerCallback is non-empty if we are the client and it has been set in
// ClientConfig. In that case it is called during the user authentication
// dance to handle a custom server's message.
bannerCallback BannerCallback
// Algorithms agreed in the last key exchange. // Algorithms agreed in the last key exchange.
algorithms *algorithms algorithms *algorithms
@@ -120,6 +125,7 @@ func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byt
t.dialAddress = dialAddr t.dialAddress = dialAddr
t.remoteAddr = addr t.remoteAddr = addr
t.hostKeyCallback = config.HostKeyCallback t.hostKeyCallback = config.HostKeyCallback
t.bannerCallback = config.BannerCallback
if config.HostKeyAlgorithms != nil { if config.HostKeyAlgorithms != nil {
t.hostKeyAlgorithms = config.HostKeyAlgorithms t.hostKeyAlgorithms = config.HostKeyAlgorithms
} else { } else {

View File

@@ -119,7 +119,7 @@ func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handsha
return nil, err return nil, err
} }
kInt, err := group.diffieHellman(kexDHReply.Y, x) ki, err := group.diffieHellman(kexDHReply.Y, x)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -129,8 +129,8 @@ func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handsha
writeString(h, kexDHReply.HostKey) writeString(h, kexDHReply.HostKey)
writeInt(h, X) writeInt(h, X)
writeInt(h, kexDHReply.Y) writeInt(h, kexDHReply.Y)
K := make([]byte, intLength(kInt)) K := make([]byte, intLength(ki))
marshalInt(K, kInt) marshalInt(K, ki)
h.Write(K) h.Write(K)
return &kexResult{ return &kexResult{
@@ -164,7 +164,7 @@ func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handsha
} }
Y := new(big.Int).Exp(group.g, y, group.p) Y := new(big.Int).Exp(group.g, y, group.p)
kInt, err := group.diffieHellman(kexDHInit.X, y) ki, err := group.diffieHellman(kexDHInit.X, y)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -177,8 +177,8 @@ func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handsha
writeInt(h, kexDHInit.X) writeInt(h, kexDHInit.X)
writeInt(h, Y) writeInt(h, Y)
K := make([]byte, intLength(kInt)) K := make([]byte, intLength(ki))
marshalInt(K, kInt) marshalInt(K, ki)
h.Write(K) h.Write(K)
H := h.Sum(nil) H := h.Sum(nil)
@@ -462,9 +462,9 @@ func (kex *curve25519sha256) Client(c packetConn, rand io.Reader, magics *handsh
writeString(h, kp.pub[:]) writeString(h, kp.pub[:])
writeString(h, reply.EphemeralPubKey) writeString(h, reply.EphemeralPubKey)
kInt := new(big.Int).SetBytes(secret[:]) ki := new(big.Int).SetBytes(secret[:])
K := make([]byte, intLength(kInt)) K := make([]byte, intLength(ki))
marshalInt(K, kInt) marshalInt(K, ki)
h.Write(K) h.Write(K)
return &kexResult{ return &kexResult{
@@ -510,9 +510,9 @@ func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handsh
writeString(h, kexInit.ClientPubKey) writeString(h, kexInit.ClientPubKey)
writeString(h, kp.pub[:]) writeString(h, kp.pub[:])
kInt := new(big.Int).SetBytes(secret[:]) ki := new(big.Int).SetBytes(secret[:])
K := make([]byte, intLength(kInt)) K := make([]byte, intLength(ki))
marshalInt(K, kInt) marshalInt(K, ki)
h.Write(K) h.Write(K)
H := h.Sum(nil) H := h.Sum(nil)

View File

@@ -363,7 +363,7 @@ func (r *rsaPublicKey) CryptoPublicKey() crypto.PublicKey {
type dsaPublicKey dsa.PublicKey type dsaPublicKey dsa.PublicKey
func (r *dsaPublicKey) Type() string { func (k *dsaPublicKey) Type() string {
return "ssh-dss" return "ssh-dss"
} }
@@ -481,12 +481,12 @@ func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
type ecdsaPublicKey ecdsa.PublicKey type ecdsaPublicKey ecdsa.PublicKey
func (key *ecdsaPublicKey) Type() string { func (k *ecdsaPublicKey) Type() string {
return "ecdsa-sha2-" + key.nistID() return "ecdsa-sha2-" + k.nistID()
} }
func (key *ecdsaPublicKey) nistID() string { func (k *ecdsaPublicKey) nistID() string {
switch key.Params().BitSize { switch k.Params().BitSize {
case 256: case 256:
return "nistp256" return "nistp256"
case 384: case 384:
@@ -499,7 +499,7 @@ func (key *ecdsaPublicKey) nistID() string {
type ed25519PublicKey ed25519.PublicKey type ed25519PublicKey ed25519.PublicKey
func (key ed25519PublicKey) Type() string { func (k ed25519PublicKey) Type() string {
return KeyAlgoED25519 return KeyAlgoED25519
} }
@@ -518,23 +518,23 @@ func parseED25519(in []byte) (out PublicKey, rest []byte, err error) {
return (ed25519PublicKey)(key), w.Rest, nil return (ed25519PublicKey)(key), w.Rest, nil
} }
func (key ed25519PublicKey) Marshal() []byte { func (k ed25519PublicKey) Marshal() []byte {
w := struct { w := struct {
Name string Name string
KeyBytes []byte KeyBytes []byte
}{ }{
KeyAlgoED25519, KeyAlgoED25519,
[]byte(key), []byte(k),
} }
return Marshal(&w) return Marshal(&w)
} }
func (key ed25519PublicKey) Verify(b []byte, sig *Signature) error { func (k ed25519PublicKey) Verify(b []byte, sig *Signature) error {
if sig.Format != key.Type() { if sig.Format != k.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type()) return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
} }
edKey := (ed25519.PublicKey)(key) edKey := (ed25519.PublicKey)(k)
if ok := ed25519.Verify(edKey, b, sig.Blob); !ok { if ok := ed25519.Verify(edKey, b, sig.Blob); !ok {
return errors.New("ssh: signature did not verify") return errors.New("ssh: signature did not verify")
} }
@@ -595,9 +595,9 @@ func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) {
return (*ecdsaPublicKey)(key), w.Rest, nil return (*ecdsaPublicKey)(key), w.Rest, nil
} }
func (key *ecdsaPublicKey) Marshal() []byte { func (k *ecdsaPublicKey) Marshal() []byte {
// See RFC 5656, section 3.1. // See RFC 5656, section 3.1.
keyBytes := elliptic.Marshal(key.Curve, key.X, key.Y) keyBytes := elliptic.Marshal(k.Curve, k.X, k.Y)
// ECDSA publickey struct layout should match the struct used by // ECDSA publickey struct layout should match the struct used by
// parseECDSACert in the x/crypto/ssh/agent package. // parseECDSACert in the x/crypto/ssh/agent package.
w := struct { w := struct {
@@ -605,20 +605,20 @@ func (key *ecdsaPublicKey) Marshal() []byte {
ID string ID string
Key []byte Key []byte
}{ }{
key.Type(), k.Type(),
key.nistID(), k.nistID(),
keyBytes, keyBytes,
} }
return Marshal(&w) return Marshal(&w)
} }
func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error { func (k *ecdsaPublicKey) Verify(data []byte, sig *Signature) error {
if sig.Format != key.Type() { if sig.Format != k.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type()) return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
} }
h := ecHash(key.Curve).New() h := ecHash(k.Curve).New()
h.Write(data) h.Write(data)
digest := h.Sum(nil) digest := h.Sum(nil)
@@ -635,7 +635,7 @@ func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error {
return err return err
} }
if ecdsa.Verify((*ecdsa.PublicKey)(key), digest, ecSig.R, ecSig.S) { if ecdsa.Verify((*ecdsa.PublicKey)(k), digest, ecSig.R, ecSig.S) {
return nil return nil
} }
return errors.New("ssh: signature did not verify") return errors.New("ssh: signature did not verify")
@@ -758,7 +758,7 @@ func NewPublicKey(key interface{}) (PublicKey, error) {
return (*rsaPublicKey)(key), nil return (*rsaPublicKey)(key), nil
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
if !supportedEllipticCurve(key.Curve) { if !supportedEllipticCurve(key.Curve) {
return nil, errors.New("ssh: only P-256, P-384 and P-521 EC keys are supported.") return nil, errors.New("ssh: only P-256, P-384 and P-521 EC keys are supported")
} }
return (*ecdsaPublicKey)(key), nil return (*ecdsaPublicKey)(key), nil
case *dsa.PublicKey: case *dsa.PublicKey:

View File

@@ -108,8 +108,8 @@ func wildcardMatch(pat []byte, str []byte) bool {
} }
} }
func (l *hostPattern) match(a addr) bool { func (p *hostPattern) match(a addr) bool {
return wildcardMatch([]byte(l.addr.host), []byte(a.host)) && l.addr.port == a.port return wildcardMatch([]byte(p.addr.host), []byte(a.host)) && p.addr.port == a.port
} }
type keyDBLine struct { type keyDBLine struct {

View File

@@ -23,10 +23,6 @@ const (
msgUnimplemented = 3 msgUnimplemented = 3
msgDebug = 4 msgDebug = 4
msgNewKeys = 21 msgNewKeys = 21
// Standard authentication messages
msgUserAuthSuccess = 52
msgUserAuthBanner = 53
) )
// SSH messages: // SSH messages:
@@ -137,6 +133,18 @@ type userAuthFailureMsg struct {
PartialSuccess bool PartialSuccess bool
} }
// See RFC 4252, section 5.1
const msgUserAuthSuccess = 52
// See RFC 4252, section 5.4
const msgUserAuthBanner = 53
type userAuthBannerMsg struct {
Message string `sshtype:"53"`
// unused, but required to allow message parsing
Language string
}
// See RFC 4256, section 3.2 // See RFC 4256, section 3.2
const msgUserAuthInfoRequest = 60 const msgUserAuthInfoRequest = 60
const msgUserAuthInfoResponse = 61 const msgUserAuthInfoResponse = 61
@@ -154,7 +162,7 @@ const msgChannelOpen = 90
type channelOpenMsg struct { type channelOpenMsg struct {
ChanType string `sshtype:"90"` ChanType string `sshtype:"90"`
PeersId uint32 PeersID uint32
PeersWindow uint32 PeersWindow uint32
MaxPacketSize uint32 MaxPacketSize uint32
TypeSpecificData []byte `ssh:"rest"` TypeSpecificData []byte `ssh:"rest"`
@@ -165,7 +173,7 @@ const msgChannelData = 94
// Used for debug print outs of packets. // Used for debug print outs of packets.
type channelDataMsg struct { type channelDataMsg struct {
PeersId uint32 `sshtype:"94"` PeersID uint32 `sshtype:"94"`
Length uint32 Length uint32
Rest []byte `ssh:"rest"` Rest []byte `ssh:"rest"`
} }
@@ -174,8 +182,8 @@ type channelDataMsg struct {
const msgChannelOpenConfirm = 91 const msgChannelOpenConfirm = 91
type channelOpenConfirmMsg struct { type channelOpenConfirmMsg struct {
PeersId uint32 `sshtype:"91"` PeersID uint32 `sshtype:"91"`
MyId uint32 MyID uint32
MyWindow uint32 MyWindow uint32
MaxPacketSize uint32 MaxPacketSize uint32
TypeSpecificData []byte `ssh:"rest"` TypeSpecificData []byte `ssh:"rest"`
@@ -185,7 +193,7 @@ type channelOpenConfirmMsg struct {
const msgChannelOpenFailure = 92 const msgChannelOpenFailure = 92
type channelOpenFailureMsg struct { type channelOpenFailureMsg struct {
PeersId uint32 `sshtype:"92"` PeersID uint32 `sshtype:"92"`
Reason RejectionReason Reason RejectionReason
Message string Message string
Language string Language string
@@ -194,7 +202,7 @@ type channelOpenFailureMsg struct {
const msgChannelRequest = 98 const msgChannelRequest = 98
type channelRequestMsg struct { type channelRequestMsg struct {
PeersId uint32 `sshtype:"98"` PeersID uint32 `sshtype:"98"`
Request string Request string
WantReply bool WantReply bool
RequestSpecificData []byte `ssh:"rest"` RequestSpecificData []byte `ssh:"rest"`
@@ -204,28 +212,28 @@ type channelRequestMsg struct {
const msgChannelSuccess = 99 const msgChannelSuccess = 99
type channelRequestSuccessMsg struct { type channelRequestSuccessMsg struct {
PeersId uint32 `sshtype:"99"` PeersID uint32 `sshtype:"99"`
} }
// See RFC 4254, section 5.4. // See RFC 4254, section 5.4.
const msgChannelFailure = 100 const msgChannelFailure = 100
type channelRequestFailureMsg struct { type channelRequestFailureMsg struct {
PeersId uint32 `sshtype:"100"` PeersID uint32 `sshtype:"100"`
} }
// See RFC 4254, section 5.3 // See RFC 4254, section 5.3
const msgChannelClose = 97 const msgChannelClose = 97
type channelCloseMsg struct { type channelCloseMsg struct {
PeersId uint32 `sshtype:"97"` PeersID uint32 `sshtype:"97"`
} }
// See RFC 4254, section 5.3 // See RFC 4254, section 5.3
const msgChannelEOF = 96 const msgChannelEOF = 96
type channelEOFMsg struct { type channelEOFMsg struct {
PeersId uint32 `sshtype:"96"` PeersID uint32 `sshtype:"96"`
} }
// See RFC 4254, section 4 // See RFC 4254, section 4
@@ -255,7 +263,7 @@ type globalRequestFailureMsg struct {
const msgChannelWindowAdjust = 93 const msgChannelWindowAdjust = 93
type windowAdjustMsg struct { type windowAdjustMsg struct {
PeersId uint32 `sshtype:"93"` PeersID uint32 `sshtype:"93"`
AdditionalBytes uint32 AdditionalBytes uint32
} }

View File

@@ -278,7 +278,7 @@ func (m *mux) handleChannelOpen(packet []byte) error {
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
failMsg := channelOpenFailureMsg{ failMsg := channelOpenFailureMsg{
PeersId: msg.PeersId, PeersID: msg.PeersID,
Reason: ConnectionFailed, Reason: ConnectionFailed,
Message: "invalid request", Message: "invalid request",
Language: "en_US.UTF-8", Language: "en_US.UTF-8",
@@ -287,7 +287,7 @@ func (m *mux) handleChannelOpen(packet []byte) error {
} }
c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData) c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
c.remoteId = msg.PeersId c.remoteId = msg.PeersID
c.maxRemotePayload = msg.MaxPacketSize c.maxRemotePayload = msg.MaxPacketSize
c.remoteWin.add(msg.PeersWindow) c.remoteWin.add(msg.PeersWindow)
m.incomingChannels <- c m.incomingChannels <- c
@@ -313,7 +313,7 @@ func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
PeersWindow: ch.myWindow, PeersWindow: ch.myWindow,
MaxPacketSize: ch.maxIncomingPayload, MaxPacketSize: ch.maxIncomingPayload,
TypeSpecificData: extra, TypeSpecificData: extra,
PeersId: ch.localId, PeersID: ch.localId,
} }
if err := m.sendMessage(open); err != nil { if err := m.sendMessage(open); err != nil {
return nil, err return nil, err

View File

@@ -95,6 +95,10 @@ type ServerConfig struct {
// Note that RFC 4253 section 4.2 requires that this string start with // Note that RFC 4253 section 4.2 requires that this string start with
// "SSH-2.0-". // "SSH-2.0-".
ServerVersion string ServerVersion string
// BannerCallback, if present, is called and the return string is sent to
// the client after key exchange completed but before authentication.
BannerCallback func(conn ConnMetadata) string
} }
// AddHostKey adds a private key as a host key. If an existing host // AddHostKey adds a private key as a host key. If an existing host
@@ -252,7 +256,7 @@ func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error)
func isAcceptableAlgo(algo string) bool { func isAcceptableAlgo(algo string) bool {
switch algo { switch algo {
case KeyAlgoRSA, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, KeyAlgoED25519, case KeyAlgoRSA, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, KeyAlgoED25519,
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01: CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoED25519v01:
return true return true
} }
return false return false
@@ -312,6 +316,7 @@ func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, err
authFailures := 0 authFailures := 0
var authErrs []error var authErrs []error
var displayedBanner bool
userAuthLoop: userAuthLoop:
for { for {
@@ -343,6 +348,20 @@ userAuthLoop:
} }
s.user = userAuthReq.User s.user = userAuthReq.User
if !displayedBanner && config.BannerCallback != nil {
displayedBanner = true
msg := config.BannerCallback(s)
if msg != "" {
bannerMsg := &userAuthBannerMsg{
Message: msg,
}
if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil {
return nil, err
}
}
}
perms = nil perms = nil
authErr := errors.New("no auth passed yet") authErr := errors.New("no auth passed yet")

View File

@@ -406,7 +406,7 @@ func (s *Session) Wait() error {
s.stdinPipeWriter.Close() s.stdinPipeWriter.Close()
} }
var copyError error var copyError error
for _ = range s.copyFuncs { for range s.copyFuncs {
if err := <-s.errors; err != nil && copyError == nil { if err := <-s.errors; err != nil && copyError == nil {
copyError = err copyError = err
} }

View File

@@ -617,7 +617,7 @@ func writeWithCRLF(w io.Writer, buf []byte) (n int, err error) {
if _, err = w.Write(crlf); err != nil { if _, err = w.Write(crlf); err != nil {
return n, err return n, err
} }
n += 1 n++
buf = buf[1:] buf = buf[1:]
} }
} }

View File

@@ -17,44 +17,41 @@
package terminal // import "golang.org/x/crypto/ssh/terminal" package terminal // import "golang.org/x/crypto/ssh/terminal"
import ( import (
"syscall"
"unsafe"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// State contains the state of a terminal. // State contains the state of a terminal.
type State struct { type State struct {
termios syscall.Termios termios unix.Termios
} }
// IsTerminal returns true if the given file descriptor is a terminal. // IsTerminal returns true if the given file descriptor is a terminal.
func IsTerminal(fd int) bool { func IsTerminal(fd int) bool {
var termios syscall.Termios _, err := unix.IoctlGetTermios(fd, ioctlReadTermios)
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&termios)), 0, 0, 0) return err == nil
return err == 0
} }
// MakeRaw put the terminal connected to the given file descriptor into raw // MakeRaw put the terminal connected to the given file descriptor into raw
// mode and returns the previous state of the terminal so that it can be // mode and returns the previous state of the terminal so that it can be
// restored. // restored.
func MakeRaw(fd int) (*State, error) { func MakeRaw(fd int) (*State, error) {
var oldState State termios, err := unix.IoctlGetTermios(fd, ioctlReadTermios)
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState.termios)), 0, 0, 0); err != 0 { if err != nil {
return nil, err return nil, err
} }
newState := oldState.termios oldState := State{termios: *termios}
// This attempts to replicate the behaviour documented for cfmakeraw in // This attempts to replicate the behaviour documented for cfmakeraw in
// the termios(3) manpage. // the termios(3) manpage.
newState.Iflag &^= syscall.IGNBRK | syscall.BRKINT | syscall.PARMRK | syscall.ISTRIP | syscall.INLCR | syscall.IGNCR | syscall.ICRNL | syscall.IXON termios.Iflag &^= unix.IGNBRK | unix.BRKINT | unix.PARMRK | unix.ISTRIP | unix.INLCR | unix.IGNCR | unix.ICRNL | unix.IXON
newState.Oflag &^= syscall.OPOST termios.Oflag &^= unix.OPOST
newState.Lflag &^= syscall.ECHO | syscall.ECHONL | syscall.ICANON | syscall.ISIG | syscall.IEXTEN termios.Lflag &^= unix.ECHO | unix.ECHONL | unix.ICANON | unix.ISIG | unix.IEXTEN
newState.Cflag &^= syscall.CSIZE | syscall.PARENB termios.Cflag &^= unix.CSIZE | unix.PARENB
newState.Cflag |= syscall.CS8 termios.Cflag |= unix.CS8
newState.Cc[unix.VMIN] = 1 termios.Cc[unix.VMIN] = 1
newState.Cc[unix.VTIME] = 0 termios.Cc[unix.VTIME] = 0
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&newState)), 0, 0, 0); err != 0 { if err := unix.IoctlSetTermios(fd, ioctlWriteTermios, termios); err != nil {
return nil, err return nil, err
} }
@@ -64,59 +61,55 @@ func MakeRaw(fd int) (*State, error) {
// GetState returns the current state of a terminal which may be useful to // GetState returns the current state of a terminal which may be useful to
// restore the terminal after a signal. // restore the terminal after a signal.
func GetState(fd int) (*State, error) { func GetState(fd int) (*State, error) {
var oldState State termios, err := unix.IoctlGetTermios(fd, ioctlReadTermios)
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState.termios)), 0, 0, 0); err != 0 { if err != nil {
return nil, err return nil, err
} }
return &oldState, nil return &State{termios: *termios}, nil
} }
// Restore restores the terminal connected to the given file descriptor to a // Restore restores the terminal connected to the given file descriptor to a
// previous state. // previous state.
func Restore(fd int, state *State) error { func Restore(fd int, state *State) error {
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&state.termios)), 0, 0, 0); err != 0 { return unix.IoctlSetTermios(fd, ioctlWriteTermios, &state.termios)
return err
}
return nil
} }
// GetSize returns the dimensions of the given terminal. // GetSize returns the dimensions of the given terminal.
func GetSize(fd int) (width, height int, err error) { func GetSize(fd int) (width, height int, err error) {
var dimensions [4]uint16 ws, err := unix.IoctlGetWinsize(fd, unix.TIOCGWINSZ)
if err != nil {
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), uintptr(syscall.TIOCGWINSZ), uintptr(unsafe.Pointer(&dimensions)), 0, 0, 0); err != 0 {
return -1, -1, err return -1, -1, err
} }
return int(dimensions[1]), int(dimensions[0]), nil return int(ws.Col), int(ws.Row), nil
} }
// passwordReader is an io.Reader that reads from a specific file descriptor. // passwordReader is an io.Reader that reads from a specific file descriptor.
type passwordReader int type passwordReader int
func (r passwordReader) Read(buf []byte) (int, error) { func (r passwordReader) Read(buf []byte) (int, error) {
return syscall.Read(int(r), buf) return unix.Read(int(r), buf)
} }
// ReadPassword reads a line of input from a terminal without local echo. This // ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice // is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n. // returned does not include the \n.
func ReadPassword(fd int) ([]byte, error) { func ReadPassword(fd int) ([]byte, error) {
var oldState syscall.Termios termios, err := unix.IoctlGetTermios(fd, ioctlReadTermios)
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0); err != 0 { if err != nil {
return nil, err return nil, err
} }
newState := oldState newState := *termios
newState.Lflag &^= syscall.ECHO newState.Lflag &^= unix.ECHO
newState.Lflag |= syscall.ICANON | syscall.ISIG newState.Lflag |= unix.ICANON | unix.ISIG
newState.Iflag |= syscall.ICRNL newState.Iflag |= unix.ICRNL
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&newState)), 0, 0, 0); err != 0 { if err := unix.IoctlSetTermios(fd, ioctlWriteTermios, &newState); err != nil {
return nil, err return nil, err
} }
defer func() { defer func() {
syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0) unix.IoctlSetTermios(fd, ioctlWriteTermios, termios)
}() }()
return readPasswordLine(passwordReader(fd)) return readPasswordLine(passwordReader(fd))

View File

@@ -17,6 +17,8 @@
package terminal package terminal
import ( import (
"os"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
) )
@@ -71,13 +73,6 @@ func GetSize(fd int) (width, height int, err error) {
return int(info.Size.X), int(info.Size.Y), nil return int(info.Size.X), int(info.Size.Y), nil
} }
// passwordReader is an io.Reader that reads from a specific Windows HANDLE.
type passwordReader int
func (r passwordReader) Read(buf []byte) (int, error) {
return windows.Read(windows.Handle(r), buf)
}
// ReadPassword reads a line of input from a terminal without local echo. This // ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice // is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n. // returned does not include the \n.
@@ -98,5 +93,5 @@ func ReadPassword(fd int) ([]byte, error) {
windows.SetConsoleMode(windows.Handle(fd), old) windows.SetConsoleMode(windows.Handle(fd), old)
}() }()
return readPasswordLine(passwordReader(fd)) return readPasswordLine(os.NewFile(uintptr(fd), "stdin"))
} }

32
vendor/golang.org/x/crypto/ssh/test/banner_test.go generated vendored Normal file
View File

@@ -0,0 +1,32 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin dragonfly freebsd linux netbsd openbsd
package test
import (
"testing"
)
func TestBannerCallbackAgainstOpenSSH(t *testing.T) {
server := newServer(t)
defer server.Shutdown()
clientConf := clientConfig()
var receivedBanner string
clientConf.BannerCallback = func(message string) error {
receivedBanner = message
return nil
}
conn := server.Dial(clientConf)
defer conn.Close()
expected := "Server Banner"
if receivedBanner != expected {
t.Fatalf("got %v; want %v", receivedBanner, expected)
}
}

View File

@@ -2,6 +2,6 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package contains integration tests for the // Package test contains integration tests for the
// golang.org/x/crypto/ssh package. // golang.org/x/crypto/ssh package.
package test // import "golang.org/x/crypto/ssh/test" package test // import "golang.org/x/crypto/ssh/test"

View File

@@ -333,6 +333,7 @@ func TestCiphers(t *testing.T) {
cipherOrder = append(cipherOrder, "aes128-cbc", "3des-cbc") cipherOrder = append(cipherOrder, "aes128-cbc", "3des-cbc")
for _, ciph := range cipherOrder { for _, ciph := range cipherOrder {
t.Run(ciph, func(t *testing.T) {
server := newServer(t) server := newServer(t)
defer server.Shutdown() defer server.Shutdown()
conf := clientConfig() conf := clientConfig()
@@ -345,9 +346,9 @@ func TestCiphers(t *testing.T) {
} else { } else {
t.Fatalf("failed for cipher %q", ciph) t.Fatalf("failed for cipher %q", ciph)
} }
})
} }
} }
func TestMACs(t *testing.T) { func TestMACs(t *testing.T) {
var config ssh.Config var config ssh.Config
config.SetDefaults() config.SetDefaults()

View File

@@ -25,8 +25,9 @@ import (
"golang.org/x/crypto/ssh/testdata" "golang.org/x/crypto/ssh/testdata"
) )
const sshd_config = ` const sshdConfig = `
Protocol 2 Protocol 2
Banner {{.Dir}}/banner
HostKey {{.Dir}}/id_rsa HostKey {{.Dir}}/id_rsa
HostKey {{.Dir}}/id_dsa HostKey {{.Dir}}/id_dsa
HostKey {{.Dir}}/id_ecdsa HostKey {{.Dir}}/id_ecdsa
@@ -50,7 +51,7 @@ HostbasedAuthentication no
PubkeyAcceptedKeyTypes=* PubkeyAcceptedKeyTypes=*
` `
var configTmpl = template.Must(template.New("").Parse(sshd_config)) var configTmpl = template.Must(template.New("").Parse(sshdConfig))
type server struct { type server struct {
t *testing.T t *testing.T
@@ -256,6 +257,8 @@ func newServer(t *testing.T) *server {
} }
f.Close() f.Close()
writeFile(filepath.Join(dir, "banner"), []byte("Server Banner"))
for k, v := range testdata.PEMBytes { for k, v := range testdata.PEMBytes {
filename := "id_" + k filename := "id_" + k
writeFile(filepath.Join(dir, filename), v) writeFile(filepath.Join(dir, filename), v)
@@ -268,7 +271,7 @@ func newServer(t *testing.T) *server {
} }
var authkeys bytes.Buffer var authkeys bytes.Buffer
for k, _ := range testdata.PEMBytes { for k := range testdata.PEMBytes {
authkeys.Write(ssh.MarshalAuthorizedKey(testPublicKeys[k])) authkeys.Write(ssh.MarshalAuthorizedKey(testPublicKeys[k]))
} }
writeFile(filepath.Join(dir, "authorized_keys"), authkeys.Bytes()) writeFile(filepath.Join(dir, "authorized_keys"), authkeys.Bytes())

View File

@@ -23,6 +23,27 @@ MHcCAQEEINGWx0zo6fhJ/0EAfrPzVFyFC9s18lBt3cRoEDhS3ARooAoGCCqGSM49
AwEHoUQDQgAEi9Hdw6KvZcWxfg2IDhA7UkpDtzzt6ZqJXSsFdLd+Kx4S3Sx4cVO+ AwEHoUQDQgAEi9Hdw6KvZcWxfg2IDhA7UkpDtzzt6ZqJXSsFdLd+Kx4S3Sx4cVO+
6/ZOXRnPmNAlLUqjShUsUBBngG0u2fqEqA== 6/ZOXRnPmNAlLUqjShUsUBBngG0u2fqEqA==
-----END EC PRIVATE KEY----- -----END EC PRIVATE KEY-----
`),
"ecdsap256": []byte(`-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIAPCE25zK0PQSnsgVcEbM1mbKTASH4pqb5QJajplDwDZoAoGCCqGSM49
AwEHoUQDQgAEWy8TxGcIHRh5XGpO4dFVfDjeNY+VkgubQrf/eyFJZHxAn1SKraXU
qJUjTKj1z622OxYtJ5P7s9CfAEVsTzLCzg==
-----END EC PRIVATE KEY-----
`),
"ecdsap384": []byte(`-----BEGIN EC PRIVATE KEY-----
MIGkAgEBBDBWfSnMuNKq8J9rQLzzEkx3KAoEohSXqhE/4CdjEYtoU2i22HW80DDS
qQhYNHRAduygBwYFK4EEACKhZANiAAQWaDMAd0HUd8ZiXCX7mYDDnC54gwH/nG43
VhCUEYmF7HMZm/B9Yn3GjFk3qYEDEvuF/52+NvUKBKKaLbh32AWxMv0ibcoba4cz
hL9+hWYhUD9XIUlzMWiZ2y6eBE9PdRI=
-----END EC PRIVATE KEY-----
`),
"ecdsap521": []byte(`-----BEGIN EC PRIVATE KEY-----
MIHcAgEBBEIBrkYpQcy8KTVHNiAkjlFZwee90224Bu6wz94R4OBo+Ts0eoAQG7SF
iaygEDMUbx6kTgXTBcKZ0jrWPKakayNZ/kigBwYFK4EEACOhgYkDgYYABADFuvLV
UoaCDGHcw5uNfdRIsvaLKuWSpLsl48eWGZAwdNG432GDVKduO+pceuE+8XzcyJb+
uMv+D2b11Q/LQUcHJwE6fqbm8m3EtDKPsoKs0u/XUJb0JsH4J8lkZzbUTjvGYamn
FFlRjzoB3Oxu8UQgb+MWPedtH9XYBbg9biz4jJLkXQ==
-----END EC PRIVATE KEY-----
`), `),
"rsa": []byte(`-----BEGIN RSA PRIVATE KEY----- "rsa": []byte(`-----BEGIN RSA PRIVATE KEY-----
MIICXAIBAAKBgQC8A6FGHDiWCSREAXCq6yBfNVr0xCVG2CzvktFNRpue+RXrGs/2 MIICXAIBAAKBgQC8A6FGHDiWCSREAXCq6yBfNVr0xCVG2CzvktFNRpue+RXrGs/2

View File

@@ -6,6 +6,7 @@ package ssh
import ( import (
"bufio" "bufio"
"bytes"
"errors" "errors"
"io" "io"
"log" "log"
@@ -76,17 +77,17 @@ type connectionState struct {
// both directions are triggered by reading and writing a msgNewKey packet // both directions are triggered by reading and writing a msgNewKey packet
// respectively. // respectively.
func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error { func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error {
if ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult); err != nil { ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult)
if err != nil {
return err return err
} else { }
t.reader.pendingKeyChange <- ciph t.reader.pendingKeyChange <- ciph
}
if ciph, err := newPacketCipher(t.writer.dir, algs.w, kexResult); err != nil { ciph, err = newPacketCipher(t.writer.dir, algs.w, kexResult)
if err != nil {
return err return err
} else {
t.writer.pendingKeyChange <- ciph
} }
t.writer.pendingKeyChange <- ciph
return nil return nil
} }
@@ -139,7 +140,7 @@ func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
case cipher := <-s.pendingKeyChange: case cipher := <-s.pendingKeyChange:
s.packetCipher = cipher s.packetCipher = cipher
default: default:
return nil, errors.New("ssh: got bogus newkeys message.") return nil, errors.New("ssh: got bogus newkeys message")
} }
case msgDisconnect: case msgDisconnect:
@@ -254,7 +255,7 @@ func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (pac
iv, key, macKey := generateKeys(d, algs, kex) iv, key, macKey := generateKeys(d, algs, kex)
if algs.Cipher == gcmCipherID { if algs.Cipher == gcmCipherID {
return newGCMCipher(iv, key, macKey) return newGCMCipher(iv, key)
} }
if algs.Cipher == aes128cbcID { if algs.Cipher == aes128cbcID {
@@ -342,7 +343,7 @@ func readVersion(r io.Reader) ([]byte, error) {
var ok bool var ok bool
var buf [1]byte var buf [1]byte
for len(versionString) < maxVersionStringBytes { for length := 0; length < maxVersionStringBytes; length++ {
_, err := io.ReadFull(r, buf[:]) _, err := io.ReadFull(r, buf[:])
if err != nil { if err != nil {
return nil, err return nil, err
@@ -350,6 +351,13 @@ func readVersion(r io.Reader) ([]byte, error) {
// The RFC says that the version should be terminated with \r\n // The RFC says that the version should be terminated with \r\n
// but several SSH servers actually only send a \n. // but several SSH servers actually only send a \n.
if buf[0] == '\n' { if buf[0] == '\n' {
if !bytes.HasPrefix(versionString, []byte("SSH-")) {
// RFC 4253 says we need to ignore all version string lines
// except the one containing the SSH version (provided that
// all the lines do not exceed 255 bytes in total).
versionString = versionString[:0]
continue
}
ok = true ok = true
break break
} }

View File

@@ -13,11 +13,13 @@ import (
) )
func TestReadVersion(t *testing.T) { func TestReadVersion(t *testing.T) {
longversion := strings.Repeat("SSH-2.0-bla", 50)[:253] longVersion := strings.Repeat("SSH-2.0-bla", 50)[:253]
multiLineVersion := strings.Repeat("ignored\r\n", 20) + "SSH-2.0-bla\r\n"
cases := map[string]string{ cases := map[string]string{
"SSH-2.0-bla\r\n": "SSH-2.0-bla", "SSH-2.0-bla\r\n": "SSH-2.0-bla",
"SSH-2.0-bla\n": "SSH-2.0-bla", "SSH-2.0-bla\n": "SSH-2.0-bla",
longversion + "\r\n": longversion, multiLineVersion: "SSH-2.0-bla",
longVersion + "\r\n": longVersion,
} }
for in, want := range cases { for in, want := range cases {
@@ -33,9 +35,11 @@ func TestReadVersion(t *testing.T) {
} }
func TestReadVersionError(t *testing.T) { func TestReadVersionError(t *testing.T) {
longversion := strings.Repeat("SSH-2.0-bla", 50)[:253] longVersion := strings.Repeat("SSH-2.0-bla", 50)[:253]
multiLineVersion := strings.Repeat("ignored\r\n", 50) + "SSH-2.0-bla\r\n"
cases := []string{ cases := []string{
longversion + "too-long\r\n", longVersion + "too-long\r\n",
multiLineVersion,
} }
for _, in := range cases { for _, in := range cases {
if _, err := readVersion(bytes.NewBufferString(in)); err == nil { if _, err := readVersion(bytes.NewBufferString(in)); err == nil {
@@ -60,7 +64,7 @@ func TestExchangeVersionsBasic(t *testing.T) {
func TestExchangeVersions(t *testing.T) { func TestExchangeVersions(t *testing.T) {
cases := []string{ cases := []string{
"not\x000allowed", "not\x000allowed",
"not allowed\n", "not allowed\x01\r\n",
} }
for _, c := range cases { for _, c := range cases {
buf := bytes.NewBufferString("SSH-2.0-bla\r\n") buf := bytes.NewBufferString("SSH-2.0-bla\r\n")

View File

@@ -5,7 +5,6 @@
// Package tea implements the TEA algorithm, as defined in Needham and // Package tea implements the TEA algorithm, as defined in Needham and
// Wheeler's 1994 technical report, “TEA, a Tiny Encryption Algorithm”. See // Wheeler's 1994 technical report, “TEA, a Tiny Encryption Algorithm”. See
// http://www.cix.co.uk/~klockstone/tea.pdf for details. // http://www.cix.co.uk/~klockstone/tea.pdf for details.
package tea package tea
import ( import (

View File

@@ -69,7 +69,7 @@ func initCipher(c *Cipher, key []byte) {
// Precalculate the table // Precalculate the table
const delta = 0x9E3779B9 const delta = 0x9E3779B9
var sum uint32 = 0 var sum uint32
// Two rounds of XTEA applied per loop // Two rounds of XTEA applied per loop
for i := 0; i < numRounds; { for i := 0; i < numRounds; {

View File

@@ -25,3 +25,7 @@ func Clearenv() {
func Environ() []string { func Environ() []string {
return syscall.Environ() return syscall.Environ()
} }
func Unsetenv(key string) error {
return syscall.Unsetenv(key)
}

View File

@@ -1,14 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.4
package plan9
import "syscall"
func Unsetenv(key string) error {
// This was added in Go 1.4.
return syscall.Unsetenv(key)
}

Some files were not shown because too many files have changed in this diff Show More