Compare commits

...

33 Commits

Author SHA1 Message Date
henry.chen
fb66b6871e release v1.4.3 2018-02-09 16:15:34 +08:00
henry.chen
5ae76f243e fixed #6,发布文章异步提交,随机 session key等 2018-02-09 13:50:34 +08:00
deepzz0
051b034e51 1. 修复编辑专题:按钮显示"新增专题"错误 2. 编辑专题链接移动到专题名称 2018-02-04 12:39:35 +08:00
Deepzz
27439ecc71 Update install.md 2018-02-01 21:24:00 +08:00
henry.chen
d02c838447 fix archive page bug 2018-01-25 23:09:59 +08:00
Deepzz
d17acf5325 Update amusing.md 2018-01-17 19:16:08 +08:00
deepzz0
b278ca377f update changelog.md 2018-01-14 13:53:32 +08:00
deepzz0
93131441e4 update 2018-01-14 13:38:26 +08:00
deepzz0
ddcc6c2d2e auto archiving by year when the month great than 12 2018-01-14 13:12:59 +08:00
henry.chen
ef63ae9598 fix page archive unable auto update 2018-01-14 02:40:11 +08:00
henry.chen
2ed9db5c7b code logical adjust 2018-01-14 02:02:12 +08:00
deepzz0
06a12bc6f9 update vendor 2018-01-13 18:23:03 +08:00
deepzz0
6524b45751 adjust the code 2018-01-13 18:19:54 +08:00
henry.chen
ceb9e2690b 添加 disqus thread 创建接口 2018-01-13 02:56:35 +08:00
deepzz0
405fbaf24f fix can delete blogroll and about page & fix delete and readd article bug 2018-01-07 20:30:14 +08:00
deepzz0
3245c0e0d3 update vendor & fix upload file url & fix judge file type 2018-01-06 23:24:27 +08:00
Deepzz
badc62e3f0 Update README.md 2018-01-06 11:47:20 +08:00
deepzz0
a5561f257b comment docker-compose.yml backup 2018-01-02 20:21:45 +08:00
deepzz0
eb37b83ebd update README.md 2018-01-01 19:03:16 +08:00
deepzz0
b2fab703fc Merge branch 'master' of github.com:eiblog/eiblog 2018-01-01 18:59:30 +08:00
deepzz0
37deb390d9 docker-compose.yml 添加数据库备份镜像 2018-01-01 18:59:10 +08:00
Deepzz
6fa5088352 更新 ct 服务器地址 2017-12-30 13:50:19 +08:00
Deepzz
e023a33786 Update app.yml
移除 disqus 评论及 Google 分析私人信息配置
2017-12-08 12:19:01 +08:00
henry.chen
6f818c4b5d fix search.html <no value> 2017-12-05 15:08:32 +08:00
henry.chen
9ad22fb2d9 don't use dynamic link: CGO_ENABLED=0 2017-11-30 10:04:54 +08:00
henry.chen
fc37d5e093 fix page:admin/write-post autocomplete tag 2017-11-29 16:17:58 +08:00
henry.chen
61024bfebd update 2017-11-27 18:34:03 +08:00
henry.chen
f20c4a6063 fix docker image: exec user process caused "no such file or directory" 2017-11-27 18:17:41 +08:00
henry.chen
c24e6bf7bd update .travis.yml 2017-11-27 16:43:30 +08:00
henry.chen
ade94168d3 update .travis.yml 2017-11-27 16:32:39 +08:00
henry.chen
552d010650 fix background turn page 2017-11-27 15:21:28 +08:00
deepzz0
1c3106cbb0 update vendor 2017-11-24 22:58:59 +08:00
henry.chen
168937f1b2 fix gopkg.in/mgo import conflict 2017-11-23 13:57:20 +08:00
255 changed files with 6374 additions and 1370 deletions

View File

@@ -1,36 +1,26 @@
sudo: required # 超级权限 sudo: required # 超级权限
dist: trusty # 在ubuntu:trusty dist: trusty # 在ubuntu:trusty
language: go # 声明构建语言环境 language: go # 声明构建语言环境
go: # 只构建最新版本 go: # 只构建最新版本
- tip - tip
services: # docker环境
services: # docker环境
- docker - docker
branches: # 限定项目分支 branches: # 限定项目分支
only: only:
- /^v[0-9](\.[0-9]){2}(-rc[1-9])?$/ - /^v[0-9](\.[0-9]){2}(-rc[1-9])?$/
install: install:
- curl https://glide.sh/get | sh # 安装glide包管理 - curl https://glide.sh/get | sh # 安装glide包管理
script: script:
- glide up - glide up
- GOOS=linux GOARCH=amd64 go build # 编译版本 - GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build # 编译版本
- docker build -t registry.cn-hangzhou.aliyuncs.com/deepzz/eiblog:$TRAVIS_TAG . # 构建镜像 - docker build -t registry.cn-hangzhou.aliyuncs.com/deepzz/eiblog . # 构建镜像
after_success: after_success:
# - if [ "$TRAVIS_BRANCH" =~ ^v[0-9](\.[0-9])+.*$ ]; then - docker login -u="$DOCKER_USERNAME" -p="$DOCKER_PASSWORD" registry.cn-hangzhou.aliyuncs.com
# docker login -u="$DOCKER_USERNAME" -p="$DOCKER_PASSWORD" registry.cn-hangzhou.aliyuncs.com; - docker push registry.cn-hangzhou.aliyuncs.com/deepzz/eiblog
# docker push registry.cn-hangzhou.aliyuncs.com/deepzz/eiblog; - docker tag registry.cn-hangzhou.aliyuncs.com/deepzz/eiblog registry.cn-hangzhou.aliyuncs.com/deepzz/eiblog:$TRAVIS_TAG
# fi
- docker push registry.cn-hangzhou.aliyuncs.com/deepzz/eiblog:$TRAVIS_TAG - docker push registry.cn-hangzhou.aliyuncs.com/deepzz/eiblog:$TRAVIS_TAG
before_deploy: before_deploy:
- ./dist.sh - ./dist.sh
deploy: deploy:
provider: releases provider: releases
api_key: api_key:

View File

@@ -1,5 +1,44 @@
# Eiblog Changelog # Eiblog Changelog
## v1.4.2 (2018-02-09)
* 修复博客初始化后about 页面不能够评论 #6
* 修复编辑专题,按钮显示“添加专题”错误
* 优化“添加文章”从同步改为异步推送feedesdisqus。速度显著提升
* **重要*)头像图片从 avatar.jpg 改为 avatar.png透明
* docker-compose.yml mongodb 去掉端口映射,防止用户将端口暴露至外网
* session key 每次重启随机生成等一些细节的修复
## v1.4.1 (2018-01-14)
* 修复创建新文章disqus 不收录bug
* 修复创建新文章归档页面不刷新bug
* 修复能够删除关于页面和友情链接页面bug
* 修复重复添加文章错误
* 注释掉 docker-compose.yml 自动备份内容,请自行解开
* 添加当月数大于12归档页面使用年份归档
* 优化代码逻辑
## v1.4.0 (2018-01-01)
* fix 搜索页面 bug
* CGO_ENABLED=0 关闭 cgo
* 更新Makefile ct log 服务器
* 数据库数据终于可以备份了
## v1.3.4 (2017-11-29)
* fix page:admin/write-post autocomplete tag
## v1.3.3 (2017-11-27)
* fix docker image: exec user process caused "no such file or directory"
## v1.3.2 (2017-11-17)
* 修复文章自动保存引起的发布文章不成功的bug
## v1.3.1 (2017-11-05)
* 修复调整 关于、友情链接 创建时间出现文章乱序
* 修复评论时间计算错误
* 调整acme文件验证路径
* 更改七牛SDK包为github包。
* 调整七牛配置文件名称app.yml: kodo -> qiniuname -> bucket请提高静态文件版本 staticversion
## v1.3.0 (2017-07-13) ## v1.3.0 (2017-07-13)
* 更改 app.yml 配置项,将大部分配置归在 general 常规配置下。注意,部署时请先更新 app.yml。 * 更改 app.yml 配置项,将大部分配置归在 general 常规配置下。注意,部署时请先更新 app.yml。
* 静态文件采用动态渲染,即用户不再需要管理 view、static 目录。 * 静态文件采用动态渲染,即用户不再需要管理 view、static 目录。

View File

@@ -7,4 +7,4 @@ ADD static/tzdata/Shanghai /etc/localtime
COPY . /eiblog COPY . /eiblog
EXPOSE 9000 EXPOSE 9000
WORKDIR /eiblog WORKDIR /eiblog
CMD ["./eiblog"] CMD ["sh","-c","/eiblog/eiblog"]

View File

@@ -40,7 +40,7 @@ gencert:makedir
@echo "generate rsa cert..." @echo "generate rsa cert..."
@$(acme.sh) --force --issue --dns dns_ali $(sans) --log \ @$(acme.sh) --force --issue --dns dns_ali $(sans) --log \
--renew-hook "ct-submit ctlog.api.venafi.com < $(config)/ssl/domain.rsa.pem > $(config)/scts/rsa/venafi.sct \ --renew-hook "ct-submit ctlog-gen2.api.venafi.com < $(config)/ssl/domain.rsa.pem > $(config)/scts/rsa/venafi.sct \
&& ct-submit ctlog.wosign.com < $(config)/ssl/domain.rsa.pem > $(config)/scts/rsa/wosign.sct" && ct-submit ctlog.wosign.com < $(config)/ssl/domain.rsa.pem > $(config)/scts/rsa/wosign.sct"
@$(acme.sh) --install-cert -d $(cn) \ @$(acme.sh) --install-cert -d $(cn) \
--key-file $(config)/ssl/domain.rsa.key \ --key-file $(config)/ssl/domain.rsa.key \
@@ -49,7 +49,7 @@ gencert:makedir
@echo "generate ecc cert..." @echo "generate ecc cert..."
@$(acme.sh) --force --issue --dns dns_ali $(sans) -k ec-256 --log \ @$(acme.sh) --force --issue --dns dns_ali $(sans) -k ec-256 --log \
--renew-hook "ct-submit ctlog.api.venafi.com < $(config)/ssl/domain.ecc.pem > $(config)/scts/ecc/venafi.sct \ --renew-hook "ct-submit ctlog-gen2.api.venafi.com < $(config)/ssl/domain.ecc.pem > $(config)/scts/ecc/venafi.sct \
&& ct-submit ctlog.wosign.com < $(config)/ssl/domain.ecc.pem > $(config)/scts/ecc/wosign.sct" && ct-submit ctlog.wosign.com < $(config)/ssl/domain.ecc.pem > $(config)/scts/ecc/wosign.sct"
@$(acme.sh) --install-cert -d $(cn) --ecc \ @$(acme.sh) --install-cert -d $(cn) --ecc \
--key-file $(config)/ssl/domain.ecc.key \ --key-file $(config)/ssl/domain.ecc.key \

View File

@@ -74,6 +74,7 @@
8. 开源 `Typecho` 完整后台系统,全功能 `markdown` 编辑器,让你体验什么是简洁清爽。 8. 开源 `Typecho` 完整后台系统,全功能 `markdown` 编辑器,让你体验什么是简洁清爽。
9. 博客后台直接对接 `七牛 SDK`,实现后台上传文件和删除文件的简单功能。 9. 博客后台直接对接 `七牛 SDK`,实现后台上传文件和删除文件的简单功能。
10. 采用 `elasticsearch` 作为站内搜索,添加 `google opensearch` 功能,搜索更加自然。 10. 采用 `elasticsearch` 作为站内搜索,添加 `google opensearch` 功能,搜索更加自然。
11. 自动备份数据库数据到七牛云。
### 文档 ### 文档
@@ -81,10 +82,10 @@
* [安装部署](https://github.com/eiblog/eiblog/blob/master/docs/install.md) * [安装部署](https://github.com/eiblog/eiblog/blob/master/docs/install.md)
* [写作需知](https://github.com/eiblog/eiblog/blob/master/docs/writing.md) * [写作需知](https://github.com/eiblog/eiblog/blob/master/docs/writing.md)
* [好玩的功能](https://github.com/eiblog/eiblog/blob/master/docs/amusing.md) * [好玩的功能](https://github.com/eiblog/eiblog/blob/master/docs/amusing.md)
* [关于备份](https://github.com/eiblog/backup)
### 成功搭建者博客 ### 成功搭建者博客
* [https://razeen.me](https://razeen.me) - Razeen's Blog * [https://razeen.me](https://razeen.me) - Razeen's Blog
* [https://mxthd.me](https://mxthd.me) - 梦醒逃荒岛
如果你的博客使用`Eiblog`搭建,你可以在 [这里](https://github.com/eiblog/eiblog/issues/1) 提交网址。 如果你的博客使用`Eiblog`搭建,你可以在 [这里](https://github.com/eiblog/eiblog/issues/1) 提交网址。

196
api.go
View File

@@ -4,15 +4,14 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"sort"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/eiblog/eiblog/setting" "github.com/eiblog/eiblog/setting"
"github.com/eiblog/utils/logd" "github.com/eiblog/utils/logd"
"github.com/eiblog/utils/mgo"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gopkg.in/mgo.v2/bson"
) )
const ( const (
@@ -56,6 +55,7 @@ func init() {
APIs["file-delete"] = apiFileDelete APIs["file-delete"] = apiFileDelete
} }
// 更新账号信息Email、PhoneNumber、Address
func apiAccount(c *gin.Context) { func apiAccount(c *gin.Context) {
e := c.PostForm("email") e := c.PostForm("email")
pn := c.PostForm("phoneNumber") pn := c.PostForm("phoneNumber")
@@ -66,8 +66,9 @@ func apiAccount(c *gin.Context) {
return return
} }
err := UpdateAccountField(bson.M{"$set": bson.M{"email": e, "phonen": pn, "address": ad}}) err := UpdateAccountField(mgo.M{"$set": mgo.M{"email": e, "phonen": pn, "address": ad}})
if err != nil { if err != nil {
logd.Error(err)
responseNotice(c, NOTICE_NOTICE, err.Error(), "") responseNotice(c, NOTICE_NOTICE, err.Error(), "")
return return
} }
@@ -77,6 +78,7 @@ func apiAccount(c *gin.Context) {
responseNotice(c, NOTICE_SUCCESS, "更新成功", "") responseNotice(c, NOTICE_SUCCESS, "更新成功", "")
} }
// 更新博客信息
func apiBlog(c *gin.Context) { func apiBlog(c *gin.Context) {
bn := c.PostForm("blogName") bn := c.PostForm("blogName")
bt := c.PostForm("bTitle") bt := c.PostForm("bTitle")
@@ -89,8 +91,11 @@ func apiBlog(c *gin.Context) {
return return
} }
err := UpdateAccountField(bson.M{"$set": bson.M{"blogger.blogname": bn, "blogger.btitle": bt, "blogger.beian": ba, "blogger.subtitle": st, "blogger.seriessay": ss, "blogger.archivessay": as}}) err := UpdateAccountField(mgo.M{"$set": mgo.M{"blogger.blogname": bn,
"blogger.btitle": bt, "blogger.beian": ba, "blogger.subtitle": st,
"blogger.seriessay": ss, "blogger.archivessay": as}})
if err != nil { if err != nil {
logd.Error(err)
responseNotice(c, NOTICE_NOTICE, err.Error(), "") responseNotice(c, NOTICE_NOTICE, err.Error(), "")
return return
} }
@@ -105,6 +110,7 @@ func apiBlog(c *gin.Context) {
responseNotice(c, NOTICE_SUCCESS, "更新成功", "") responseNotice(c, NOTICE_SUCCESS, "更新成功", "")
} }
// 更新密码
func apiPassword(c *gin.Context) { func apiPassword(c *gin.Context) {
logd.Debug(c.Request.PostForm.Encode()) logd.Debug(c.Request.PostForm.Encode())
od := c.PostForm("old") od := c.PostForm("old")
@@ -124,8 +130,9 @@ func apiPassword(c *gin.Context) {
} }
newPwd := EncryptPasswd(Ei.Username, nw) newPwd := EncryptPasswd(Ei.Username, nw)
err := UpdateAccountField(bson.M{"$set": bson.M{"password": newPwd}}) err := UpdateAccountField(mgo.M{"$set": mgo.M{"password": newPwd}})
if err != nil { if err != nil {
logd.Error(err)
responseNotice(c, NOTICE_NOTICE, err.Error(), "") responseNotice(c, NOTICE_NOTICE, err.Error(), "")
return return
} }
@@ -133,46 +140,39 @@ func apiPassword(c *gin.Context) {
responseNotice(c, NOTICE_SUCCESS, "更新成功", "") responseNotice(c, NOTICE_SUCCESS, "更新成功", "")
} }
// 删除文章,软删除:移入到回收箱
func apiPostDelete(c *gin.Context) { func apiPostDelete(c *gin.Context) {
var err error
defer func() {
if err != nil {
logd.Error(err)
responseNotice(c, NOTICE_NOTICE, err.Error(), "")
return
}
responseNotice(c, NOTICE_SUCCESS, "删除成功", "")
}()
err = c.Request.ParseForm()
if err != nil {
return
}
var ids []int32 var ids []int32
var i int for _, v := range c.PostFormArray("cid[]") {
for _, v := range c.Request.PostForm["cid[]"] { i, err := strconv.Atoi(v)
i, err = strconv.Atoi(v) if err != nil || int32(i) < setting.Conf.General.StartID {
if err != nil || i < 1 { responseNotice(c, NOTICE_NOTICE, "参数错误", "")
err = errors.New("参数错误")
return return
} }
ids = append(ids, int32(i)) ids = append(ids, int32(i))
} }
err = DelArticles(ids...) err := DelArticles(ids...)
if err != nil { if err != nil {
logd.Error(err)
responseNotice(c, NOTICE_NOTICE, err.Error(), "")
return return
} }
// elasticsearch 删除索引
// elasticsearch
err = ElasticDelIndex(ids) err = ElasticDelIndex(ids)
if err != nil { if err != nil {
return logd.Error(err)
} }
// TODO disqus delete
responseNotice(c, NOTICE_SUCCESS, "删除成功", "")
} }
func apiPostAdd(c *gin.Context) { func apiPostAdd(c *gin.Context) {
var err error var (
var do string err error
var cid int do string
cid int
)
defer func() { defer func() {
switch do { switch do {
case "auto": // 自动保存 case "auto": // 自动保存
@@ -181,18 +181,16 @@ func apiPostAdd(c *gin.Context) {
return return
} }
c.JSON(http.StatusOK, gin.H{"success": SUCCESS, "time": time.Now().Format("15:04:05 PM"), "cid": cid}) c.JSON(http.StatusOK, gin.H{"success": SUCCESS, "time": time.Now().Format("15:04:05 PM"), "cid": cid})
case "save": // 保存草稿 case "save", "publish": // 草稿,发布
if err != nil { if err != nil {
responseNotice(c, NOTICE_NOTICE, err.Error(), "") responseNotice(c, NOTICE_NOTICE, err.Error(), "")
return return
} }
c.Redirect(http.StatusFound, "/admin/manage-draft") uri := "/admin/manage-draft"
case "publish": // 发布 if do == "publish" {
if err != nil { uri = "/admin/manage-posts"
responseNotice(c, NOTICE_NOTICE, err.Error(), "")
return
} }
c.Redirect(http.StatusFound, "/admin/manage-posts") c.Redirect(http.StatusFound, uri)
} }
}() }()
@@ -200,11 +198,11 @@ func apiPostAdd(c *gin.Context) {
slug := c.PostForm("slug") slug := c.PostForm("slug")
title := c.PostForm("title") title := c.PostForm("title")
text := c.PostForm("text") text := c.PostForm("text")
date := c.PostForm("date") date := CheckDate(c.PostForm("date"))
serie := c.PostForm("serie") serie := c.PostForm("serie")
tag := c.PostForm("tags") tag := c.PostForm("tags")
update := c.PostForm("update") update := c.PostForm("update")
if title == "" || text == "" || slug == "" { if slug == "" || title == "" || text == "" {
err = errors.New("参数错误") err = errors.New("参数错误")
return return
} }
@@ -217,13 +215,14 @@ func apiPostAdd(c *gin.Context) {
Title: title, Title: title,
Content: text, Content: text,
Slug: slug, Slug: slug,
CreateTime: CheckDate(date), CreateTime: date,
IsDraft: do != "publish", IsDraft: do != "publish",
Author: Ei.Username, Author: Ei.Username,
SerieID: serieid, SerieID: serieid,
Tags: tags, Tags: tags,
} }
cid, err = strconv.Atoi(c.PostForm("cid")) cid, err = strconv.Atoi(c.PostForm("cid"))
// 新文章
if err != nil || cid < 1 { if err != nil || cid < 1 {
err = AddArticle(artc) err = AddArticle(artc)
if err != nil { if err != nil {
@@ -232,58 +231,55 @@ func apiPostAdd(c *gin.Context) {
} }
cid = int(artc.ID) cid = int(artc.ID)
if !artc.IsDraft { if !artc.IsDraft {
ElasticIndex(artc) // 异步执行,快
DoPings(slug) go func() {
// elastic
ElasticIndex(artc)
// rss
DoPings(slug)
// disqus
ThreadCreate(artc)
}()
} }
return return
} }
// 旧文章
artc.ID = int32(cid) artc.ID = int32(cid)
i, a := GetArticle(artc.ID) _, a := GetArticle(artc.ID)
if a != nil { if a != nil {
artc.IsDraft = false artc.IsDraft = false
artc.Count = a.Count artc.Count = a.Count
artc.UpdateTime = a.UpdateTime artc.UpdateTime = a.UpdateTime
Ei.Articles = append(Ei.Articles[0:i], Ei.Articles[i+1:]...)
DelFromLinkedList(a)
ManageTagsArticle(a, false, DELETE)
ManageSeriesArticle(a, false, DELETE)
ManageArchivesArticle(a, false, DELETE)
delete(Ei.MapArticles, a.Slug)
a = nil
} }
if CheckBool(update) { if CheckBool(update) {
artc.UpdateTime = time.Now() artc.UpdateTime = time.Now()
} }
err = UpdateArticle(bson.M{"id": artc.ID}, artc) // 数据库更新
err = UpdateArticle(mgo.M{"id": artc.ID}, artc)
if err != nil { if err != nil {
logd.Error(err) logd.Error(err)
return return
} }
if !artc.IsDraft { if !artc.IsDraft {
Ei.MapArticles[artc.Slug] = artc ReplaceArticle(a, artc)
Ei.Articles = append(Ei.Articles, artc) // 异步执行,快
sort.Sort(Ei.Articles) go func() {
GenerateExcerptAndRender(artc) // elastic
// elasticsearch 索引 ElasticIndex(artc)
ElasticIndex(artc) // rss
DoPings(slug) DoPings(slug)
if artc.ID >= setting.Conf.General.StartID { // disqus
ManageTagsArticle(artc, true, ADD) if a == nil {
ManageSeriesArticle(artc, true, ADD) ThreadCreate(artc)
ManageArchivesArticle(artc, true, ADD) }
AddToLinkedList(artc.ID) }()
}
} }
} }
// 只能逐一删除,专题下不能有文章
func apiSerieDelete(c *gin.Context) { func apiSerieDelete(c *gin.Context) {
err := c.Request.ParseForm() for _, v := range c.PostFormArray("mid[]") {
if err != nil {
responseNotice(c, NOTICE_NOTICE, err.Error(), "")
return
}
// 只能逐一删除
for _, v := range c.Request.PostForm["mid[]"] {
id, err := strconv.Atoi(v) id, err := strconv.Atoi(v)
if err != nil || id < 1 { if err != nil || id < 1 {
responseNotice(c, NOTICE_NOTICE, err.Error(), "") responseNotice(c, NOTICE_NOTICE, err.Error(), "")
@@ -299,6 +295,7 @@ func apiSerieDelete(c *gin.Context) {
responseNotice(c, NOTICE_SUCCESS, "删除成功", "") responseNotice(c, NOTICE_SUCCESS, "删除成功", "")
} }
// 添加专题,如果专题有提交 mid 即更新专题
func apiSerieAdd(c *gin.Context) { func apiSerieAdd(c *gin.Context) {
name := c.PostForm("name") name := c.PostForm("name")
slug := c.PostForm("slug") slug := c.PostForm("slug")
@@ -335,24 +332,15 @@ func apiSerieAdd(c *gin.Context) {
responseNotice(c, NOTICE_SUCCESS, "操作成功", "") responseNotice(c, NOTICE_SUCCESS, "操作成功", "")
} }
// 暂未启用 // NOTE 排序专题,暂未实现
func apiSerieSort(c *gin.Context) { func apiSerieSort(c *gin.Context) {
err := c.Request.ParseForm() v := c.PostFormArray("mid[]")
if err != nil {
responseNotice(c, NOTICE_NOTICE, err.Error(), "")
return
}
v := c.Request.PostForm["mid[]"]
logd.Debug(v) logd.Debug(v)
} }
// 删除草稿箱,物理删除
func apiDraftDelete(c *gin.Context) { func apiDraftDelete(c *gin.Context) {
err := c.Request.ParseForm() for _, v := range c.PostFormArray("mid[]") {
if err != nil {
responseNotice(c, NOTICE_NOTICE, err.Error(), "")
return
}
for _, v := range c.Request.PostForm["mid[]"] {
i, err := strconv.Atoi(v) i, err := strconv.Atoi(v)
if err != nil || i < 1 { if err != nil || i < 1 {
responseNotice(c, NOTICE_NOTICE, "参数错误", "") responseNotice(c, NOTICE_NOTICE, "参数错误", "")
@@ -367,15 +355,9 @@ func apiDraftDelete(c *gin.Context) {
responseNotice(c, NOTICE_SUCCESS, "删除成功", "") responseNotice(c, NOTICE_SUCCESS, "删除成功", "")
} }
// 删除垃圾箱,物理删除
func apiTrashDelete(c *gin.Context) { func apiTrashDelete(c *gin.Context) {
logd.Debug(c.PostForm("key")) for _, v := range c.PostFormArray("mid[]") {
logd.Debug(c.Request.PostForm)
err := c.Request.ParseForm()
if err != nil {
responseNotice(c, NOTICE_NOTICE, err.Error(), "")
return
}
for _, v := range c.Request.PostForm["mid[]"] {
i, err := strconv.Atoi(v) i, err := strconv.Atoi(v)
if err != nil || i < 1 { if err != nil || i < 1 {
responseNotice(c, NOTICE_NOTICE, "参数错误", "") responseNotice(c, NOTICE_NOTICE, "参数错误", "")
@@ -390,15 +372,9 @@ func apiTrashDelete(c *gin.Context) {
responseNotice(c, NOTICE_SUCCESS, "删除成功", "") responseNotice(c, NOTICE_SUCCESS, "删除成功", "")
} }
// 从垃圾箱恢复到草稿箱
func apiTrashRecover(c *gin.Context) { func apiTrashRecover(c *gin.Context) {
logd.Debug(c.PostForm("key")) for _, v := range c.PostFormArray("mid[]") {
logd.Debug(c.Request.PostForm)
err := c.Request.ParseForm()
if err != nil {
responseNotice(c, NOTICE_NOTICE, err.Error(), "")
return
}
for _, v := range c.Request.PostForm["mid[]"] {
i, err := strconv.Atoi(v) i, err := strconv.Atoi(v)
if err != nil || i < 1 { if err != nil || i < 1 {
responseNotice(c, NOTICE_NOTICE, "参数错误", "") responseNotice(c, NOTICE_NOTICE, "参数错误", "")
@@ -414,6 +390,7 @@ func apiTrashRecover(c *gin.Context) {
responseNotice(c, NOTICE_SUCCESS, "恢复成功", "") responseNotice(c, NOTICE_SUCCESS, "恢复成功", "")
} }
// 上传文件到 qiniu 云
func apiFileUpload(c *gin.Context) { func apiFileUpload(c *gin.Context) {
type Size interface { type Size interface {
Size() int64 Size() int64
@@ -437,7 +414,7 @@ func apiFileUpload(c *gin.Context) {
c.String(http.StatusBadRequest, err.Error()) c.String(http.StatusBadRequest, err.Error())
return return
} }
typ := c.Request.Header.Get("Content-Type") typ := header.Header.Get("Content-Type")
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"title": filename, "title": filename,
"isImage": typ[:5] == "image", "isImage": typ[:5] == "image",
@@ -446,20 +423,19 @@ func apiFileUpload(c *gin.Context) {
}) })
} }
// 删除七牛 CDN 文件
func apiFileDelete(c *gin.Context) { func apiFileDelete(c *gin.Context) {
var err error defer c.String(http.StatusOK, "删掉了吗?鬼知道。。。")
defer func() {
if err != nil {
logd.Error(err)
}
c.String(http.StatusOK, "删掉了吗?鬼知道。。。")
}()
name := c.PostForm("title") name := c.PostForm("title")
if name == "" { if name == "" {
err = errors.New("参数错误") logd.Error("参数错误")
return return
} }
err = FileDelete(name) err := FileDelete(name)
if err != nil {
logd.Error(err)
}
} }
func responseNotice(c *gin.Context, typ, content, hl string) { func responseNotice(c *gin.Context, typ, content, hl string) {

View File

@@ -3,6 +3,7 @@ package main
import ( import (
"bytes" "bytes"
"encoding/json"
"fmt" "fmt"
"html/template" "html/template"
"net/http" "net/http"
@@ -11,9 +12,9 @@ import (
"github.com/eiblog/eiblog/setting" "github.com/eiblog/eiblog/setting"
"github.com/eiblog/utils/logd" "github.com/eiblog/utils/logd"
"github.com/eiblog/utils/mgo"
"github.com/gin-gonic/contrib/sessions" "github.com/gin-gonic/contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gopkg.in/mgo.v2/bson"
) )
// 是否登录 // 是否登录
@@ -73,7 +74,7 @@ func HandleLoginPost(c *gin.Context) {
session.Save() session.Save()
Ei.LoginIP = c.ClientIP() Ei.LoginIP = c.ClientIP()
Ei.LoginTime = time.Now() Ei.LoginTime = time.Now()
UpdateAccountField(bson.M{"$set": bson.M{"loginip": Ei.LoginIP, "logintime": Ei.LoginTime}}) UpdateAccountField(mgo.M{"$set": mgo.M{"loginip": Ei.LoginIP, "logintime": Ei.LoginTime}})
c.Redirect(http.StatusFound, "/admin/profile") c.Redirect(http.StatusFound, "/admin/profile")
} }
@@ -118,7 +119,8 @@ func HandlePost(c *gin.Context) {
for tag, _ := range Ei.Tags { for tag, _ := range Ei.Tags {
tags = append(tags, T{tag, tag}) tags = append(tags, T{tag, tag})
} }
h["Tags"] = tags str, _ := json.Marshal(tags)
h["Tags"] = string(str)
c.Status(http.StatusOK) c.Status(http.StatusOK)
RenderHTMLBack(c, "admin-post", h) RenderHTMLBack(c, "admin-post", h)
} }

View File

@@ -35,13 +35,14 @@ general:
clean: 1 clean: 1
# 评论相关 # 评论相关
disqus: disqus:
shortname: deepzz shortname: xxxxxx
publickey: wdSgxRm9rdGAlLKFcFdToBe3GT4SibmV7Y8EjJQ0r4GWXeKtxpopMAeIeoI2dTEg publickey: wdSgxRm9rdGAlLKFcFdToBe3GT4SibmV7Y8EjJQ0r4GWXeKtxpopMAeIeoI2dTEg
accesstoken: 50023908f39f4607957e909b495326af accesstoken: 50023908f39f4607957e909b495326af
postscount: https://disqus.com/api/3.0/threads/set.json postscount: https://disqus.com/api/3.0/threads/set.json
postslist: https://disqus.com/api/3.0/threads/listPosts.json postslist: https://disqus.com/api/3.0/threads/listPosts.json
postcreate: https://disqus.com/api/3.0/posts/create.json postcreate: https://disqus.com/api/3.0/posts/create.json
postapprove: https://disqus.com/api/3.0/posts/approve.json postapprove: https://disqus.com/api/3.0/posts/approve.json
threadcreate: https://disqus.com/api/3.0/threads/create.json
# disqus.js 文件名 # disqus.js 文件名
embed: disqus_7d3cf2.js embed: disqus_7d3cf2.js
# 获取评论数量间隔 # 获取评论数量间隔
@@ -49,7 +50,7 @@ disqus:
# 谷歌统计 # 谷歌统计
google: google:
url: https://www.google-analytics.com/collect url: https://www.google-analytics.com/collect
tid: UA-77251712-1 tid: UA-xxxxxx-1
v: "1" v: "1"
t: pageview t: pageview
# 七牛CDN # 七牛CDN

402
db.go
View File

@@ -13,9 +13,7 @@ import (
"github.com/eiblog/blackfriday" "github.com/eiblog/blackfriday"
"github.com/eiblog/eiblog/setting" "github.com/eiblog/eiblog/setting"
"github.com/eiblog/utils/logd" "github.com/eiblog/utils/logd"
db "github.com/eiblog/utils/mgo" "github.com/eiblog/utils/mgo"
"gopkg.in/mgo.v2"
"gopkg.in/mgo.v2/bson"
) )
// 数据库及表名 // 数据库及表名
@@ -62,44 +60,24 @@ var (
func init() { func init() {
// 数据库加索引 // 数据库加索引
ms, c := db.Connect(DB, COLLECTION_ACCOUNT) err := mgo.Index(DB, COLLECTION_ACCOUNT, []string{"username"})
index := mgo.Index{ if err != nil {
Key: []string{"username"},
Unique: true,
DropDups: true,
Background: true,
Sparse: true,
}
if err := c.EnsureIndex(index); err != nil {
logd.Fatal(err) logd.Fatal(err)
} }
ms.Close()
ms, c = db.Connect(DB, COLLECTION_ARTICLE) err = mgo.Index(DB, COLLECTION_ARTICLE, []string{"id"})
index = mgo.Index{ if err != nil {
Key: []string{"id"},
Unique: true,
DropDups: true,
Background: true,
Sparse: true,
}
if err := c.EnsureIndex(index); err != nil {
logd.Fatal(err) logd.Fatal(err)
} }
index = mgo.Index{
Key: []string{"slug"}, err = mgo.Index(DB, COLLECTION_ARTICLE, []string{"slug"})
Unique: true, if err != nil {
DropDups: true,
Background: true,
Sparse: true,
}
if err := c.EnsureIndex(index); err != nil {
logd.Fatal(err) logd.Fatal(err)
} }
ms.Close()
// 读取帐号信息 // 读取帐号信息
Ei = loadAccount() loadAccount()
// 获取文章数据 // 获取文章数据
Ei.Articles = loadArticles() loadArticles()
// 生成markdown文档 // 生成markdown文档
go generateMarkdown() go generateMarkdown()
// 启动定时器 // 启动定时器
@@ -109,12 +87,13 @@ func init() {
} }
// 读取或初始化帐号信息 // 读取或初始化帐号信息
func loadAccount() (a *Account) { func loadAccount() {
a = &Account{} Ei = &Account{}
err := db.FindOne(DB, COLLECTION_ACCOUNT, bson.M{"username": setting.Conf.Account.Username}, a) err := mgo.FindOne(DB, COLLECTION_ACCOUNT, mgo.M{"username": setting.Conf.Account.Username}, Ei)
// 初始化用户数据 // 初始化用户数据
if err == mgo.ErrNotFound { if err == mgo.ErrNotFound {
a = &Account{ logd.Printf("Initializing account: %s\n", setting.Conf.Account.Username)
Ei = &Account{
Username: setting.Conf.Account.Username, Username: setting.Conf.Account.Username,
Password: EncryptPasswd(setting.Conf.Account.Username, setting.Conf.Account.Password), Password: EncryptPasswd(setting.Conf.Account.Username, setting.Conf.Account.Password),
Email: setting.Conf.Account.Email, Email: setting.Conf.Account.Email,
@@ -122,29 +101,28 @@ func loadAccount() (a *Account) {
Address: setting.Conf.Account.Address, Address: setting.Conf.Account.Address,
CreateTime: time.Now(), CreateTime: time.Now(),
} }
a.BlogName = setting.Conf.Blogger.BlogName Ei.BlogName = setting.Conf.Blogger.BlogName
a.SubTitle = setting.Conf.Blogger.SubTitle Ei.SubTitle = setting.Conf.Blogger.SubTitle
a.BeiAn = setting.Conf.Blogger.BeiAn Ei.BeiAn = setting.Conf.Blogger.BeiAn
a.BTitle = setting.Conf.Blogger.BTitle Ei.BTitle = setting.Conf.Blogger.BTitle
a.Copyright = setting.Conf.Blogger.Copyright Ei.Copyright = setting.Conf.Blogger.Copyright
err = db.Insert(DB, COLLECTION_ACCOUNT, a) err = mgo.Insert(DB, COLLECTION_ACCOUNT, Ei)
generateTopic() generateTopic()
} else if err != nil { } else if err != nil {
logd.Fatal(err) logd.Fatal(err)
} }
a.CH = make(chan string, 2) Ei.CH = make(chan string, 2)
a.MapArticles = make(map[string]*Article) Ei.MapArticles = make(map[string]*Article)
a.Tags = make(map[string]SortArticles) Ei.Tags = make(map[string]SortArticles)
return
} }
func loadArticles() (artcs SortArticles) { func loadArticles() {
err := db.FindAll(DB, COLLECTION_ARTICLE, bson.M{"isdraft": false, "deletetime": bson.M{"$eq": time.Time{}}}, &artcs) err := mgo.FindAll(DB, COLLECTION_ARTICLE, mgo.M{"isdraft": false, "deletetime": mgo.M{"$eq": time.Time{}}}, &Ei.Articles)
if err != nil { if err != nil {
logd.Fatal(err) logd.Fatal(err)
} }
sort.Sort(artcs) sort.Sort(Ei.Articles)
for i, v := range artcs { for i, v := range Ei.Articles {
// 渲染文章 // 渲染文章
GenerateExcerptAndRender(v) GenerateExcerptAndRender(v)
Ei.MapArticles[v.Slug] = v Ei.MapArticles[v.Slug] = v
@@ -153,18 +131,15 @@ func loadArticles() (artcs SortArticles) {
continue continue
} }
if i > 0 { if i > 0 {
v.Prev = artcs[i-1] v.Prev = Ei.Articles[i-1]
} }
if artcs[i+1].ID >= setting.Conf.General.StartID { if Ei.Articles[i+1].ID >= setting.Conf.General.StartID {
v.Next = artcs[i+1] v.Next = Ei.Articles[i+1]
} }
ManageTagsArticle(v, false, ADD) upArticle(v, false)
ManageSeriesArticle(v, false, ADD)
ManageArchivesArticle(v, false, ADD)
} }
Ei.CH <- SERIES_MD Ei.CH <- SERIES_MD
Ei.CH <- ARCHIVE_MD Ei.CH <- ARCHIVE_MD
return
} }
// generate series,archive markdown // generate series,archive markdown
@@ -183,7 +158,8 @@ func generateMarkdown() {
buffer.WriteString("\n\n") buffer.WriteString("\n\n")
for _, artc := range serie.Articles { for _, artc := range serie.Articles {
//eg. * [标题一](/post/hello-world.html) <span class="date">(Man 02, 2006)</span> //eg. * [标题一](/post/hello-world.html) <span class="date">(Man 02, 2006)</span>
buffer.WriteString("* [" + artc.Title + "](/post/" + artc.Slug + ".html) <span class=\"date\">(" + artc.CreateTime.Format("Jan 02, 2006") + ")</span>\n") buffer.WriteString("* [" + artc.Title + "](/post/" + artc.Slug +
".html) <span class=\"date\">(" + artc.CreateTime.Format("Jan 02, 2006") + ")</span>\n")
} }
buffer.WriteByte('\n') buffer.WriteByte('\n')
} }
@@ -191,15 +167,31 @@ func generateMarkdown() {
case ARCHIVE_MD: case ARCHIVE_MD:
sort.Sort(Ei.Archives) sort.Sort(Ei.Archives)
var buffer bytes.Buffer var buffer bytes.Buffer
buffer.WriteString(Ei.ArchivesSay) buffer.WriteString(Ei.ArchivesSay + "\n")
buffer.WriteString("\n\n")
var (
currentYear string
gt12Month = len(Ei.Archives) > 12
)
for _, archive := range Ei.Archives { for _, archive := range Ei.Archives {
buffer.WriteString(fmt.Sprintf("### %s", archive.Time.Format("2006年01月"))) if gt12Month {
buffer.WriteString("\n\n") year := archive.Time.Format("2006 年")
for _, artc := range archive.Articles { if currentYear != year {
buffer.WriteString("* [" + artc.Title + "](/post/" + artc.Slug + ".html) <span class=\"date\">(" + artc.CreateTime.Format("Jan 02, 2006") + ")</span>\n") currentYear = year
buffer.WriteString(fmt.Sprintf("\n### %s\n\n", archive.Time.Format("2006 年")))
}
} else {
buffer.WriteString(fmt.Sprintf("\n### %s\n\n", archive.Time.Format("2006年1月")))
}
for i, artc := range archive.Articles {
if i == 0 && gt12Month {
buffer.WriteString("* *[" + artc.Title + "](/post/" + artc.Slug +
".html) <span class=\"date\">(" + artc.CreateTime.Format("Jan 02, 2006") + ")</span>*\n")
} else {
buffer.WriteString("* [" + artc.Title + "](/post/" + artc.Slug +
".html) <span class=\"date\">(" + artc.CreateTime.Format("Jan 02, 2006") + ")</span>\n")
}
} }
buffer.WriteByte('\n')
} }
Ei.PageArchives = string(renderPage(buffer.Bytes())) Ei.PageArchives = string(renderPage(buffer.Bytes()))
} }
@@ -209,26 +201,29 @@ func generateMarkdown() {
// init account: generate blogroll and about page // init account: generate blogroll and about page
func generateTopic() { func generateTopic() {
about := &Article{ about := &Article{
ID: db.NextVal(DB, COUNTER_ARTICLE), ID: mgo.NextVal(DB, COUNTER_ARTICLE),
Author: setting.Conf.Account.Username, Author: setting.Conf.Account.Username,
Title: "关于", Title: "关于",
Slug: "about", Slug: "about",
CreateTime: time.Time{}, CreateTime: time.Time{},
UpdateTime: time.Time{}, UpdateTime: time.Time{},
} }
// 推送到 disqus
go func() { ThreadCreate(about) }()
blogroll := &Article{ blogroll := &Article{
ID: db.NextVal(DB, COUNTER_ARTICLE), ID: mgo.NextVal(DB, COUNTER_ARTICLE),
Author: setting.Conf.Account.Username, Author: setting.Conf.Account.Username,
Title: "友情链接", Title: "友情链接",
Slug: "blogroll", Slug: "blogroll",
CreateTime: time.Time{}, CreateTime: time.Time{},
UpdateTime: time.Time{}, UpdateTime: time.Time{},
} }
err := db.Insert(DB, COLLECTION_ARTICLE, blogroll) err := mgo.Insert(DB, COLLECTION_ARTICLE, blogroll)
if err != nil { if err != nil {
logd.Fatal(err) logd.Fatal(err)
} }
err = db.Insert(DB, COLLECTION_ARTICLE, about) err = mgo.Insert(DB, COLLECTION_ARTICLE, about)
if err != nil { if err != nil {
logd.Fatal(err) logd.Fatal(err)
} }
@@ -273,97 +268,6 @@ func PageList(p, n int) (prev int, next int, artcs []*Article) {
return return
} }
// 管理 tag
func ManageTagsArticle(artc *Article, s bool, do string) {
switch do {
case ADD:
for _, tag := range artc.Tags {
Ei.Tags[tag] = append(Ei.Tags[tag], artc)
if s {
sort.Sort(Ei.Tags[tag])
}
}
case DELETE:
for _, tag := range artc.Tags {
for i, v := range Ei.Tags[tag] {
if v == artc {
Ei.Tags[tag] = append(Ei.Tags[tag][0:i], Ei.Tags[tag][i+1:]...)
if len(Ei.Tags[tag]) == 0 {
delete(Ei.Tags, tag)
}
return
}
}
}
}
}
// 管理专题
func ManageSeriesArticle(artc *Article, s bool, do string) {
switch do {
case ADD:
for i, serie := range Ei.Series {
if serie.ID == artc.SerieID {
Ei.Series[i].Articles = append(Ei.Series[i].Articles, artc)
if s {
sort.Sort(Ei.Series[i].Articles)
Ei.CH <- SERIES_MD
return
}
}
}
case DELETE:
for i, serie := range Ei.Series {
if serie.ID == artc.SerieID {
for j, v := range serie.Articles {
if v == artc {
Ei.Series[i].Articles = append(Ei.Series[i].Articles[0:j], Ei.Series[i].Articles[j+1:]...)
Ei.CH <- SERIES_MD
return
}
}
}
}
}
}
// 管理归档
func ManageArchivesArticle(artc *Article, s bool, do string) {
switch do {
case ADD:
add := false
y, m, _ := artc.CreateTime.Date()
for i, archive := range Ei.Archives {
ay, am, _ := archive.Time.Date()
if y == ay && m == am {
add = true
Ei.Archives[i].Articles = append(Ei.Archives[i].Articles, artc)
if s {
sort.Sort(Ei.Archives[i].Articles)
Ei.CH <- ARCHIVE_MD
break
}
}
}
if !add {
Ei.Archives = append(Ei.Archives, &Archive{Time: artc.CreateTime, Articles: SortArticles{artc}})
}
case DELETE:
for i, archive := range Ei.Archives {
ay, am, _ := archive.Time.Date()
if y, m, _ := artc.CreateTime.Date(); ay == y && am == m {
for j, v := range archive.Articles {
if v == artc {
Ei.Archives[i].Articles = append(Ei.Archives[i].Articles[0:j], Ei.Archives[i].Articles[j+1:]...)
Ei.CH <- ARCHIVE_MD
return
}
}
}
}
}
}
// 渲染markdown操作和截取摘要操作 // 渲染markdown操作和截取摘要操作
var reg = regexp.MustCompile(setting.Conf.General.Identifier) var reg = regexp.MustCompile(setting.Conf.General.Identifier)
@@ -401,23 +305,128 @@ func GenerateExcerptAndRender(artc *Article) {
// 读取草稿箱 // 读取草稿箱
func LoadDraft() (artcs SortArticles, err error) { func LoadDraft() (artcs SortArticles, err error) {
err = db.FindAll(DB, COLLECTION_ARTICLE, bson.M{"isdraft": true}, &artcs) err = mgo.FindAll(DB, COLLECTION_ARTICLE, mgo.M{"isdraft": true}, &artcs)
sort.Sort(artcs) sort.Sort(artcs)
return return
} }
// 读取回收箱 // 读取回收箱
func LoadTrash() (artcs SortArticles, err error) { func LoadTrash() (artcs SortArticles, err error) {
err = db.FindAll(DB, COLLECTION_ARTICLE, bson.M{"deletetime": bson.M{"$ne": time.Time{}}}, &artcs) err = mgo.FindAll(DB, COLLECTION_ARTICLE, mgo.M{"deletetime": mgo.M{"$ne": time.Time{}}}, &artcs)
sort.Sort(artcs) sort.Sort(artcs)
return return
} }
// 添加文章到tag、serie、archive
func upArticle(artc *Article, needSort bool) {
// tag
for _, tag := range artc.Tags {
Ei.Tags[tag] = append(Ei.Tags[tag], artc)
if needSort {
sort.Sort(Ei.Tags[tag])
}
}
// serie
for i, serie := range Ei.Series {
if serie.ID == artc.SerieID {
Ei.Series[i].Articles = append(Ei.Series[i].Articles, artc)
if needSort {
sort.Sort(Ei.Series[i].Articles)
Ei.CH <- SERIES_MD
}
break
}
}
// archive
y, m, _ := artc.CreateTime.Date()
for i, archive := range Ei.Archives {
if ay, am, _ := archive.Time.Date(); y == ay && m == am {
Ei.Archives[i].Articles = append(Ei.Archives[i].Articles, artc)
if needSort {
sort.Sort(Ei.Archives[i].Articles)
Ei.CH <- ARCHIVE_MD
}
return
}
}
Ei.Archives = append(Ei.Archives, &Archive{Time: artc.CreateTime,
Articles: SortArticles{artc}})
if needSort {
Ei.CH <- ARCHIVE_MD
}
}
// 删除文章从tag、serie、archive
func dropArticle(artc *Article) {
// tag
for _, tag := range artc.Tags {
for i, v := range Ei.Tags[tag] {
if v == artc {
Ei.Tags[tag] = append(Ei.Tags[tag][0:i], Ei.Tags[tag][i+1:]...)
if len(Ei.Tags[tag]) == 0 {
delete(Ei.Tags, tag)
}
}
}
}
// serie
for i, serie := range Ei.Series {
if serie.ID == artc.SerieID {
for j, v := range serie.Articles {
if v == artc {
Ei.Series[i].Articles = append(Ei.Series[i].Articles[0:j],
Ei.Series[i].Articles[j+1:]...)
Ei.CH <- SERIES_MD
break
}
}
}
}
// archive
for i, archive := range Ei.Archives {
ay, am, _ := archive.Time.Date()
if y, m, _ := artc.CreateTime.Date(); ay == y && am == m {
for j, v := range archive.Articles {
if v == artc {
Ei.Archives[i].Articles = append(Ei.Archives[i].Articles[0:j],
Ei.Archives[i].Articles[j+1:]...)
if len(Ei.Archives[i].Articles) == 0 {
Ei.Archives = append(Ei.Archives[:i], Ei.Archives[i+1:]...)
}
Ei.CH <- ARCHIVE_MD
break
}
}
}
}
}
// 替换文章
func ReplaceArticle(oldArtc *Article, newArtc *Article) {
if oldArtc != nil {
i, artc := GetArticle(oldArtc.ID)
DelFromLinkedList(artc)
Ei.Articles = append(Ei.Articles[:i], Ei.Articles[i+1:]...)
delete(Ei.MapArticles, artc.Slug)
dropArticle(oldArtc)
}
Ei.MapArticles[newArtc.Slug] = newArtc
Ei.Articles = append(Ei.Articles, newArtc)
sort.Sort(Ei.Articles)
GenerateExcerptAndRender(newArtc)
AddToLinkedList(newArtc.ID)
upArticle(newArtc, true)
}
// 添加文章 // 添加文章
func AddArticle(artc *Article) error { func AddArticle(artc *Article) error {
// 分配ID, 占位至起始id // 分配ID, 占位至起始id
for { for {
if id := db.NextVal(DB, COUNTER_ARTICLE); id < setting.Conf.General.StartID { if id := mgo.NextVal(DB, COUNTER_ARTICLE); id < setting.Conf.General.StartID {
continue continue
} else { } else {
artc.ID = id artc.ID = id
@@ -425,22 +434,22 @@ func AddArticle(artc *Article) error {
} }
} }
err := mgo.Insert(DB, COLLECTION_ARTICLE, artc)
if err != nil {
return err
}
// 正式发布文章
if !artc.IsDraft { if !artc.IsDraft {
// 正式发布文章
defer GenerateExcerptAndRender(artc) defer GenerateExcerptAndRender(artc)
Ei.MapArticles[artc.Slug] = artc Ei.MapArticles[artc.Slug] = artc
Ei.Articles = append([]*Article{artc}, Ei.Articles...) Ei.Articles = append([]*Article{artc}, Ei.Articles...)
sort.Sort(Ei.Articles) sort.Sort(Ei.Articles)
AddToLinkedList(artc.ID) AddToLinkedList(artc.ID)
ManageTagsArticle(artc, true, ADD)
ManageSeriesArticle(artc, true, ADD) upArticle(artc, true)
ManageArchivesArticle(artc, true, ADD)
Ei.CH <- ARCHIVE_MD
if artc.SerieID > 0 {
Ei.CH <- SERIES_MD
}
} }
return db.Insert(DB, COLLECTION_ARTICLE, artc) return nil
} }
// 删除文章,移入回收箱 // 删除文章,移入回收箱
@@ -452,17 +461,13 @@ func DelArticles(ids ...int32) error {
DelFromLinkedList(artc) DelFromLinkedList(artc)
Ei.Articles = append(Ei.Articles[:i], Ei.Articles[i+1:]...) Ei.Articles = append(Ei.Articles[:i], Ei.Articles[i+1:]...)
delete(Ei.MapArticles, artc.Slug) delete(Ei.MapArticles, artc.Slug)
ManageTagsArticle(artc, false, DELETE)
ManageSeriesArticle(artc, false, DELETE) err := UpdateArticle(mgo.M{"id": id}, mgo.M{"$set": mgo.M{"deletetime": time.Now()}})
ManageArchivesArticle(artc, false, DELETE)
err := UpdateArticle(bson.M{"id": id}, bson.M{"$set": bson.M{"deletetime": time.Now()}})
if err != nil { if err != nil {
return err return err
} }
artc = nil dropArticle(artc)
} }
Ei.CH <- ARCHIVE_MD
Ei.CH <- SERIES_MD
return nil return nil
} }
@@ -509,34 +514,36 @@ func timer() {
delT := time.NewTicker(time.Duration(setting.Conf.General.Clean) * time.Hour) delT := time.NewTicker(time.Duration(setting.Conf.General.Clean) * time.Hour)
for { for {
<-delT.C <-delT.C
db.Remove(DB, COLLECTION_ARTICLE, bson.M{"deletetime": bson.M{"$gt": time.Time{}, "$lt": time.Now().Add(time.Duration(setting.Conf.General.Trash) * time.Hour)}}) mgo.Remove(DB, COLLECTION_ARTICLE, mgo.M{"deletetime": mgo.M{"$gt": time.Time{},
"$lt": time.Now().Add(time.Duration(setting.Conf.General.Trash) * time.Hour)}})
} }
} }
// 操作帐号字段 // 操作帐号字段
func UpdateAccountField(M bson.M) error { func UpdateAccountField(M mgo.M) error {
return db.Update(DB, COLLECTION_ACCOUNT, bson.M{"username": Ei.Username}, M) return mgo.Update(DB, COLLECTION_ACCOUNT, mgo.M{"username": Ei.Username}, M)
} }
// 删除草稿箱或回收箱,永久删除 // 删除草稿箱或回收箱,永久删除
func RemoveArticle(id int32) error { func RemoveArticle(id int32) error {
return db.Remove(DB, COLLECTION_ARTICLE, bson.M{"id": id}) return mgo.Remove(DB, COLLECTION_ARTICLE, mgo.M{"id": id})
} }
// 恢复删除文章到草稿箱 // 恢复删除文章到草稿箱
func RecoverArticle(id int32) error { func RecoverArticle(id int32) error {
return db.Update(DB, COLLECTION_ARTICLE, bson.M{"id": id}, bson.M{"$set": bson.M{"deletetime": time.Time{}, "isdraft": true}}) return mgo.Update(DB, COLLECTION_ARTICLE, mgo.M{"id": id},
mgo.M{"$set": mgo.M{"deletetime": time.Time{}, "isdraft": true}})
} }
// 更新文章 // 更新文章
func UpdateArticle(query, update interface{}) error { func UpdateArticle(query, update interface{}) error {
return db.Update(DB, COLLECTION_ARTICLE, query, update) return mgo.Update(DB, COLLECTION_ARTICLE, query, update)
} }
// 编辑文档 // 编辑文档
func QueryArticle(id int32) *Article { func QueryArticle(id int32) *Article {
artc := &Article{} artc := &Article{}
if err := db.FindOne(DB, COLLECTION_ARTICLE, bson.M{"id": id}, artc); err != nil { if err := mgo.FindOne(DB, COLLECTION_ARTICLE, mgo.M{"id": id}, artc); err != nil {
return nil return nil
} }
return artc return artc
@@ -544,17 +551,18 @@ func QueryArticle(id int32) *Article {
// 添加专题 // 添加专题
func AddSerie(name, slug, desc string) error { func AddSerie(name, slug, desc string) error {
serie := &Serie{db.NextVal(DB, COUNTER_SERIE), name, slug, desc, time.Now(), nil} serie := &Serie{mgo.NextVal(DB, COUNTER_SERIE), name, slug, desc, time.Now(), nil}
Ei.Series = append(Ei.Series, serie) Ei.Series = append(Ei.Series, serie)
sort.Sort(Ei.Series) sort.Sort(Ei.Series)
Ei.CH <- SERIES_MD Ei.CH <- SERIES_MD
return UpdateAccountField(bson.M{"$addToSet": bson.M{"blogger.series": serie}}) return UpdateAccountField(mgo.M{"$addToSet": mgo.M{"blogger.series": serie}})
} }
// 更新专题 // 更新专题
func UpdateSerie(serie *Serie) error { func UpdateSerie(serie *Serie) error {
Ei.CH <- SERIES_MD Ei.CH <- SERIES_MD
return db.Update(DB, COLLECTION_ACCOUNT, bson.M{"username": Ei.Username, "blogger.series.id": serie.ID}, bson.M{"$set": bson.M{"blogger.series.$": serie}}) return mgo.Update(DB, COLLECTION_ACCOUNT, mgo.M{"username": Ei.Username,
"blogger.series.id": serie.ID}, mgo.M{"$set": mgo.M{"blogger.series.$": serie}})
} }
// 删除专题 // 删除专题
@@ -564,7 +572,7 @@ func DelSerie(id int32) error {
if len(serie.Articles) > 0 { if len(serie.Articles) > 0 {
return fmt.Errorf("请删除该专题下的所有文章") return fmt.Errorf("请删除该专题下的所有文章")
} }
err := UpdateAccountField(bson.M{"$pull": bson.M{"blogger.series": bson.M{"id": id}}}) err := UpdateAccountField(mgo.M{"$pull": mgo.M{"blogger.series": mgo.M{"id": id}}})
if err != nil { if err != nil {
return err return err
} }
@@ -588,24 +596,24 @@ func QuerySerie(id int32) *Serie {
// 后台分页 // 后台分页
func PageListBack(se int, kw string, draft, del bool, p, n int) (max int, artcs []*Article) { func PageListBack(se int, kw string, draft, del bool, p, n int) (max int, artcs []*Article) {
M := bson.M{} M := mgo.M{}
if draft { if draft {
M["isdraft"] = true M["isdraft"] = true
} else if del { } else if del {
M["deletetime"] = bson.M{"$ne": time.Time{}} M["deletetime"] = mgo.M{"$ne": time.Time{}}
} else { } else {
M["isdraft"] = false M["isdraft"] = false
M["deletetime"] = bson.M{"$eq": time.Time{}} M["deletetime"] = mgo.M{"$eq": time.Time{}}
if se > 0 { if se > 0 {
M["serieid"] = se M["serieid"] = se
} }
if kw != "" { if kw != "" {
M["title"] = bson.M{"$regex": kw, "$options": "$i"} M["title"] = mgo.M{"$regex": kw, "$options": "$i"}
} }
} }
ms, c := db.Connect(DB, COLLECTION_ARTICLE) ms, c := mgo.Connect(DB, COLLECTION_ARTICLE)
defer ms.Close() defer ms.Close()
err := c.Find(M).Select(bson.M{"content": 0}).Sort("-createtime").Limit(n).Skip((p - 1) * n).All(&artcs) err := c.Find(M).Select(mgo.M{"content": 0}).Sort("-createtime").Limit(n).Skip((p - 1) * n).All(&artcs)
if err != nil { if err != nil {
logd.Error(err) logd.Error(err)
} }

147
disqus.go
View File

@@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url"
"strings" "strings"
"time" "time"
@@ -17,6 +18,12 @@ import (
var ErrDisqusConfig = errors.New("disqus config incorrect") var ErrDisqusConfig = errors.New("disqus config incorrect")
func correctDisqusConfig() bool {
return setting.Conf.Disqus.PostsCount != "" &&
setting.Conf.Disqus.PublicKey != "" &&
setting.Conf.Disqus.ShortName != ""
}
// 定时获取所有文章评论数量 // 定时获取所有文章评论数量
type postsCountResp struct { type postsCountResp struct {
Code int Code int
@@ -28,9 +35,7 @@ type postsCountResp struct {
} }
func PostsCount() error { func PostsCount() error {
if setting.Conf.Disqus.PostsCount == "" || if !correctDisqusConfig() {
setting.Conf.Disqus.PublicKey == "" ||
setting.Conf.Disqus.ShortName == "" {
return ErrDisqusConfig return ErrDisqusConfig
} }
@@ -41,20 +46,19 @@ func PostsCount() error {
} }
}) })
baseUrl := setting.Conf.Disqus.PostsCount + vals := url.Values{}
"?api_key=" + setting.Conf.Disqus.PublicKey + vals.Set("api_key", setting.Conf.Disqus.PublicKey)
"&forum=" + setting.Conf.Disqus.ShortName + "&" vals.Set("forum", setting.Conf.Disqus.ShortName)
var count, index int var count, index int
for index < len(Ei.Articles) { for index < len(Ei.Articles) {
var threads []string
for ; index < len(Ei.Articles) && count < 50; index++ { for ; index < len(Ei.Articles) && count < 50; index++ {
artc := Ei.Articles[index] artc := Ei.Articles[index]
threads = append(threads, fmt.Sprintf("thread:ident=post-%s", artc.Slug)) vals.Add("thread:ident", "post-"+artc.Slug)
count++ count++
} }
count = 0 count = 0
url := baseUrl + strings.Join(threads, "&") resp, err := http.Get(setting.Conf.Disqus.PostsCount + "?" + vals.Encode())
resp, err := http.Get(url)
if err != nil { if err != nil {
return err return err
} }
@@ -113,16 +117,18 @@ type postDetail struct {
} }
func PostsList(slug, cursor string) (*postsListResp, error) { func PostsList(slug, cursor string) (*postsListResp, error) {
if setting.Conf.Disqus.PostsList == "" || if !correctDisqusConfig() {
setting.Conf.Disqus.PublicKey == "" ||
setting.Conf.Disqus.ShortName == "" {
return nil, ErrDisqusConfig return nil, ErrDisqusConfig
} }
url := setting.Conf.Disqus.PostsList + "?limit=50&api_key=" + vals := url.Values{}
setting.Conf.Disqus.PublicKey + "&forum=" + setting.Conf.Disqus.ShortName + vals.Set("api_key", setting.Conf.Disqus.PublicKey)
"&cursor=" + cursor + "&thread:ident=post-" + slug vals.Set("forum", setting.Conf.Disqus.ShortName)
resp, err := http.Get(url) vals.Set("thread:ident", "post-"+slug)
vals.Set("cursor", cursor)
vals.Set("limit", "50")
resp, err := http.Get(setting.Conf.Disqus.PostsList + "?" + vals.Encode())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -144,15 +150,15 @@ func PostsList(slug, cursor string) (*postsListResp, error) {
return result, nil return result, nil
} }
type PostCreate struct { type PostComment struct {
Message string `json:"message"` Message string
Parent string `json:"parent"` Parent string
Thread string `json:"thread"` Thread string
AuthorEmail string `json:"author_email"` AuthorEmail string
AuthorName string `json:"autor_name"` AuthorName string
IpAddress string `json:"ip_address"` IpAddress string
Identifier string `json:"identifier"` Identifier string
UserAgent string `json:"user_agent"` UserAgent string
} }
type postCreateResp struct { type postCreateResp struct {
@@ -161,19 +167,21 @@ type postCreateResp struct {
} }
// 评论文章 // 评论文章
func PostComment(pc *PostCreate) (*postCreateResp, error) { func PostCreate(pc *PostComment) (*postCreateResp, error) {
if setting.Conf.Disqus.PostsList == "" || if !correctDisqusConfig() {
setting.Conf.Disqus.PublicKey == "" ||
setting.Conf.Disqus.ShortName == "" {
return nil, ErrDisqusConfig return nil, ErrDisqusConfig
} }
url := setting.Conf.Disqus.PostCreate +
"?api_key=E8Uh5l5fHZ6gD8U3KycjAIAk46f68Zw7C6eW8WSjZvCLXebZ7p0r1yrYDrLilk2F" +
"&message=" + pc.Message + "&parent=" + pc.Parent +
"&thread=" + pc.Thread + "&author_email=" + pc.AuthorEmail +
"&author_name=" + pc.AuthorName
request, err := http.NewRequest("POST", url, nil) vals := url.Values{}
vals.Set("api_key", "E8Uh5l5fHZ6gD8U3KycjAIAk46f68Zw7C6eW8WSjZvCLXebZ7p0r1yrYDrLilk2F")
vals.Set("message", pc.Message)
vals.Set("parent", pc.Parent)
vals.Set("thread", pc.Thread)
vals.Set("author_email", pc.AuthorEmail)
vals.Set("author_name", pc.AuthorName)
// vals.Set("state", "approved")
request, err := http.NewRequest("POST", setting.Conf.Disqus.PostCreate, strings.NewReader(vals.Encode()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -201,24 +209,23 @@ func PostComment(pc *PostCreate) (*postCreateResp, error) {
// 批准评论通过 // 批准评论通过
type approvedResp struct { type approvedResp struct {
Code int `json:"code"` Code int
Response []struct { Response []struct {
Id string `json:"id"` Id string
} `json:"response"` }
} }
func PostApprove(post string) error { func PostApprove(post string) error {
if setting.Conf.Disqus.PostsList == "" || if !correctDisqusConfig() {
setting.Conf.Disqus.PublicKey == "" ||
setting.Conf.Disqus.ShortName == "" {
return ErrDisqusConfig return ErrDisqusConfig
} }
url := setting.Conf.Disqus.PostApprove + vals := url.Values{}
"?api_key=" + setting.Conf.Disqus.PublicKey + vals.Set("api_key", setting.Conf.Disqus.PublicKey)
"&access_token=" + setting.Conf.Disqus.AccessToken + vals.Set("access_token", setting.Conf.Disqus.AccessToken)
"&post=" + post vals.Set("post", post)
request, err := http.NewRequest("POST", url, nil)
request, err := http.NewRequest("POST", setting.Conf.Disqus.PostApprove, strings.NewReader(vals.Encode()))
if err != nil { if err != nil {
return err return err
} }
@@ -246,3 +253,49 @@ func PostApprove(post string) error {
return nil return nil
} }
// 创建thread
type threadCreateResp struct {
Code int
Response struct {
Id string
}
}
func ThreadCreate(artc *Article) error {
if !correctDisqusConfig() {
return ErrDisqusConfig
}
vals := url.Values{}
vals.Set("api_key", setting.Conf.Disqus.PublicKey)
vals.Set("access_token", setting.Conf.Disqus.AccessToken)
vals.Set("forum", setting.Conf.Disqus.ShortName)
vals.Set("title", artc.Title+" | "+Ei.BTitle)
vals.Set("identifier", "post-"+artc.Slug)
urlPath := fmt.Sprintf("https://%s/post/%s.html", setting.Conf.Mode.Domain, artc.Slug)
vals.Set("url", urlPath)
resp, err := http.PostForm(setting.Conf.Disqus.ThreadCreate, vals)
if err != nil {
return err
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
return errors.New(string(b))
}
result := &threadCreateResp{}
err = json.Unmarshal(b, result)
if err != nil {
return err
}
artc.Thread = result.Response.Id
return nil
}

View File

@@ -8,18 +8,29 @@ func TestDisqus(t *testing.T) {
PostsCount() PostsCount()
} }
func TestPostComment(t *testing.T) { func TestPostCreate(t *testing.T) {
pc := &PostCreate{ pc := &PostComment{
Message: "hahahaha", Message: "hahahaha",
Thread: "52799014", Thread: "52799014",
AuthorEmail: "deepzz.qi@gmail.com", AuthorEmail: "deepzz.qi@gmail.com",
AuthorName: "deepzz", AuthorName: "deepzz",
} }
id, err := PostComment(pc) id, err := PostCreate(pc)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
t.Log("post success", id) t.Log("post success", id)
} }
func TestThreadCreate(t *testing.T) {
tc := &Article{
Title: "测试test7",
Slug: "test7",
}
err := ThreadCreate(tc)
if err != nil {
t.Fatal(err)
}
}

View File

@@ -6,8 +6,6 @@ services:
volumes: volumes:
- /data/eiblog/mgodb:/data/db - /data/eiblog/mgodb:/data/db
restart: always restart: always
ports:
- 27017:27017
elasticsearch: elasticsearch:
image: elasticsearch:2.4.1 image: elasticsearch:2.4.1
container_name: eisearch container_name: eisearch
@@ -35,3 +33,14 @@ services:
ports: ports:
- "9000:9000" - "9000:9000"
restart: always restart: always
# backup:
# image: registry.cn-hangzhou.aliyuncs.com/deepzz/backup
# container_name: backup
# links:
# - mongodb
# environment:
# - QINIU_BUCKET=xxxx
# - QINIU_DOMAIN=xx.example.com
# - ACCESS_KEY=xxxxxxxxxx
# - SECRECT_KEY=xxxxxxxxxx
# restart: always

View File

@@ -18,3 +18,7 @@ twitter:
![twitter-card2](http://7xokm2.com1.z0.glb.clouddn.com/img/twitter-pub2.png) ![twitter-card2](http://7xokm2.com1.z0.glb.clouddn.com/img/twitter-pub2.png)
可以看到``之前是没有内容的,该内容是我们文章的描述。 可以看到``之前是没有内容的,该内容是我们文章的描述。
### Google OpenSearch
在 Chrome 浏览器上,你可以在输入网站后按 TAB 键进入搜索模式,如:
![opensearch](http://7xokm2.com1.z0.glb.clouddn.com/opensearch.gif)

View File

@@ -90,7 +90,7 @@ $ docker run -d --name eisearch \
| ------------------ | ---------------------------------------- | ---------------------------------------- | | ------------------ | ---------------------------------------- | ---------------------------------------- |
| favicon.ico | st.example.com/static/img/favicon.ico | cdn 中的文件名为 `static/img/favicon.ico`。你也可以复制 favicon.ico 到 static 文件夹下,通过 example.com/favicon.ico 也是能够访问到。docker 用户可能需要重新打包镜像。 | | favicon.ico | st.example.com/static/img/favicon.ico | cdn 中的文件名为 `static/img/favicon.ico`。你也可以复制 favicon.ico 到 static 文件夹下,通过 example.com/favicon.ico 也是能够访问到。docker 用户可能需要重新打包镜像。 |
| bg04.jpg | st.example.com/static/img/bg04.jpg | 首页左侧的大背景图,需要更名请到 views/st_blog.css 修改。 | | bg04.jpg | st.example.com/static/img/bg04.jpg | 首页左侧的大背景图,需要更名请到 views/st_blog.css 修改。 |
| avatar.jpg | st.example.com/static/img/avatar.jpg | 头像 | | avatar.png | st.example.com/static/img/avatar.png | 头像 |
| blank.gif | st.example.com/static/img/blank.gif | 空白图片,[下载](https://st.deepzz.com/static/img/blank.gif) | | blank.gif | st.example.com/static/img/blank.gif | 空白图片,[下载](https://st.deepzz.com/static/img/blank.gif) |
| default_avatar.png | st.example.com/static/img/default_avatar.png | disqus 默认图片,[下载](https://st.deepzz.com/static/img/default_avatar.png) | | default_avatar.png | st.example.com/static/img/default_avatar.png | disqus 默认图片,[下载](https://st.deepzz.com/static/img/default_avatar.png) |
| disqus.js | st.example.com/static/js/disqus_xxx.js | disqus 文件,你可以通过 https://short_name.disqus.com/embed.js 下载你的专属文件,并上传到七牛。更新配置文件 app.yml。 | | disqus.js | st.example.com/static/js/disqus_xxx.js | disqus 文件,你可以通过 https://short_name.disqus.com/embed.js 下载你的专属文件,并上传到七牛。更新配置文件 app.yml。 |

View File

@@ -6,6 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net"
"net/http" "net/http"
"regexp" "regexp"
"strings" "strings"
@@ -23,10 +24,20 @@ const (
ES_DATE = `{"range":{"date":{"gte":"%s","lte": "%s","format": "yyyy-MM-dd||yyyy-MM||yyyy"}}}` // 2016-10||/M ES_DATE = `{"range":{"date":{"gte":"%s","lte": "%s","format": "yyyy-MM-dd||yyyy-MM||yyyy"}}}` // 2016-10||/M
) )
var es *ElasticService var (
ErrUninitializedES = errors.New("uninitialized elasticsearch")
es *ElasticService
)
// 初始化 Elasticsearch 服务器 // 初始化 Elasticsearch 服务器
func init() { func init() {
_, err := net.LookupIP("elasticsearch")
if err != nil {
logd.Info(err)
return
}
es = &ElasticService{url: "http://elasticsearch:9200", c: new(http.Client)} es = &ElasticService{url: "http://elasticsearch:9200", c: new(http.Client)}
initIndex() initIndex()
} }
@@ -41,7 +52,11 @@ func initIndex() {
} }
// 查询 // 查询
func Elasticsearch(qStr string, size, from int) *ESSearchResult { func Elasticsearch(qStr string, size, from int) (*ESSearchResult, error) {
if es == nil {
return nil, ErrUninitializedES
}
// 分析查询字符串 // 分析查询字符串
reg := regexp.MustCompile(`(tag|slug|date):`) reg := regexp.MustCompile(`(tag|slug|date):`)
indexs := reg.FindAllStringIndex(qStr, -1) indexs := reg.FindAllStringIndex(qStr, -1)
@@ -92,14 +107,17 @@ func Elasticsearch(qStr string, size, from int) *ESSearchResult {
} }
docs, err := IndexQueryDSL(INDEX, TYPE, size, from, []byte(dsl)) docs, err := IndexQueryDSL(INDEX, TYPE, size, from, []byte(dsl))
if err != nil { if err != nil {
logd.Error(err) return nil, err
return nil
} }
return docs return docs, nil
} }
// 添加或更新索引 // 添加或更新索引
func ElasticIndex(artc *Article) error { func ElasticIndex(artc *Article) error {
if es == nil {
return ErrUninitializedES
}
img := PickFirstImage(artc.Content) img := PickFirstImage(artc.Content)
mapping := map[string]interface{}{ mapping := map[string]interface{}{
"title": artc.Title, "title": artc.Title,
@@ -115,6 +133,10 @@ func ElasticIndex(artc *Article) error {
// 删除索引 // 删除索引
func ElasticDelIndex(ids []int32) error { func ElasticDelIndex(ids []int32) error {
if es == nil {
return ErrUninitializedES
}
var target []string var target []string
for _, id := range ids { for _, id := range ids {
target = append(target, fmt.Sprint(id)) target = append(target, fmt.Sprint(id))

View File

@@ -199,10 +199,12 @@ func HandleSearchPage(c *gin.Context) {
start = 1 start = 1
} }
h["Word"] = q h["Word"] = q
var result *ESSearchResult
vals := c.Request.URL.Query() vals := c.Request.URL.Query()
result = Elasticsearch(q, setting.Conf.General.PageNum, start-1) result, err := Elasticsearch(q, setting.Conf.General.PageNum, start-1)
if result != nil { if err != nil {
logd.Error(err)
} else {
result.Took /= 1000 result.Took /= 1000
for i, v := range result.Hits.Hits { for i, v := range result.Hits.Hits {
if artc := Ei.MapArticles[result.Hits.Hits[i].Source.Slug]; len(v.Highlight.Content) == 0 && artc != nil { if artc := Ei.MapArticles[result.Hits.Hits[i].Source.Slug]; len(v.Highlight.Content) == 0 && artc != nil {
@@ -379,7 +381,8 @@ func HandleDisqus(c *gin.Context) {
} }
// 发表评论 // 发表评论
// [thread:[5279901489] parent:[] identifier:[post-troubleshooting-https] next:[] author_name:[你好] author_email:[chenqijing2@163.com] message:[fdsfdsf]] // [thread:[5279901489] parent:[] identifier:[post-troubleshooting-https]
// next:[] author_name:[你好] author_email:[chenqijing2@163.com] message:[fdsfdsf]]
type DisqusCreate struct { type DisqusCreate struct {
ErrNo int `json:"errno"` ErrNo int `json:"errno"`
ErrMsg string `json:"errmsg"` ErrMsg string `json:"errmsg"`
@@ -400,7 +403,7 @@ func HandleDisqusCreate(c *gin.Context) {
resp.ErrMsg = "参数错误" resp.ErrMsg = "参数错误"
return return
} }
pc := &PostCreate{ pc := &PostComment{
Message: msg, Message: msg,
Parent: c.PostForm("parent"), Parent: c.PostForm("parent"),
Thread: thread, Thread: thread,
@@ -410,7 +413,7 @@ func HandleDisqusCreate(c *gin.Context) {
IpAddress: c.ClientIP(), IpAddress: c.ClientIP(),
} }
postDetail, err := PostComment(pc) postDetail, err := PostCreate(pc)
if err != nil { if err != nil {
logd.Error(err) logd.Error(err)
resp.ErrNo = FAIL resp.ErrNo = FAIL

18
glide.lock generated
View File

@@ -1,28 +1,28 @@
hash: c733fa4abeda21b59b001578b37a168bd33038d337b61198cc5fd94be8bfdf77 hash: c733fa4abeda21b59b001578b37a168bd33038d337b61198cc5fd94be8bfdf77
updated: 2017-11-05T12:08:01.167405372+08:00 updated: 2018-01-13T18:22:28.620808+08:00
imports: imports:
- name: github.com/boj/redistore - name: github.com/boj/redistore
version: 4562487a4bee9a7c272b72bfaeda4917d0a47ab9 version: 4562487a4bee9a7c272b72bfaeda4917d0a47ab9
- name: github.com/deepzz0/logd - name: github.com/deepzz0/logd
version: 2bbe53d047054777f3a171cdfc6dca7aa9f8af78 version: f91dd8c6316f0e156e93895a96739b67577b6a63
- name: github.com/eiblog/blackfriday - name: github.com/eiblog/blackfriday
version: c0ec111761ae784fe31cc076f2fa0e2d2216d623 version: c0ec111761ae784fe31cc076f2fa0e2d2216d623
- name: github.com/eiblog/utils - name: github.com/eiblog/utils
version: ddfd888542f9a093000f71c3709009c1440a0789 version: e8f16268dae939f920ddc55f1c9e46a97a5e3559
subpackages: subpackages:
- logd - logd
- mgo - mgo
- tmpl - tmpl
- uuid - uuid
- name: github.com/garyburd/redigo - name: github.com/garyburd/redigo
version: 47dc60e71eed504e3ef8e77ee3c6fe720f3be57f version: d1ed5c67e5794de818ea85e6b522fda02623a484
subpackages: subpackages:
- internal - internal
- redis - redis
- name: github.com/gin-gonic/autotls - name: github.com/gin-gonic/autotls
version: 8ca25fbde72bb72a00466215b94b489c71fcb815 version: 8ca25fbde72bb72a00466215b94b489c71fcb815
- name: github.com/gin-gonic/contrib - name: github.com/gin-gonic/contrib
version: 5aa1e38d1d932e45fa5032bd1b8739e1a548e596 version: 88aede40372d4bcb11e45168a8c30d99e44cf617
subpackages: subpackages:
- sessions - sessions
- name: github.com/gin-gonic/gin - name: github.com/gin-gonic/gin
@@ -43,7 +43,7 @@ imports:
- name: github.com/manucorporat/sse - name: github.com/manucorporat/sse
version: ee05b128a739a0fb76c7ebd3ae4810c1de808d6d version: ee05b128a739a0fb76c7ebd3ae4810c1de808d6d
- name: github.com/mattn/go-isatty - name: github.com/mattn/go-isatty
version: a5cdd64afdee435007ee3e9f6ed4684af949d568 version: 6ca4dbf54d38eea1a992b3c722a76a5d1c4cb25c
- name: github.com/qiniu/api.v7 - name: github.com/qiniu/api.v7
version: b7c7d6a2ce0aff8e5e7d14c39c3cde867efa1123 version: b7c7d6a2ce0aff8e5e7d14c39c3cde867efa1123
subpackages: subpackages:
@@ -61,7 +61,7 @@ imports:
- name: github.com/shurcooL/sanitized_anchor_name - name: github.com/shurcooL/sanitized_anchor_name
version: 86672fcb3f950f35f2e675df2240550f2a50762f version: 86672fcb3f950f35f2e675df2240550f2a50762f
- name: golang.org/x/crypto - name: golang.org/x/crypto
version: bd6f299fb381e4c3393d1c4b1f0b94f5e77650c8 version: 13931e22f9e72ea58bb73048bc752b48c6d4d4ac
subpackages: subpackages:
- acme - acme
- acme/autocert - acme/autocert
@@ -70,7 +70,7 @@ imports:
subpackages: subpackages:
- context - context
- name: golang.org/x/sys - name: golang.org/x/sys
version: 8eb05f94d449fdf134ec24630ce69ada5b469c1c version: 810d7000345868fc619eb81f46307107118f4ae1
subpackages: subpackages:
- unix - unix
- name: gopkg.in/go-playground/validator.v8 - name: gopkg.in/go-playground/validator.v8
@@ -83,7 +83,7 @@ imports:
- internal/sasl - internal/sasl
- internal/scram - internal/scram
- name: gopkg.in/yaml.v2 - name: gopkg.in/yaml.v2
version: eb3733d160e74a9c7e442f435eb3bea458e1d19f version: d670f9405373e636a5a2765eea47fac0c9bc91a4
- name: qiniupkg.com/x - name: qiniupkg.com/x
version: 946c4a16076d6d98aeb78619e2bd4012357f7228 version: 946c4a16076d6d98aeb78619e2bd4012357f7228
subpackages: subpackages:

View File

@@ -3,6 +3,7 @@ package main
import ( import (
"bytes" "bytes"
"encoding/xml" "encoding/xml"
"fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
@@ -64,7 +65,7 @@ func (p *pingRPC) PingFunc(slug string) {
if len(setting.Conf.PingRPCs) == 0 { if len(setting.Conf.PingRPCs) == 0 {
return return
} }
p.Params.Param[1].Value = "https://" + setting.Conf.Mode.Domain + "/post/" + slug + ".html" p.Params.Param[1].Value = fmt.Sprintf("https://%s/post/%s.html", setting.Conf.Mode.Domain, slug)
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
buf.WriteString(xml.Header) buf.WriteString(xml.Header)
enc := xml.NewEncoder(buf) enc := xml.NewEncoder(buf)

View File

@@ -4,7 +4,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net/url"
"path/filepath" "path/filepath"
"github.com/eiblog/eiblog/setting" "github.com/eiblog/eiblog/setting"
@@ -12,18 +11,6 @@ import (
"github.com/qiniu/api.v7/storage" "github.com/qiniu/api.v7/storage"
) )
type bucket struct {
name string
domain string
accessKey string
secretKey string
}
type PutRet struct {
Hash string `json:"hash"`
Key string `json:"key"`
}
// 进度条 // 进度条
func onProgress(fsize, uploaded int64) { func onProgress(fsize, uploaded int64) {
d := int(float64(uploaded) / float64(fsize) * 100) d := int(float64(uploaded) / float64(fsize) * 100)
@@ -41,10 +28,6 @@ func FileUpload(name string, size int64, data io.Reader) (string, error) {
} }
key := getKey(name) key := getKey(name)
if key == "" {
return "", errors.New("不支持的文件类型")
}
mac := qbox.NewMac(setting.Conf.Qiniu.AccessKey, setting.Conf.Qiniu.SecretKey) mac := qbox.NewMac(setting.Conf.Qiniu.AccessKey, setting.Conf.Qiniu.SecretKey)
// 设置上传的策略 // 设置上传的策略
putPolicy := &storage.PutPolicy{ putPolicy := &storage.PutPolicy{
@@ -63,22 +46,20 @@ func FileUpload(name string, size int64, data io.Reader) (string, error) {
// uploader // uploader
uploader := storage.NewFormUploader(cfg) uploader := storage.NewFormUploader(cfg)
ret := new(storage.PutRet) ret := new(storage.PutRet)
putExtra := &storage.PutExtra{OnProgress: onProgress} putExtra := &storage.PutExtra{}
err := uploader.Put(nil, ret, upToken, key, data, size, putExtra) err := uploader.Put(nil, ret, upToken, key, data, size, putExtra)
if err != nil { if err != nil {
return "", err return "", err
} }
url := "https://" + setting.Conf.Qiniu.Domain + "/" + url.QueryEscape(key) url := "https://" + setting.Conf.Qiniu.Domain + "/" + key
return url, nil return url, nil
} }
// 删除文件 // 删除文件
func FileDelete(name string) error { func FileDelete(name string) error {
key := getKey(name) key := getKey(name)
if key == "" {
return errors.New("不支持的文件类型")
}
mac := qbox.NewMac(setting.Conf.Qiniu.AccessKey, setting.Conf.Qiniu.SecretKey) mac := qbox.NewMac(setting.Conf.Qiniu.AccessKey, setting.Conf.Qiniu.SecretKey)
// 上传配置 // 上传配置
@@ -105,9 +86,12 @@ func getKey(name string) string {
key = "blog/img/" + name key = "blog/img/" + name
case ".mov", ".mp4": case ".mov", ".mp4":
key = "blog/video/" + name key = "blog/video/" + name
case ".go", ".js", ".css", ".cpp", ".php", ".rb", ".java", ".py", ".sql", ".lua", ".html", ".sh", ".xml", ".cs": case ".go", ".js", ".css", ".cpp", ".php", ".rb",
".java", ".py", ".sql", ".lua", ".html",
".sh", ".xml", ".cs":
key = "blog/code/" + name key = "blog/code/" + name
case ".txt", ".md", ".ini", ".yaml", ".yml", ".doc", ".ppt", ".pdf": case ".txt", ".md", ".ini", ".yaml", ".yml",
".doc", ".ppt", ".pdf":
key = "blog/document/" + name key = "blog/document/" + name
case ".zip", ".rar", ".tar", ".gz": case ".zip", ".rar", ".tar", ".gz":
key = "blog/archive/" + name key = "blog/archive/" + name

View File

@@ -2,8 +2,9 @@
package main package main
import ( import (
"crypto/rand"
"fmt" "fmt"
"html/template" "text/template"
"time" "time"
"github.com/eiblog/eiblog/setting" "github.com/eiblog/eiblog/setting"
@@ -27,7 +28,12 @@ func init() {
} }
router = gin.Default() router = gin.Default()
store := sessions.NewCookieStore([]byte("eiblog321")) b := make([]byte, 16)
_, err := rand.Read(b)
if err != nil {
logd.Fatal(err)
}
store := sessions.NewCookieStore(b)
store.Options(sessions.Options{ store.Options(sessions.Options{
MaxAge: 86400 * 7, MaxAge: 86400 * 7,
Path: "/", Path: "/",
@@ -43,7 +49,7 @@ func init() {
} }
return false return false
}) })
_, err := Tmpl.ParseFiles(files...) _, err = Tmpl.ParseFiles(files...)
if err != nil { if err != nil {
logd.Fatal(err) logd.Fatal(err)
} }

View File

@@ -35,15 +35,16 @@ type Config struct {
Clean int // 清理回收箱频率 Clean int // 清理回收箱频率
} }
Disqus struct { // 获取文章数量相关 Disqus struct { // 获取文章数量相关
ShortName string ShortName string
PublicKey string PublicKey string
AccessToken string AccessToken string
PostsCount string PostsCount string
PostsList string PostsList string
PostCreate string PostCreate string
PostApprove string PostApprove string
Embed string ThreadCreate string
Interval int Embed string
Interval int
} }
Google struct { // 谷歌统计 Google struct { // 谷歌统计
URL string URL string

View File

@@ -67,10 +67,20 @@ type LogOption struct {
Mails Emailer // 告警邮件 Mails Emailer // 告警邮件
} }
func osSep() string {
var sep string
if os.IsPathSeparator('\\') {
sep = "\\"
} else {
sep = "/"
}
return sep
}
// 新建日志打印器 // 新建日志打印器
func New(option LogOption) *Logger { func New(option LogOption) *Logger {
wd, _ := os.Getwd() wd, _ := os.Getwd()
index := strings.LastIndex(wd, "/") index := strings.LastIndex(wd, osSep())
logger := &Logger{ logger := &Logger{
obj: wd[index+1:], obj: wd[index+1:],
out: option.Out, out: option.Out,

View File

@@ -29,6 +29,10 @@ import (
"time" "time"
) )
var (
_ ConnWithTimeout = (*conn)(nil)
)
// conn is the low-level implementation of Conn // conn is the low-level implementation of Conn
type conn struct { type conn struct {
// Shared // Shared
@@ -72,6 +76,7 @@ type DialOption struct {
type dialOptions struct { type dialOptions struct {
readTimeout time.Duration readTimeout time.Duration
writeTimeout time.Duration writeTimeout time.Duration
dialer *net.Dialer
dial func(network, addr string) (net.Conn, error) dial func(network, addr string) (net.Conn, error)
db int db int
password string password string
@@ -94,17 +99,27 @@ func DialWriteTimeout(d time.Duration) DialOption {
}} }}
} }
// DialConnectTimeout specifies the timeout for connecting to the Redis server. // DialConnectTimeout specifies the timeout for connecting to the Redis server when
// no DialNetDial option is specified.
func DialConnectTimeout(d time.Duration) DialOption { func DialConnectTimeout(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) { return DialOption{func(do *dialOptions) {
dialer := net.Dialer{Timeout: d} do.dialer.Timeout = d
do.dial = dialer.Dial }}
}
// DialKeepAlive specifies the keep-alive period for TCP connections to the Redis server
// when no DialNetDial option is specified.
// If zero, keep-alives are not enabled. If no DialKeepAlive option is specified then
// the default of 5 minutes is used to ensure that half-closed TCP sessions are detected.
func DialKeepAlive(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) {
do.dialer.KeepAlive = d
}} }}
} }
// DialNetDial specifies a custom dial function for creating TCP // DialNetDial specifies a custom dial function for creating TCP
// connections. If this option is left out, then net.Dial is // connections, otherwise a net.Dialer customized via the other options is used.
// used. DialNetDial overrides DialConnectTimeout. // DialNetDial overrides DialConnectTimeout and DialKeepAlive.
func DialNetDial(dial func(network, addr string) (net.Conn, error)) DialOption { func DialNetDial(dial func(network, addr string) (net.Conn, error)) DialOption {
return DialOption{func(do *dialOptions) { return DialOption{func(do *dialOptions) {
do.dial = dial do.dial = dial
@@ -154,11 +169,16 @@ func DialUseTLS(useTLS bool) DialOption {
// address using the specified options. // address using the specified options.
func Dial(network, address string, options ...DialOption) (Conn, error) { func Dial(network, address string, options ...DialOption) (Conn, error) {
do := dialOptions{ do := dialOptions{
dial: net.Dial, dialer: &net.Dialer{
KeepAlive: time.Minute * 5,
},
} }
for _, option := range options { for _, option := range options {
option.f(&do) option.f(&do)
} }
if do.dial == nil {
do.dial = do.dialer.Dial
}
netConn, err := do.dial(network, address) netConn, err := do.dial(network, address)
if err != nil { if err != nil {
@@ -166,7 +186,12 @@ func Dial(network, address string, options ...DialOption) (Conn, error) {
} }
if do.useTLS { if do.useTLS {
tlsConfig := cloneTLSClientConfig(do.tlsConfig, do.skipVerify) var tlsConfig *tls.Config
if do.tlsConfig == nil {
tlsConfig = &tls.Config{InsecureSkipVerify: do.skipVerify}
} else {
tlsConfig = cloneTLSConfig(do.tlsConfig)
}
if tlsConfig.ServerName == "" { if tlsConfig.ServerName == "" {
host, _, err := net.SplitHostPort(address) host, _, err := net.SplitHostPort(address)
if err != nil { if err != nil {
@@ -555,10 +580,17 @@ func (c *conn) Flush() error {
return nil return nil
} }
func (c *conn) Receive() (reply interface{}, err error) { func (c *conn) Receive() (interface{}, error) {
if c.readTimeout != 0 { return c.ReceiveWithTimeout(c.readTimeout)
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) }
func (c *conn) ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error) {
var deadline time.Time
if timeout != 0 {
deadline = time.Now().Add(timeout)
} }
c.conn.SetReadDeadline(deadline)
if reply, err = c.readReply(); err != nil { if reply, err = c.readReply(); err != nil {
return nil, c.fatal(err) return nil, c.fatal(err)
} }
@@ -581,6 +613,10 @@ func (c *conn) Receive() (reply interface{}, err error) {
} }
func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) { func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
return c.DoWithTimeout(c.readTimeout, cmd, args...)
}
func (c *conn) DoWithTimeout(readTimeout time.Duration, cmd string, args ...interface{}) (interface{}, error) {
c.mu.Lock() c.mu.Lock()
pending := c.pending pending := c.pending
c.pending = 0 c.pending = 0
@@ -604,9 +640,11 @@ func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
return nil, c.fatal(err) return nil, c.fatal(err)
} }
if c.readTimeout != 0 { var deadline time.Time
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) if readTimeout != 0 {
deadline = time.Now().Add(readTimeout)
} }
c.conn.SetReadDeadline(deadline)
if cmd == "" { if cmd == "" {
reply := make([]interface{}, pending) reply := make([]interface{}, pending)

View File

@@ -34,14 +34,16 @@ import (
type testConn struct { type testConn struct {
io.Reader io.Reader
io.Writer io.Writer
readDeadline time.Time
writeDeadline time.Time
} }
func (*testConn) Close() error { return nil } func (*testConn) Close() error { return nil }
func (*testConn) LocalAddr() net.Addr { return nil } func (*testConn) LocalAddr() net.Addr { return nil }
func (*testConn) RemoteAddr() net.Addr { return nil } func (*testConn) RemoteAddr() net.Addr { return nil }
func (*testConn) SetDeadline(t time.Time) error { return nil } func (c *testConn) SetDeadline(t time.Time) error { c.readDeadline = t; c.writeDeadline = t; return nil }
func (*testConn) SetReadDeadline(t time.Time) error { return nil } func (c *testConn) SetReadDeadline(t time.Time) error { c.readDeadline = t; return nil }
func (*testConn) SetWriteDeadline(t time.Time) error { return nil } func (c *testConn) SetWriteDeadline(t time.Time) error { c.writeDeadline = t; return nil }
func dialTestConn(r string, w io.Writer) redis.DialOption { func dialTestConn(r string, w io.Writer) redis.DialOption {
return redis.DialNetDial(func(network, addr string) (net.Conn, error) { return redis.DialNetDial(func(network, addr string) (net.Conn, error) {
@@ -821,3 +823,45 @@ Bjqn3yoLHaoZVvbWOi0C2TCN4FjXjaLNZGifQPbIcaA=
clientTLSConfig.RootCAs = x509.NewCertPool() clientTLSConfig.RootCAs = x509.NewCertPool()
clientTLSConfig.RootCAs.AddCert(certificate) clientTLSConfig.RootCAs.AddCert(certificate)
} }
func TestWithTimeout(t *testing.T) {
for _, recv := range []bool{true, false} {
for _, defaultTimout := range []time.Duration{0, time.Minute} {
var buf bytes.Buffer
nc := &testConn{Reader: strings.NewReader("+OK\r\n+OK\r\n+OK\r\n+OK\r\n+OK\r\n+OK\r\n+OK\r\n+OK\r\n+OK\r\n+OK\r\n"), Writer: &buf}
c, _ := redis.Dial("", "", redis.DialReadTimeout(defaultTimout), redis.DialNetDial(func(network, addr string) (net.Conn, error) { return nc, nil }))
for i := 0; i < 4; i++ {
var minDeadline, maxDeadline time.Time
// Alternate between default and specified timeout.
if i%2 == 0 {
if defaultTimout != 0 {
minDeadline = time.Now().Add(defaultTimout)
}
if recv {
c.Receive()
} else {
c.Do("PING")
}
if defaultTimout != 0 {
maxDeadline = time.Now().Add(defaultTimout)
}
} else {
timeout := 10 * time.Minute
minDeadline = time.Now().Add(timeout)
if recv {
redis.ReceiveWithTimeout(c, timeout)
} else {
redis.DoWithTimeout(c, timeout, "PING")
}
maxDeadline = time.Now().Add(timeout)
}
// Expect set deadline in expected range.
if nc.readDeadline.Before(minDeadline) || nc.readDeadline.After(maxDeadline) {
t.Errorf("recv %v, %d: do deadline error: %v, %v, %v", recv, i, minDeadline, nc.readDeadline, maxDeadline)
}
}
}
}
}

View File

@@ -4,11 +4,7 @@ package redis
import "crypto/tls" import "crypto/tls"
// similar cloneTLSClientConfig in the stdlib, but also honor skipVerify for the nil case func cloneTLSConfig(cfg *tls.Config) *tls.Config {
func cloneTLSClientConfig(cfg *tls.Config, skipVerify bool) *tls.Config {
if cfg == nil {
return &tls.Config{InsecureSkipVerify: skipVerify}
}
return &tls.Config{ return &tls.Config{
Rand: cfg.Rand, Rand: cfg.Rand,
Time: cfg.Time, Time: cfg.Time,

View File

@@ -1,14 +1,10 @@
// +build go1.7 // +build go1.7,!go1.8
package redis package redis
import "crypto/tls" import "crypto/tls"
// similar cloneTLSClientConfig in the stdlib, but also honor skipVerify for the nil case func cloneTLSConfig(cfg *tls.Config) *tls.Config {
func cloneTLSClientConfig(cfg *tls.Config, skipVerify bool) *tls.Config {
if cfg == nil {
return &tls.Config{InsecureSkipVerify: skipVerify}
}
return &tls.Config{ return &tls.Config{
Rand: cfg.Rand, Rand: cfg.Rand,
Time: cfg.Time, Time: cfg.Time,

9
vendor/github.com/garyburd/redigo/redis/go18.go generated vendored Normal file
View File

@@ -0,0 +1,9 @@
// +build go1.8
package redis
import "crypto/tls"
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
return cfg.Clone()
}

View File

@@ -18,6 +18,11 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"log" "log"
"time"
)
var (
_ ConnWithTimeout = (*loggingConn)(nil)
) )
// NewLoggingConn returns a logging wrapper around a connection. // NewLoggingConn returns a logging wrapper around a connection.
@@ -104,6 +109,12 @@ func (c *loggingConn) Do(commandName string, args ...interface{}) (interface{},
return reply, err return reply, err
} }
func (c *loggingConn) DoWithTimeout(timeout time.Duration, commandName string, args ...interface{}) (interface{}, error) {
reply, err := DoWithTimeout(c.Conn, timeout, commandName, args...)
c.print("DoWithTimeout", commandName, args, reply, err)
return reply, err
}
func (c *loggingConn) Send(commandName string, args ...interface{}) error { func (c *loggingConn) Send(commandName string, args ...interface{}) error {
err := c.Conn.Send(commandName, args...) err := c.Conn.Send(commandName, args...)
c.print("Send", commandName, args, nil, err) c.print("Send", commandName, args, nil, err)
@@ -115,3 +126,9 @@ func (c *loggingConn) Receive() (interface{}, error) {
c.print("Receive", "", nil, reply, err) c.print("Receive", "", nil, reply, err)
return reply, err return reply, err
} }
func (c *loggingConn) ReceiveWithTimeout(timeout time.Duration) (interface{}, error) {
reply, err := ReceiveWithTimeout(c.Conn, timeout)
c.print("ReceiveWithTimeout", "", nil, reply, err)
return reply, err
}

View File

@@ -28,6 +28,11 @@ import (
"github.com/garyburd/redigo/internal" "github.com/garyburd/redigo/internal"
) )
var (
_ ConnWithTimeout = (*pooledConnection)(nil)
_ ConnWithTimeout = (*errorConnection)(nil)
)
var nowFunc = time.Now // for testing var nowFunc = time.Now // for testing
// ErrPoolExhausted is returned from a pool connection method (Do, Send, // ErrPoolExhausted is returned from a pool connection method (Do, Send,
@@ -96,7 +101,7 @@ var (
// return nil, err // return nil, err
// } // }
// return c, nil // return c, nil
// } // },
// } // }
// //
// Use the TestOnBorrow function to check the health of an idle connection // Use the TestOnBorrow function to check the health of an idle connection
@@ -418,6 +423,16 @@ func (pc *pooledConnection) Do(commandName string, args ...interface{}) (reply i
return pc.c.Do(commandName, args...) return pc.c.Do(commandName, args...)
} }
func (pc *pooledConnection) DoWithTimeout(timeout time.Duration, commandName string, args ...interface{}) (reply interface{}, err error) {
cwt, ok := pc.c.(ConnWithTimeout)
if !ok {
return nil, errTimeoutNotSupported
}
ci := internal.LookupCommandInfo(commandName)
pc.state = (pc.state | ci.Set) &^ ci.Clear
return cwt.DoWithTimeout(timeout, commandName, args...)
}
func (pc *pooledConnection) Send(commandName string, args ...interface{}) error { func (pc *pooledConnection) Send(commandName string, args ...interface{}) error {
ci := internal.LookupCommandInfo(commandName) ci := internal.LookupCommandInfo(commandName)
pc.state = (pc.state | ci.Set) &^ ci.Clear pc.state = (pc.state | ci.Set) &^ ci.Clear
@@ -432,11 +447,23 @@ func (pc *pooledConnection) Receive() (reply interface{}, err error) {
return pc.c.Receive() return pc.c.Receive()
} }
func (pc *pooledConnection) ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error) {
cwt, ok := pc.c.(ConnWithTimeout)
if !ok {
return nil, errTimeoutNotSupported
}
return cwt.ReceiveWithTimeout(timeout)
}
type errorConnection struct{ err error } type errorConnection struct{ err error }
func (ec errorConnection) Do(string, ...interface{}) (interface{}, error) { return nil, ec.err } func (ec errorConnection) Do(string, ...interface{}) (interface{}, error) { return nil, ec.err }
func (ec errorConnection) Send(string, ...interface{}) error { return ec.err } func (ec errorConnection) DoWithTimeout(time.Duration, string, ...interface{}) (interface{}, error) {
func (ec errorConnection) Err() error { return ec.err } return nil, ec.err
func (ec errorConnection) Close() error { return ec.err } }
func (ec errorConnection) Flush() error { return ec.err } func (ec errorConnection) Send(string, ...interface{}) error { return ec.err }
func (ec errorConnection) Receive() (interface{}, error) { return nil, ec.err } func (ec errorConnection) Err() error { return ec.err }
func (ec errorConnection) Close() error { return nil }
func (ec errorConnection) Flush() error { return ec.err }
func (ec errorConnection) Receive() (interface{}, error) { return nil, ec.err }
func (ec errorConnection) ReceiveWithTimeout(time.Duration) (interface{}, error) { return nil, ec.err }

View File

@@ -14,7 +14,10 @@
package redis package redis
import "errors" import (
"errors"
"time"
)
// Subscription represents a subscribe or unsubscribe notification. // Subscription represents a subscribe or unsubscribe notification.
type Subscription struct { type Subscription struct {
@@ -103,7 +106,17 @@ func (c PubSubConn) Ping(data string) error {
// or error. The return value is intended to be used directly in a type switch // or error. The return value is intended to be used directly in a type switch
// as illustrated in the PubSubConn example. // as illustrated in the PubSubConn example.
func (c PubSubConn) Receive() interface{} { func (c PubSubConn) Receive() interface{} {
reply, err := Values(c.Conn.Receive()) return c.receiveInternal(c.Conn.Receive())
}
// ReceiveWithTimeout is like Receive, but it allows the application to
// override the connection's default timeout.
func (c PubSubConn) ReceiveWithTimeout(timeout time.Duration) interface{} {
return c.receiveInternal(ReceiveWithTimeout(c.Conn, timeout))
}
func (c PubSubConn) receiveInternal(replyArg interface{}, errArg error) interface{} {
reply, err := Values(replyArg, errArg)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -0,0 +1,165 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
// +build go1.7
package redis_test
import (
"context"
"fmt"
"time"
"github.com/garyburd/redigo/redis"
)
// listenPubSubChannels listens for messages on Redis pubsub channels. The
// onStart function is called after the channels are subscribed. The onMessage
// function is called for each message.
func listenPubSubChannels(ctx context.Context, redisServerAddr string,
onStart func() error,
onMessage func(channel string, data []byte) error,
channels ...string) error {
// A ping is set to the server with this period to test for the health of
// the connection and server.
const healthCheckPeriod = time.Minute
c, err := redis.Dial("tcp", redisServerAddr,
// Read timeout on server should be greater than ping period.
redis.DialReadTimeout(healthCheckPeriod+10*time.Second),
redis.DialWriteTimeout(10*time.Second))
if err != nil {
return err
}
defer c.Close()
psc := redis.PubSubConn{Conn: c}
if err := psc.Subscribe(redis.Args{}.AddFlat(channels)...); err != nil {
return err
}
done := make(chan error, 1)
// Start a goroutine to receive notifications from the server.
go func() {
for {
switch n := psc.Receive().(type) {
case error:
done <- n
return
case redis.Message:
if err := onMessage(n.Channel, n.Data); err != nil {
done <- err
return
}
case redis.Subscription:
switch n.Count {
case len(channels):
// Notify application when all channels are subscribed.
if err := onStart(); err != nil {
done <- err
return
}
case 0:
// Return from the goroutine when all channels are unsubscribed.
done <- nil
return
}
}
}
}()
ticker := time.NewTicker(healthCheckPeriod)
defer ticker.Stop()
loop:
for err == nil {
select {
case <-ticker.C:
// Send ping to test health of connection and server. If
// corresponding pong is not received, then receive on the
// connection will timeout and the receive goroutine will exit.
if err = psc.Ping(""); err != nil {
break loop
}
case <-ctx.Done():
break loop
case err := <-done:
// Return error from the receive goroutine.
return err
}
}
// Signal the receiving goroutine to exit by unsubscribing from all channels.
psc.Unsubscribe()
// Wait for goroutine to complete.
return <-done
}
func publish() {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
c.Do("PUBLISH", "c1", "hello")
c.Do("PUBLISH", "c2", "world")
c.Do("PUBLISH", "c1", "goodbye")
}
// This example shows how receive pubsub notifications with cancelation and
// health checks.
func ExamplePubSubConn() {
redisServerAddr, err := serverAddr()
if err != nil {
fmt.Println(err)
return
}
ctx, cancel := context.WithCancel(context.Background())
err = listenPubSubChannels(ctx,
redisServerAddr,
func() error {
// The start callback is a good place to backfill missed
// notifications. For the purpose of this example, a goroutine is
// started to send notifications.
go publish()
return nil
},
func(channel string, message []byte) error {
fmt.Printf("channel: %s, message: %s\n", channel, message)
// For the purpose of this example, cancel the listener's context
// after receiving last message sent by publish().
if string(message) == "goodbye" {
cancel()
}
return nil
},
"c1", "c2")
if err != nil {
fmt.Println(err)
return
}
// Output:
// channel: c1, message: hello
// channel: c2, message: world
// channel: c1, message: goodbye
}

View File

@@ -15,93 +15,13 @@
package redis_test package redis_test
import ( import (
"fmt"
"reflect" "reflect"
"sync"
"testing" "testing"
"time"
"github.com/garyburd/redigo/redis" "github.com/garyburd/redigo/redis"
) )
func publish(channel, value interface{}) {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
c.Do("PUBLISH", channel, value)
}
// Applications can receive pushed messages from one goroutine and manage subscriptions from another goroutine.
func ExamplePubSubConn() {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
var wg sync.WaitGroup
wg.Add(2)
psc := redis.PubSubConn{Conn: c}
// This goroutine receives and prints pushed notifications from the server.
// The goroutine exits when the connection is unsubscribed from all
// channels or there is an error.
go func() {
defer wg.Done()
for {
switch n := psc.Receive().(type) {
case redis.Message:
fmt.Printf("Message: %s %s\n", n.Channel, n.Data)
case redis.PMessage:
fmt.Printf("PMessage: %s %s %s\n", n.Pattern, n.Channel, n.Data)
case redis.Subscription:
fmt.Printf("Subscription: %s %s %d\n", n.Kind, n.Channel, n.Count)
if n.Count == 0 {
return
}
case error:
fmt.Printf("error: %v\n", n)
return
}
}
}()
// This goroutine manages subscriptions for the connection.
go func() {
defer wg.Done()
psc.Subscribe("example")
psc.PSubscribe("p*")
// The following function calls publish a message using another
// connection to the Redis server.
publish("example", "hello")
publish("example", "world")
publish("pexample", "foo")
publish("pexample", "bar")
// Unsubscribe from all connections. This will cause the receiving
// goroutine to exit.
psc.Unsubscribe()
psc.PUnsubscribe()
}()
wg.Wait()
// Output:
// Subscription: subscribe example 1
// Subscription: psubscribe p* 2
// Message: example hello
// Message: example world
// PMessage: p* pexample foo
// PMessage: p* pexample bar
// Subscription: unsubscribe example 1
// Subscription: punsubscribe p* 0
}
func expectPushed(t *testing.T, c redis.PubSubConn, message string, expected interface{}) { func expectPushed(t *testing.T, c redis.PubSubConn, message string, expected interface{}) {
actual := c.Receive() actual := c.Receive()
if !reflect.DeepEqual(actual, expected) { if !reflect.DeepEqual(actual, expected) {
@@ -145,4 +65,10 @@ func TestPushed(t *testing.T) {
c.Conn.Send("PING") c.Conn.Send("PING")
c.Conn.Flush() c.Conn.Flush()
expectPushed(t, c, `Send("PING")`, redis.Pong{}) expectPushed(t, c, `Send("PING")`, redis.Pong{})
c.Ping("timeout")
got := c.ReceiveWithTimeout(time.Minute)
if want := (redis.Pong{Data: "timeout"}); want != got {
t.Errorf("recv /w timeout got %v, want %v", got, want)
}
} }

View File

@@ -14,6 +14,11 @@
package redis package redis
import (
"errors"
"time"
)
// Error represents an error returned in a command reply. // Error represents an error returned in a command reply.
type Error string type Error string
@@ -59,3 +64,54 @@ type Scanner interface {
// loss of information. // loss of information.
RedisScan(src interface{}) error RedisScan(src interface{}) error
} }
// ConnWithTimeout is an optional interface that allows the caller to override
// a connection's default read timeout. This interface is useful for executing
// the BLPOP, BRPOP, BRPOPLPUSH, XREAD and other commands that block at the
// server.
//
// A connection's default read timeout is set with the DialReadTimeout dial
// option. Applications should rely on the default timeout for commands that do
// not block at the server.
//
// All of the Conn implementations in this package satisfy the ConnWithTimeout
// interface.
//
// Use the DoWithTimeout and ReceiveWithTimeout helper functions to simplify
// use of this interface.
type ConnWithTimeout interface {
Conn
// Do sends a command to the server and returns the received reply.
// The timeout overrides the read timeout set when dialing the
// connection.
DoWithTimeout(timeout time.Duration, commandName string, args ...interface{}) (reply interface{}, err error)
// Receive receives a single reply from the Redis server. The timeout
// overrides the read timeout set when dialing the connection.
ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error)
}
var errTimeoutNotSupported = errors.New("redis: connection does not support ConnWithTimeout")
// DoWithTimeout executes a Redis command with the specified read timeout. If
// the connection does not satisfy the ConnWithTimeout interface, then an error
// is returned.
func DoWithTimeout(c Conn, timeout time.Duration, cmd string, args ...interface{}) (interface{}, error) {
cwt, ok := c.(ConnWithTimeout)
if !ok {
return nil, errTimeoutNotSupported
}
return cwt.DoWithTimeout(timeout, cmd, args...)
}
// ReceiveWithTimeout receives a reply with the specified read timeout. If the
// connection does not satisfy the ConnWithTimeout interface, then an error is
// returned.
func ReceiveWithTimeout(c Conn, timeout time.Duration) (interface{}, error) {
cwt, ok := c.(ConnWithTimeout)
if !ok {
return nil, errTimeoutNotSupported
}
return cwt.ReceiveWithTimeout(timeout)
}

71
vendor/github.com/garyburd/redigo/redis/redis_test.go generated vendored Normal file
View File

@@ -0,0 +1,71 @@
// Copyright 2017 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package redis_test
import (
"testing"
"time"
"github.com/garyburd/redigo/redis"
)
type timeoutTestConn int
func (tc timeoutTestConn) Do(string, ...interface{}) (interface{}, error) {
return time.Duration(-1), nil
}
func (tc timeoutTestConn) DoWithTimeout(timeout time.Duration, cmd string, args ...interface{}) (interface{}, error) {
return timeout, nil
}
func (tc timeoutTestConn) Receive() (interface{}, error) {
return time.Duration(-1), nil
}
func (tc timeoutTestConn) ReceiveWithTimeout(timeout time.Duration) (interface{}, error) {
return timeout, nil
}
func (tc timeoutTestConn) Send(string, ...interface{}) error { return nil }
func (tc timeoutTestConn) Err() error { return nil }
func (tc timeoutTestConn) Close() error { return nil }
func (tc timeoutTestConn) Flush() error { return nil }
func testTimeout(t *testing.T, c redis.Conn) {
r, err := c.Do("PING")
if r != time.Duration(-1) || err != nil {
t.Errorf("Do() = %v, %v, want %v, %v", r, err, time.Duration(-1), nil)
}
r, err = redis.DoWithTimeout(c, time.Minute, "PING")
if r != time.Minute || err != nil {
t.Errorf("DoWithTimeout() = %v, %v, want %v, %v", r, err, time.Minute, nil)
}
r, err = c.Receive()
if r != time.Duration(-1) || err != nil {
t.Errorf("Receive() = %v, %v, want %v, %v", r, err, time.Duration(-1), nil)
}
r, err = redis.ReceiveWithTimeout(c, time.Minute)
if r != time.Minute || err != nil {
t.Errorf("ReceiveWithTimeout() = %v, %v, want %v, %v", r, err, time.Minute, nil)
}
}
func TestConnTimeout(t *testing.T) {
testTimeout(t, timeoutTestConn(0))
}
func TestPoolConnTimeout(t *testing.T) {
p := &redis.Pool{Dial: func() (redis.Conn, error) { return timeoutTestConn(0), nil }}
testTimeout(t, p.Get())
}

View File

@@ -140,6 +140,11 @@ func dial() (redis.Conn, error) {
return redis.DialDefaultServer() return redis.DialDefaultServer()
} }
// serverAddr wraps DefaultServerAddr() with a more suitable function name for examples.
func serverAddr() (string, error) {
return redis.DefaultServerAddr()
}
func ExampleBool() { func ExampleBool() {
c, err := dial() c, err := dial()
if err != nil { if err != nil {

View File

@@ -38,6 +38,7 @@ var (
ErrNegativeInt = errNegativeInt ErrNegativeInt = errNegativeInt
serverPath = flag.String("redis-server", "redis-server", "Path to redis server binary") serverPath = flag.String("redis-server", "redis-server", "Path to redis server binary")
serverAddress = flag.String("redis-address", "127.0.0.1", "The address of the server")
serverBasePort = flag.Int("redis-port", 16379, "Beginning of port range for test servers") serverBasePort = flag.Int("redis-port", 16379, "Beginning of port range for test servers")
serverLogName = flag.String("redis-log", "", "Write Redis server logs to `filename`") serverLogName = flag.String("redis-log", "", "Write Redis server logs to `filename`")
serverLog = ioutil.Discard serverLog = ioutil.Discard
@@ -126,28 +127,32 @@ func stopDefaultServer() {
} }
} }
// startDefaultServer starts the default server if not already running. // DefaultServerAddr starts the test server if not already started and returns
func startDefaultServer() error { // the address of that server.
func DefaultServerAddr() (string, error) {
defaultServerMu.Lock() defaultServerMu.Lock()
defer defaultServerMu.Unlock() defer defaultServerMu.Unlock()
addr := fmt.Sprintf("%v:%d", *serverAddress, *serverBasePort)
if defaultServer != nil || defaultServerErr != nil { if defaultServer != nil || defaultServerErr != nil {
return defaultServerErr return addr, defaultServerErr
} }
defaultServer, defaultServerErr = NewServer( defaultServer, defaultServerErr = NewServer(
"default", "default",
"--port", strconv.Itoa(*serverBasePort), "--port", strconv.Itoa(*serverBasePort),
"--bind", *serverAddress,
"--save", "", "--save", "",
"--appendonly", "no") "--appendonly", "no")
return defaultServerErr return addr, defaultServerErr
} }
// DialDefaultServer starts the test server if not already started and dials a // DialDefaultServer starts the test server if not already started and dials a
// connection to the server. // connection to the server.
func DialDefaultServer() (Conn, error) { func DialDefaultServer() (Conn, error) {
if err := startDefaultServer(); err != nil { addr, err := DefaultServerAddr()
if err != nil {
return nil, err return nil, err
} }
c, err := Dial("tcp", fmt.Sprintf(":%d", *serverBasePort), DialReadTimeout(1*time.Second), DialWriteTimeout(1*time.Second)) c, err := Dial("tcp", addr, DialReadTimeout(1*time.Second), DialWriteTimeout(1*time.Second))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -34,3 +34,5 @@ Each author is responsible of maintaining his own code, although if you submit a
+ [gin-oauth2](https://github.com/zalando/gin-oauth2) - for working with OAuth2 + [gin-oauth2](https://github.com/zalando/gin-oauth2) - for working with OAuth2
+ [static](https://github.com/hyperboloide/static) An alternative static assets handler for the gin framework. + [static](https://github.com/hyperboloide/static) An alternative static assets handler for the gin framework.
+ [xss-mw](https://github.com/dvwright/xss-mw) - XssMw is a middleware designed to "auto remove XSS" from user submitted input + [xss-mw](https://github.com/dvwright/xss-mw) - XssMw is a middleware designed to "auto remove XSS" from user submitted input
+ [gin-helmet](https://github.com/danielkov/gin-helmet) - Collection of simple security middleware.
+ [gin-jwt-session](https://github.com/ScottHuangZL/gin-jwt-session) - middleware to provide JWT/Session/Flashes, easy to use while also provide options for adjust if necessary. Provide sample too.

View File

@@ -10,6 +10,10 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
type loggerEntryWithFields interface {
WithFields(fields logrus.Fields) *logrus.Entry
}
// Ginrus returns a gin.HandlerFunc (middleware) that logs requests using logrus. // Ginrus returns a gin.HandlerFunc (middleware) that logs requests using logrus.
// //
// Requests with errors are logged using logrus.Error(). // Requests with errors are logged using logrus.Error().
@@ -18,7 +22,7 @@ import (
// It receives: // It receives:
// 1. A time package format string (e.g. time.RFC3339). // 1. A time package format string (e.g. time.RFC3339).
// 2. A boolean stating whether to use UTC time zone or local. // 2. A boolean stating whether to use UTC time zone or local.
func Ginrus(logger *logrus.Logger, timeFormat string, utc bool) gin.HandlerFunc { func Ginrus(logger loggerEntryWithFields, timeFormat string, utc bool) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
start := time.Now() start := time.Now()
// some evil middlewares modify this values // some evil middlewares modify this values

View File

@@ -2,6 +2,10 @@ language: go
go: go:
- tip - tip
os:
- linux
- osx
before_install: before_install:
- go get github.com/mattn/goveralls - go get github.com/mattn/goveralls
- go get golang.org/x/tools/cmd/cover - go get golang.org/x/tools/cmd/cover

View File

@@ -3,7 +3,7 @@
package isatty package isatty
// IsCygwinTerminal() return true if the file descriptor is a cygwin or msys2 // IsCygwinTerminal return true if the file descriptor is a cygwin or msys2
// terminal. This is also always false on this environment. // terminal. This is also always false on this environment.
func IsCygwinTerminal(fd uintptr) bool { func IsCygwinTerminal(fd uintptr) bool {
return false return false

View File

@@ -946,7 +946,7 @@ func TestNonce_add(t *testing.T) {
c.addNonce(http.Header{"Replay-Nonce": {}}) c.addNonce(http.Header{"Replay-Nonce": {}})
c.addNonce(http.Header{"Replay-Nonce": {"nonce"}}) c.addNonce(http.Header{"Replay-Nonce": {"nonce"}})
nonces := map[string]struct{}{"nonce": struct{}{}} nonces := map[string]struct{}{"nonce": {}}
if !reflect.DeepEqual(c.nonces, nonces) { if !reflect.DeepEqual(c.nonces, nonces) {
t.Errorf("c.nonces = %q; want %q", c.nonces, nonces) t.Errorf("c.nonces = %q; want %q", c.nonces, nonces)
} }

View File

@@ -24,7 +24,9 @@ import (
"fmt" "fmt"
"io" "io"
mathrand "math/rand" mathrand "math/rand"
"net"
"net/http" "net/http"
"path"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@@ -80,8 +82,9 @@ func defaultHostPolicy(context.Context, string) error {
} }
// Manager is a stateful certificate manager built on top of acme.Client. // Manager is a stateful certificate manager built on top of acme.Client.
// It obtains and refreshes certificates automatically, // It obtains and refreshes certificates automatically using "tls-sni-01",
// as well as providing them to a TLS server via tls.Config. // "tls-sni-02" and "http-01" challenge types, as well as providing them
// to a TLS server via tls.Config.
// //
// You must specify a cache implementation, such as DirCache, // You must specify a cache implementation, such as DirCache,
// to reuse obtained certificates across program restarts. // to reuse obtained certificates across program restarts.
@@ -150,15 +153,26 @@ type Manager struct {
stateMu sync.Mutex stateMu sync.Mutex
state map[string]*certState // keyed by domain name state map[string]*certState // keyed by domain name
// tokenCert is keyed by token domain name, which matches server name
// of ClientHello. Keys always have ".acme.invalid" suffix.
tokenCertMu sync.RWMutex
tokenCert map[string]*tls.Certificate
// renewal tracks the set of domains currently running renewal timers. // renewal tracks the set of domains currently running renewal timers.
// It is keyed by domain name. // It is keyed by domain name.
renewalMu sync.Mutex renewalMu sync.Mutex
renewal map[string]*domainRenewal renewal map[string]*domainRenewal
// tokensMu guards the rest of the fields: tryHTTP01, certTokens and httpTokens.
tokensMu sync.RWMutex
// tryHTTP01 indicates whether the Manager should try "http-01" challenge type
// during the authorization flow.
tryHTTP01 bool
// httpTokens contains response body values for http-01 challenges
// and is keyed by the URL path at which a challenge response is expected
// to be provisioned.
// The entries are stored for the duration of the authorization flow.
httpTokens map[string][]byte
// certTokens contains temporary certificates for tls-sni challenges
// and is keyed by token domain name, which matches server name of ClientHello.
// Keys always have ".acme.invalid" suffix.
// The entries are stored for the duration of the authorization flow.
certTokens map[string]*tls.Certificate
} }
// GetCertificate implements the tls.Config.GetCertificate hook. // GetCertificate implements the tls.Config.GetCertificate hook.
@@ -185,14 +199,16 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
return nil, errors.New("acme/autocert: server name contains invalid character") return nil, errors.New("acme/autocert: server name contains invalid character")
} }
// In the worst-case scenario, the timeout needs to account for caching, host policy,
// domain ownership verification and certificate issuance.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel() defer cancel()
// check whether this is a token cert requested for TLS-SNI challenge // check whether this is a token cert requested for TLS-SNI challenge
if strings.HasSuffix(name, ".acme.invalid") { if strings.HasSuffix(name, ".acme.invalid") {
m.tokenCertMu.RLock() m.tokensMu.RLock()
defer m.tokenCertMu.RUnlock() defer m.tokensMu.RUnlock()
if cert := m.tokenCert[name]; cert != nil { if cert := m.certTokens[name]; cert != nil {
return cert, nil return cert, nil
} }
if cert, err := m.cacheGet(ctx, name); err == nil { if cert, err := m.cacheGet(ctx, name); err == nil {
@@ -224,6 +240,68 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
return cert, nil return cert, nil
} }
// HTTPHandler configures the Manager to provision ACME "http-01" challenge responses.
// It returns an http.Handler that responds to the challenges and must be
// running on port 80. If it receives a request that is not an ACME challenge,
// it delegates the request to the optional fallback handler.
//
// If fallback is nil, the returned handler redirects all GET and HEAD requests
// to the default TLS port 443 with 302 Found status code, preserving the original
// request path and query. It responds with 400 Bad Request to all other HTTP methods.
// The fallback is not protected by the optional HostPolicy.
//
// Because the fallback handler is run with unencrypted port 80 requests,
// the fallback should not serve TLS-only requests.
//
// If HTTPHandler is never called, the Manager will only use TLS SNI
// challenges for domain verification.
func (m *Manager) HTTPHandler(fallback http.Handler) http.Handler {
m.tokensMu.Lock()
defer m.tokensMu.Unlock()
m.tryHTTP01 = true
if fallback == nil {
fallback = http.HandlerFunc(handleHTTPRedirect)
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.URL.Path, "/.well-known/acme-challenge/") {
fallback.ServeHTTP(w, r)
return
}
// A reasonable context timeout for cache and host policy only,
// because we don't wait for a new certificate issuance here.
ctx, cancel := context.WithTimeout(r.Context(), time.Minute)
defer cancel()
if err := m.hostPolicy()(ctx, r.Host); err != nil {
http.Error(w, err.Error(), http.StatusForbidden)
return
}
data, err := m.httpToken(ctx, r.URL.Path)
if err != nil {
http.Error(w, err.Error(), http.StatusNotFound)
return
}
w.Write(data)
})
}
func handleHTTPRedirect(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" && r.Method != "HEAD" {
http.Error(w, "Use HTTPS", http.StatusBadRequest)
return
}
target := "https://" + stripPort(r.Host) + r.URL.RequestURI()
http.Redirect(w, r, target, http.StatusFound)
}
func stripPort(hostport string) string {
host, _, err := net.SplitHostPort(hostport)
if err != nil {
return hostport
}
return net.JoinHostPort(host, "443")
}
// cert returns an existing certificate either from m.state or cache. // cert returns an existing certificate either from m.state or cache.
// If a certificate is found in cache but not in m.state, the latter will be filled // If a certificate is found in cache but not in m.state, the latter will be filled
// with the cached value. // with the cached value.
@@ -442,13 +520,14 @@ func (m *Manager) certState(domain string) (*certState, error) {
// authorizedCert starts the domain ownership verification process and requests a new cert upon success. // authorizedCert starts the domain ownership verification process and requests a new cert upon success.
// The key argument is the certificate private key. // The key argument is the certificate private key.
func (m *Manager) authorizedCert(ctx context.Context, key crypto.Signer, domain string) (der [][]byte, leaf *x509.Certificate, err error) { func (m *Manager) authorizedCert(ctx context.Context, key crypto.Signer, domain string) (der [][]byte, leaf *x509.Certificate, err error) {
if err := m.verify(ctx, domain); err != nil {
return nil, nil, err
}
client, err := m.acmeClient(ctx) client, err := m.acmeClient(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if err := m.verify(ctx, client, domain); err != nil {
return nil, nil, err
}
csr, err := certRequest(key, domain) csr, err := certRequest(key, domain)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@@ -464,98 +543,171 @@ func (m *Manager) authorizedCert(ctx context.Context, key crypto.Signer, domain
return der, leaf, nil return der, leaf, nil
} }
// verify starts a new identifier (domain) authorization flow. // verify runs the identifier (domain) authorization flow
// It prepares a challenge response and then blocks until the authorization // using each applicable ACME challenge type.
// is marked as "completed" by the CA (either succeeded or failed). func (m *Manager) verify(ctx context.Context, client *acme.Client, domain string) error {
// // The list of challenge types we'll try to fulfill
// verify returns nil iff the verification was successful. // in this specific order.
func (m *Manager) verify(ctx context.Context, domain string) error { challengeTypes := []string{"tls-sni-02", "tls-sni-01"}
client, err := m.acmeClient(ctx) m.tokensMu.RLock()
if err != nil { if m.tryHTTP01 {
return err challengeTypes = append(challengeTypes, "http-01")
} }
m.tokensMu.RUnlock()
// start domain authorization and get the challenge var nextTyp int // challengeType index of the next challenge type to try
authz, err := client.Authorize(ctx, domain) for {
if err != nil { // Start domain authorization and get the challenge.
return err authz, err := client.Authorize(ctx, domain)
} if err != nil {
// maybe don't need to at all return err
if authz.Status == acme.StatusValid {
return nil
}
// pick a challenge: prefer tls-sni-02 over tls-sni-01
// TODO: consider authz.Combinations
var chal *acme.Challenge
for _, c := range authz.Challenges {
if c.Type == "tls-sni-02" {
chal = c
break
} }
if c.Type == "tls-sni-01" { // No point in accepting challenges if the authorization status
chal = c // is in a final state.
switch authz.Status {
case acme.StatusValid:
return nil // already authorized
case acme.StatusInvalid:
return fmt.Errorf("acme/autocert: invalid authorization %q", authz.URI)
}
// Pick the next preferred challenge.
var chal *acme.Challenge
for chal == nil && nextTyp < len(challengeTypes) {
chal = pickChallenge(challengeTypes[nextTyp], authz.Challenges)
nextTyp++
}
if chal == nil {
return fmt.Errorf("acme/autocert: unable to authorize %q; tried %q", domain, challengeTypes)
}
cleanup, err := m.fulfill(ctx, client, chal)
if err != nil {
continue
}
defer cleanup()
if _, err := client.Accept(ctx, chal); err != nil {
continue
}
// A challenge is fulfilled and accepted: wait for the CA to validate.
if _, err := client.WaitAuthorization(ctx, authz.URI); err == nil {
return nil
} }
} }
if chal == nil {
return errors.New("acme/autocert: no supported challenge type found")
}
// create a token cert for the challenge response
var (
cert tls.Certificate
name string
)
switch chal.Type {
case "tls-sni-01":
cert, name, err = client.TLSSNI01ChallengeCert(chal.Token)
case "tls-sni-02":
cert, name, err = client.TLSSNI02ChallengeCert(chal.Token)
default:
err = fmt.Errorf("acme/autocert: unknown challenge type %q", chal.Type)
}
if err != nil {
return err
}
m.putTokenCert(ctx, name, &cert)
defer func() {
// verification has ended at this point
// don't need token cert anymore
go m.deleteTokenCert(name)
}()
// ready to fulfill the challenge
if _, err := client.Accept(ctx, chal); err != nil {
return err
}
// wait for the CA to validate
_, err = client.WaitAuthorization(ctx, authz.URI)
return err
} }
// putTokenCert stores the cert under the named key in both m.tokenCert map // fulfill provisions a response to the challenge chal.
// and m.Cache. // The cleanup is non-nil only if provisioning succeeded.
func (m *Manager) putTokenCert(ctx context.Context, name string, cert *tls.Certificate) { func (m *Manager) fulfill(ctx context.Context, client *acme.Client, chal *acme.Challenge) (cleanup func(), err error) {
m.tokenCertMu.Lock() switch chal.Type {
defer m.tokenCertMu.Unlock() case "tls-sni-01":
if m.tokenCert == nil { cert, name, err := client.TLSSNI01ChallengeCert(chal.Token)
m.tokenCert = make(map[string]*tls.Certificate) if err != nil {
return nil, err
}
m.putCertToken(ctx, name, &cert)
return func() { go m.deleteCertToken(name) }, nil
case "tls-sni-02":
cert, name, err := client.TLSSNI02ChallengeCert(chal.Token)
if err != nil {
return nil, err
}
m.putCertToken(ctx, name, &cert)
return func() { go m.deleteCertToken(name) }, nil
case "http-01":
resp, err := client.HTTP01ChallengeResponse(chal.Token)
if err != nil {
return nil, err
}
p := client.HTTP01ChallengePath(chal.Token)
m.putHTTPToken(ctx, p, resp)
return func() { go m.deleteHTTPToken(p) }, nil
} }
m.tokenCert[name] = cert return nil, fmt.Errorf("acme/autocert: unknown challenge type %q", chal.Type)
}
func pickChallenge(typ string, chal []*acme.Challenge) *acme.Challenge {
for _, c := range chal {
if c.Type == typ {
return c
}
}
return nil
}
// putCertToken stores the cert under the named key in both m.certTokens map
// and m.Cache.
func (m *Manager) putCertToken(ctx context.Context, name string, cert *tls.Certificate) {
m.tokensMu.Lock()
defer m.tokensMu.Unlock()
if m.certTokens == nil {
m.certTokens = make(map[string]*tls.Certificate)
}
m.certTokens[name] = cert
m.cachePut(ctx, name, cert) m.cachePut(ctx, name, cert)
} }
// deleteTokenCert removes the token certificate for the specified domain name // deleteCertToken removes the token certificate for the specified domain name
// from both m.tokenCert map and m.Cache. // from both m.certTokens map and m.Cache.
func (m *Manager) deleteTokenCert(name string) { func (m *Manager) deleteCertToken(name string) {
m.tokenCertMu.Lock() m.tokensMu.Lock()
defer m.tokenCertMu.Unlock() defer m.tokensMu.Unlock()
delete(m.tokenCert, name) delete(m.certTokens, name)
if m.Cache != nil { if m.Cache != nil {
m.Cache.Delete(context.Background(), name) m.Cache.Delete(context.Background(), name)
} }
} }
// httpToken retrieves an existing http-01 token value from an in-memory map
// or the optional cache.
func (m *Manager) httpToken(ctx context.Context, tokenPath string) ([]byte, error) {
m.tokensMu.RLock()
defer m.tokensMu.RUnlock()
if v, ok := m.httpTokens[tokenPath]; ok {
return v, nil
}
if m.Cache == nil {
return nil, fmt.Errorf("acme/autocert: no token at %q", tokenPath)
}
return m.Cache.Get(ctx, httpTokenCacheKey(tokenPath))
}
// putHTTPToken stores an http-01 token value using tokenPath as key
// in both in-memory map and the optional Cache.
//
// It ignores any error returned from Cache.Put.
func (m *Manager) putHTTPToken(ctx context.Context, tokenPath, val string) {
m.tokensMu.Lock()
defer m.tokensMu.Unlock()
if m.httpTokens == nil {
m.httpTokens = make(map[string][]byte)
}
b := []byte(val)
m.httpTokens[tokenPath] = b
if m.Cache != nil {
m.Cache.Put(ctx, httpTokenCacheKey(tokenPath), b)
}
}
// deleteHTTPToken removes an http-01 token value from both in-memory map
// and the optional Cache, ignoring any error returned from the latter.
//
// If m.Cache is non-nil, it blocks until Cache.Delete returns without a timeout.
func (m *Manager) deleteHTTPToken(tokenPath string) {
m.tokensMu.Lock()
defer m.tokensMu.Unlock()
delete(m.httpTokens, tokenPath)
if m.Cache != nil {
m.Cache.Delete(context.Background(), httpTokenCacheKey(tokenPath))
}
}
// httpTokenCacheKey returns a key at which an http-01 token value may be stored
// in the Manager's optional Cache.
func httpTokenCacheKey(tokenPath string) string {
return "http-01-" + path.Base(tokenPath)
}
// renew starts a cert renewal timer loop, one per domain. // renew starts a cert renewal timer loop, one per domain.
// //
// The loop is scheduled in two cases: // The loop is scheduled in two cases:

View File

@@ -23,6 +23,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect" "reflect"
"strings"
"sync" "sync"
"testing" "testing"
"time" "time"
@@ -48,6 +49,16 @@ var authzTmpl = template.Must(template.New("authz").Parse(`{
"uri": "{{.}}/challenge/2", "uri": "{{.}}/challenge/2",
"type": "tls-sni-02", "type": "tls-sni-02",
"token": "token-02" "token": "token-02"
},
{
"uri": "{{.}}/challenge/dns-01",
"type": "dns-01",
"token": "token-dns-01"
},
{
"uri": "{{.}}/challenge/http-01",
"type": "http-01",
"token": "token-http-01"
} }
] ]
}`)) }`))
@@ -419,6 +430,146 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl
} }
func TestVerifyHTTP01(t *testing.T) {
var (
http01 http.Handler
authzCount int // num. of created authorizations
didAcceptHTTP01 bool
)
verifyHTTPToken := func() {
r := httptest.NewRequest("GET", "/.well-known/acme-challenge/token-http-01", nil)
w := httptest.NewRecorder()
http01.ServeHTTP(w, r)
if w.Code != http.StatusOK {
t.Errorf("http token: w.Code = %d; want %d", w.Code, http.StatusOK)
}
if v := string(w.Body.Bytes()); !strings.HasPrefix(v, "token-http-01.") {
t.Errorf("http token value = %q; want 'token-http-01.' prefix", v)
}
}
// ACME CA server stub, only the needed bits.
// TODO: Merge this with startACMEServerStub, making it a configurable CA for testing.
var ca *httptest.Server
ca = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Replay-Nonce", "nonce")
if r.Method == "HEAD" {
// a nonce request
return
}
switch r.URL.Path {
// Discovery.
case "/":
if err := discoTmpl.Execute(w, ca.URL); err != nil {
t.Errorf("discoTmpl: %v", err)
}
// Client key registration.
case "/new-reg":
w.Write([]byte("{}"))
// New domain authorization.
case "/new-authz":
authzCount++
w.Header().Set("Location", fmt.Sprintf("%s/authz/%d", ca.URL, authzCount))
w.WriteHeader(http.StatusCreated)
if err := authzTmpl.Execute(w, ca.URL); err != nil {
t.Errorf("authzTmpl: %v", err)
}
// Accept tls-sni-02.
case "/challenge/2":
w.Write([]byte("{}"))
// Reject tls-sni-01.
case "/challenge/1":
http.Error(w, "won't accept tls-sni-01", http.StatusBadRequest)
// Should not accept dns-01.
case "/challenge/dns-01":
t.Errorf("dns-01 challenge was accepted")
http.Error(w, "won't accept dns-01", http.StatusBadRequest)
// Accept http-01.
case "/challenge/http-01":
didAcceptHTTP01 = true
verifyHTTPToken()
w.Write([]byte("{}"))
// Authorization statuses.
// Make tls-sni-xxx invalid.
case "/authz/1", "/authz/2":
w.Write([]byte(`{"status": "invalid"}`))
case "/authz/3", "/authz/4":
w.Write([]byte(`{"status": "valid"}`))
default:
http.NotFound(w, r)
t.Errorf("unrecognized r.URL.Path: %s", r.URL.Path)
}
}))
defer ca.Close()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
m := &Manager{
Client: &acme.Client{
Key: key,
DirectoryURL: ca.URL,
},
}
http01 = m.HTTPHandler(nil)
if err := m.verify(context.Background(), m.Client, "example.org"); err != nil {
t.Errorf("m.verify: %v", err)
}
// Only tls-sni-01, tls-sni-02 and http-01 must be accepted
// The dns-01 challenge is unsupported.
if authzCount != 3 {
t.Errorf("authzCount = %d; want 3", authzCount)
}
if !didAcceptHTTP01 {
t.Error("did not accept http-01 challenge")
}
}
func TestHTTPHandlerDefaultFallback(t *testing.T) {
tt := []struct {
method, url string
wantCode int
wantLocation string
}{
{"GET", "http://example.org", 302, "https://example.org/"},
{"GET", "http://example.org/foo", 302, "https://example.org/foo"},
{"GET", "http://example.org/foo/bar/", 302, "https://example.org/foo/bar/"},
{"GET", "http://example.org/?a=b", 302, "https://example.org/?a=b"},
{"GET", "http://example.org/foo?a=b", 302, "https://example.org/foo?a=b"},
{"GET", "http://example.org:80/foo?a=b", 302, "https://example.org:443/foo?a=b"},
{"GET", "http://example.org:80/foo%20bar", 302, "https://example.org:443/foo%20bar"},
{"GET", "http://[2602:d1:xxxx::c60a]:1234", 302, "https://[2602:d1:xxxx::c60a]:443/"},
{"GET", "http://[2602:d1:xxxx::c60a]", 302, "https://[2602:d1:xxxx::c60a]/"},
{"GET", "http://[2602:d1:xxxx::c60a]/foo?a=b", 302, "https://[2602:d1:xxxx::c60a]/foo?a=b"},
{"HEAD", "http://example.org", 302, "https://example.org/"},
{"HEAD", "http://example.org/foo", 302, "https://example.org/foo"},
{"HEAD", "http://example.org/foo/bar/", 302, "https://example.org/foo/bar/"},
{"HEAD", "http://example.org/?a=b", 302, "https://example.org/?a=b"},
{"HEAD", "http://example.org/foo?a=b", 302, "https://example.org/foo?a=b"},
{"POST", "http://example.org", 400, ""},
{"PUT", "http://example.org", 400, ""},
{"GET", "http://example.org/.well-known/acme-challenge/x", 404, ""},
}
var m Manager
h := m.HTTPHandler(nil)
for i, test := range tt {
r := httptest.NewRequest(test.method, test.url, nil)
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
if w.Code != test.wantCode {
t.Errorf("%d: w.Code = %d; want %d", i, w.Code, test.wantCode)
t.Errorf("%d: body: %s", i, w.Body.Bytes())
}
if v := w.Header().Get("Location"); v != test.wantLocation {
t.Errorf("%d: Location = %q; want %q", i, v, test.wantLocation)
}
}
}
func TestAccountKeyCache(t *testing.T) { func TestAccountKeyCache(t *testing.T) {
m := Manager{Cache: newMemCache()} m := Manager{Cache: newMemCache()}
ctx := context.Background() ctx := context.Background()

View File

@@ -22,11 +22,12 @@ func ExampleNewListener() {
} }
func ExampleManager() { func ExampleManager() {
m := autocert.Manager{ m := &autocert.Manager{
Cache: autocert.DirCache("secret-dir"), Cache: autocert.DirCache("secret-dir"),
Prompt: autocert.AcceptTOS, Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist("example.org"), HostPolicy: autocert.HostWhitelist("example.org"),
} }
go http.ListenAndServe(":http", m.HTTPHandler(nil))
s := &http.Server{ s := &http.Server{
Addr: ":https", Addr: ":https",
TLSConfig: &tls.Config{GetCertificate: m.GetCertificate}, TLSConfig: &tls.Config{GetCertificate: m.GetCertificate},

228
vendor/golang.org/x/crypto/argon2/argon2.go generated vendored Normal file
View File

@@ -0,0 +1,228 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package argon2 implements the key derivation function Argon2.
// Argon2 was selected as the winner of the Password Hashing Competition and can
// be used to derive cryptographic keys from passwords.
// Argon2 is specfifed at https://github.com/P-H-C/phc-winner-argon2/blob/master/argon2-specs.pdf
package argon2
import (
"encoding/binary"
"sync"
"golang.org/x/crypto/blake2b"
)
// The Argon2 version implemented by this package.
const Version = 0x13
const (
argon2d = iota
argon2i
argon2id
)
// Key derives a key from the password, salt, and cost parameters using Argon2i
// returning a byte slice of length keyLen that can be used as cryptographic key.
// The CPU cost and parallism degree must be greater than zero.
//
// For example, you can get a derived key for e.g. AES-256 (which needs a 32-byte key) by doing:
// `key := argon2.Key([]byte("some password"), salt, 4, 32*1024, 4, 32)`
//
// The recommended parameters for interactive logins as of 2017 are time=4, memory=32*1024.
// The number of threads can be adjusted to the numbers of available CPUs.
// The time parameter specifies the number of passes over the memory and the memory
// parameter specifies the size of the memory in KiB. For example memory=32*1024 sets the
// memory cost to ~32 MB.
// The cost parameters should be increased as memory latency and CPU parallelism increases.
// Remember to get a good random salt.
func Key(password, salt []byte, time, memory uint32, threads uint8, keyLen uint32) []byte {
return deriveKey(argon2i, password, salt, nil, nil, time, memory, threads, keyLen)
}
func deriveKey(mode int, password, salt, secret, data []byte, time, memory uint32, threads uint8, keyLen uint32) []byte {
if time < 1 {
panic("argon2: number of rounds too small")
}
if threads < 1 {
panic("argon2: parallelism degree too low")
}
h0 := initHash(password, salt, secret, data, time, memory, uint32(threads), keyLen, mode)
memory = memory / (syncPoints * uint32(threads)) * (syncPoints * uint32(threads))
if memory < 2*syncPoints*uint32(threads) {
memory = 2 * syncPoints * uint32(threads)
}
B := initBlocks(&h0, memory, uint32(threads))
processBlocks(B, time, memory, uint32(threads), mode)
return extractKey(B, memory, uint32(threads), keyLen)
}
const (
blockLength = 128
syncPoints = 4
)
type block [blockLength]uint64
func initHash(password, salt, key, data []byte, time, memory, threads, keyLen uint32, mode int) [blake2b.Size + 8]byte {
var (
h0 [blake2b.Size + 8]byte
params [24]byte
tmp [4]byte
)
b2, _ := blake2b.New512(nil)
binary.LittleEndian.PutUint32(params[0:4], threads)
binary.LittleEndian.PutUint32(params[4:8], keyLen)
binary.LittleEndian.PutUint32(params[8:12], memory)
binary.LittleEndian.PutUint32(params[12:16], time)
binary.LittleEndian.PutUint32(params[16:20], uint32(Version))
binary.LittleEndian.PutUint32(params[20:24], uint32(mode))
b2.Write(params[:])
binary.LittleEndian.PutUint32(tmp[:], uint32(len(password)))
b2.Write(tmp[:])
b2.Write(password)
binary.LittleEndian.PutUint32(tmp[:], uint32(len(salt)))
b2.Write(tmp[:])
b2.Write(salt)
binary.LittleEndian.PutUint32(tmp[:], uint32(len(key)))
b2.Write(tmp[:])
b2.Write(key)
binary.LittleEndian.PutUint32(tmp[:], uint32(len(data)))
b2.Write(tmp[:])
b2.Write(data)
b2.Sum(h0[:0])
return h0
}
func initBlocks(h0 *[blake2b.Size + 8]byte, memory, threads uint32) []block {
var block0 [1024]byte
B := make([]block, memory)
for lane := uint32(0); lane < threads; lane++ {
j := lane * (memory / threads)
binary.LittleEndian.PutUint32(h0[blake2b.Size+4:], lane)
binary.LittleEndian.PutUint32(h0[blake2b.Size:], 0)
blake2bHash(block0[:], h0[:])
for i := range B[j+0] {
B[j+0][i] = binary.LittleEndian.Uint64(block0[i*8:])
}
binary.LittleEndian.PutUint32(h0[blake2b.Size:], 1)
blake2bHash(block0[:], h0[:])
for i := range B[j+1] {
B[j+1][i] = binary.LittleEndian.Uint64(block0[i*8:])
}
}
return B
}
func processBlocks(B []block, time, memory, threads uint32, mode int) {
lanes := memory / threads
segments := lanes / syncPoints
processSegment := func(n, slice, lane uint32, wg *sync.WaitGroup) {
var addresses, in, zero block
if mode == argon2i || (mode == argon2id && n == 0 && slice < syncPoints/2) {
in[0] = uint64(n)
in[1] = uint64(lane)
in[2] = uint64(slice)
in[3] = uint64(memory)
in[4] = uint64(time)
in[5] = uint64(mode)
}
index := uint32(0)
if n == 0 && slice == 0 {
index = 2 // we have already generated the first two blocks
if mode == argon2i || mode == argon2id {
in[6]++
processBlock(&addresses, &in, &zero)
processBlock(&addresses, &addresses, &zero)
}
}
offset := lane*lanes + slice*segments + index
var random uint64
for index < segments {
prev := offset - 1
if index == 0 && slice == 0 {
prev += lanes // last block in lane
}
if mode == argon2i || (mode == argon2id && n == 0 && slice < syncPoints/2) {
if index%blockLength == 0 {
in[6]++
processBlock(&addresses, &in, &zero)
processBlock(&addresses, &addresses, &zero)
}
random = addresses[index%blockLength]
} else {
random = B[prev][0]
}
newOffset := indexAlpha(random, lanes, segments, threads, n, slice, lane, index)
processBlockXOR(&B[offset], &B[prev], &B[newOffset])
index, offset = index+1, offset+1
}
wg.Done()
}
for n := uint32(0); n < time; n++ {
for slice := uint32(0); slice < syncPoints; slice++ {
var wg sync.WaitGroup
for lane := uint32(0); lane < threads; lane++ {
wg.Add(1)
go processSegment(n, slice, lane, &wg)
}
wg.Wait()
}
}
}
func extractKey(B []block, memory, threads, keyLen uint32) []byte {
lanes := memory / threads
for lane := uint32(0); lane < threads-1; lane++ {
for i, v := range B[(lane*lanes)+lanes-1] {
B[memory-1][i] ^= v
}
}
var block [1024]byte
for i, v := range B[memory-1] {
binary.LittleEndian.PutUint64(block[i*8:], v)
}
key := make([]byte, keyLen)
blake2bHash(key, block[:])
return key
}
func indexAlpha(rand uint64, lanes, segments, threads, n, slice, lane, index uint32) uint32 {
refLane := uint32(rand>>32) % threads
if n == 0 && slice == 0 {
refLane = lane
}
m, s := 3*segments, ((slice+1)%syncPoints)*segments
if lane == refLane {
m += index
}
if n == 0 {
m, s = slice*segments, 0
if slice == 0 || lane == refLane {
m += index
}
}
if index == 0 || lane == refLane {
m--
}
return phi(rand, uint64(m), uint64(s), refLane, lanes)
}
func phi(rand, m, s uint64, lane, lanes uint32) uint32 {
p := rand & 0xFFFFFFFF
p = (p * p) >> 32
p = (p * m) >> 32
return lane*lanes + uint32((s+m-(p+1))%uint64(lanes))
}

233
vendor/golang.org/x/crypto/argon2/argon2_test.go generated vendored Normal file
View File

@@ -0,0 +1,233 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package argon2
import (
"bytes"
"encoding/hex"
"testing"
)
var (
genKatPassword = []byte{
0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
}
genKatSalt = []byte{0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02}
genKatSecret = []byte{0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03}
genKatAAD = []byte{0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04}
)
func TestArgon2(t *testing.T) {
defer func(sse4 bool) { useSSE4 = sse4 }(useSSE4)
if useSSE4 {
t.Log("SSE4.1 version")
testArgon2i(t)
testArgon2d(t)
testArgon2id(t)
useSSE4 = false
}
t.Log("generic version")
testArgon2i(t)
testArgon2d(t)
testArgon2id(t)
}
func testArgon2d(t *testing.T) {
want := []byte{
0x51, 0x2b, 0x39, 0x1b, 0x6f, 0x11, 0x62, 0x97,
0x53, 0x71, 0xd3, 0x09, 0x19, 0x73, 0x42, 0x94,
0xf8, 0x68, 0xe3, 0xbe, 0x39, 0x84, 0xf3, 0xc1,
0xa1, 0x3a, 0x4d, 0xb9, 0xfa, 0xbe, 0x4a, 0xcb,
}
hash := deriveKey(argon2d, genKatPassword, genKatSalt, genKatSecret, genKatAAD, 3, 32, 4, 32)
if !bytes.Equal(hash, want) {
t.Errorf("derived key does not match - got: %s , want: %s", hex.EncodeToString(hash), hex.EncodeToString(want))
}
}
func testArgon2i(t *testing.T) {
want := []byte{
0xc8, 0x14, 0xd9, 0xd1, 0xdc, 0x7f, 0x37, 0xaa,
0x13, 0xf0, 0xd7, 0x7f, 0x24, 0x94, 0xbd, 0xa1,
0xc8, 0xde, 0x6b, 0x01, 0x6d, 0xd3, 0x88, 0xd2,
0x99, 0x52, 0xa4, 0xc4, 0x67, 0x2b, 0x6c, 0xe8,
}
hash := deriveKey(argon2i, genKatPassword, genKatSalt, genKatSecret, genKatAAD, 3, 32, 4, 32)
if !bytes.Equal(hash, want) {
t.Errorf("derived key does not match - got: %s , want: %s", hex.EncodeToString(hash), hex.EncodeToString(want))
}
}
func testArgon2id(t *testing.T) {
want := []byte{
0x0d, 0x64, 0x0d, 0xf5, 0x8d, 0x78, 0x76, 0x6c,
0x08, 0xc0, 0x37, 0xa3, 0x4a, 0x8b, 0x53, 0xc9,
0xd0, 0x1e, 0xf0, 0x45, 0x2d, 0x75, 0xb6, 0x5e,
0xb5, 0x25, 0x20, 0xe9, 0x6b, 0x01, 0xe6, 0x59,
}
hash := deriveKey(argon2id, genKatPassword, genKatSalt, genKatSecret, genKatAAD, 3, 32, 4, 32)
if !bytes.Equal(hash, want) {
t.Errorf("derived key does not match - got: %s , want: %s", hex.EncodeToString(hash), hex.EncodeToString(want))
}
}
func TestVectors(t *testing.T) {
password, salt := []byte("password"), []byte("somesalt")
for i, v := range testVectors {
want, err := hex.DecodeString(v.hash)
if err != nil {
t.Fatalf("Test %d: failed to decode hash: %v", i, err)
}
hash := deriveKey(v.mode, password, salt, nil, nil, v.time, v.memory, v.threads, uint32(len(want)))
if !bytes.Equal(hash, want) {
t.Errorf("Test %d - got: %s want: %s", i, hex.EncodeToString(hash), hex.EncodeToString(want))
}
}
}
func benchmarkArgon2(mode int, time, memory uint32, threads uint8, keyLen uint32, b *testing.B) {
password := []byte("password")
salt := []byte("choosing random salts is hard")
b.ReportAllocs()
for i := 0; i < b.N; i++ {
deriveKey(mode, password, salt, nil, nil, time, memory, threads, keyLen)
}
}
func BenchmarkArgon2i(b *testing.B) {
b.Run(" Time: 3 Memory: 32 MB, Threads: 1", func(b *testing.B) { benchmarkArgon2(argon2i, 3, 32*1024, 1, 32, b) })
b.Run(" Time: 4 Memory: 32 MB, Threads: 1", func(b *testing.B) { benchmarkArgon2(argon2i, 4, 32*1024, 1, 32, b) })
b.Run(" Time: 5 Memory: 32 MB, Threads: 1", func(b *testing.B) { benchmarkArgon2(argon2i, 5, 32*1024, 1, 32, b) })
b.Run(" Time: 3 Memory: 64 MB, Threads: 4", func(b *testing.B) { benchmarkArgon2(argon2i, 3, 64*1024, 4, 32, b) })
b.Run(" Time: 4 Memory: 64 MB, Threads: 4", func(b *testing.B) { benchmarkArgon2(argon2i, 4, 64*1024, 4, 32, b) })
b.Run(" Time: 5 Memory: 64 MB, Threads: 4", func(b *testing.B) { benchmarkArgon2(argon2i, 5, 64*1024, 4, 32, b) })
}
func BenchmarkArgon2d(b *testing.B) {
b.Run(" Time: 3, Memory: 32 MB, Threads: 1", func(b *testing.B) { benchmarkArgon2(argon2d, 3, 32*1024, 1, 32, b) })
b.Run(" Time: 4, Memory: 32 MB, Threads: 1", func(b *testing.B) { benchmarkArgon2(argon2d, 4, 32*1024, 1, 32, b) })
b.Run(" Time: 5, Memory: 32 MB, Threads: 1", func(b *testing.B) { benchmarkArgon2(argon2d, 5, 32*1024, 1, 32, b) })
b.Run(" Time: 3, Memory: 64 MB, Threads: 4", func(b *testing.B) { benchmarkArgon2(argon2d, 3, 64*1024, 4, 32, b) })
b.Run(" Time: 4, Memory: 64 MB, Threads: 4", func(b *testing.B) { benchmarkArgon2(argon2d, 4, 64*1024, 4, 32, b) })
b.Run(" Time: 5, Memory: 64 MB, Threads: 4", func(b *testing.B) { benchmarkArgon2(argon2d, 5, 64*1024, 4, 32, b) })
}
func BenchmarkArgon2id(b *testing.B) {
b.Run(" Time: 3, Memory: 32 MB, Threads: 1", func(b *testing.B) { benchmarkArgon2(argon2id, 3, 32*1024, 1, 32, b) })
b.Run(" Time: 4, Memory: 32 MB, Threads: 1", func(b *testing.B) { benchmarkArgon2(argon2id, 4, 32*1024, 1, 32, b) })
b.Run(" Time: 5, Memory: 32 MB, Threads: 1", func(b *testing.B) { benchmarkArgon2(argon2id, 5, 32*1024, 1, 32, b) })
b.Run(" Time: 3, Memory: 64 MB, Threads: 4", func(b *testing.B) { benchmarkArgon2(argon2id, 3, 64*1024, 4, 32, b) })
b.Run(" Time: 4, Memory: 64 MB, Threads: 4", func(b *testing.B) { benchmarkArgon2(argon2id, 4, 64*1024, 4, 32, b) })
b.Run(" Time: 5, Memory: 64 MB, Threads: 4", func(b *testing.B) { benchmarkArgon2(argon2id, 5, 64*1024, 4, 32, b) })
}
// Generated with the CLI of https://github.com/P-H-C/phc-winner-argon2/blob/master/argon2-specs.pdf
var testVectors = []struct {
mode int
time, memory uint32
threads uint8
hash string
}{
{
mode: argon2i, time: 1, memory: 64, threads: 1,
hash: "b9c401d1844a67d50eae3967dc28870b22e508092e861a37",
},
{
mode: argon2d, time: 1, memory: 64, threads: 1,
hash: "8727405fd07c32c78d64f547f24150d3f2e703a89f981a19",
},
{
mode: argon2id, time: 1, memory: 64, threads: 1,
hash: "655ad15eac652dc59f7170a7332bf49b8469be1fdb9c28bb",
},
{
mode: argon2i, time: 2, memory: 64, threads: 1,
hash: "8cf3d8f76a6617afe35fac48eb0b7433a9a670ca4a07ed64",
},
{
mode: argon2d, time: 2, memory: 64, threads: 1,
hash: "3be9ec79a69b75d3752acb59a1fbb8b295a46529c48fbb75",
},
{
mode: argon2id, time: 2, memory: 64, threads: 1,
hash: "068d62b26455936aa6ebe60060b0a65870dbfa3ddf8d41f7",
},
{
mode: argon2i, time: 2, memory: 64, threads: 2,
hash: "2089f3e78a799720f80af806553128f29b132cafe40d059f",
},
{
mode: argon2d, time: 2, memory: 64, threads: 2,
hash: "68e2462c98b8bc6bb60ec68db418ae2c9ed24fc6748a40e9",
},
{
mode: argon2id, time: 2, memory: 64, threads: 2,
hash: "350ac37222f436ccb5c0972f1ebd3bf6b958bf2071841362",
},
{
mode: argon2i, time: 3, memory: 256, threads: 2,
hash: "f5bbf5d4c3836af13193053155b73ec7476a6a2eb93fd5e6",
},
{
mode: argon2d, time: 3, memory: 256, threads: 2,
hash: "f4f0669218eaf3641f39cc97efb915721102f4b128211ef2",
},
{
mode: argon2id, time: 3, memory: 256, threads: 2,
hash: "4668d30ac4187e6878eedeacf0fd83c5a0a30db2cc16ef0b",
},
{
mode: argon2i, time: 4, memory: 4096, threads: 4,
hash: "a11f7b7f3f93f02ad4bddb59ab62d121e278369288a0d0e7",
},
{
mode: argon2d, time: 4, memory: 4096, threads: 4,
hash: "935598181aa8dc2b720914aa6435ac8d3e3a4210c5b0fb2d",
},
{
mode: argon2id, time: 4, memory: 4096, threads: 4,
hash: "145db9733a9f4ee43edf33c509be96b934d505a4efb33c5a",
},
{
mode: argon2i, time: 4, memory: 1024, threads: 8,
hash: "0cdd3956aa35e6b475a7b0c63488822f774f15b43f6e6e17",
},
{
mode: argon2d, time: 4, memory: 1024, threads: 8,
hash: "83604fc2ad0589b9d055578f4d3cc55bc616df3578a896e9",
},
{
mode: argon2id, time: 4, memory: 1024, threads: 8,
hash: "8dafa8e004f8ea96bf7c0f93eecf67a6047476143d15577f",
},
{
mode: argon2i, time: 2, memory: 64, threads: 3,
hash: "5cab452fe6b8479c8661def8cd703b611a3905a6d5477fe6",
},
{
mode: argon2d, time: 2, memory: 64, threads: 3,
hash: "22474a423bda2ccd36ec9afd5119e5c8949798cadf659f51",
},
{
mode: argon2id, time: 2, memory: 64, threads: 3,
hash: "4a15b31aec7c2590b87d1f520be7d96f56658172deaa3079",
},
{
mode: argon2i, time: 3, memory: 1024, threads: 6,
hash: "d236b29c2b2a09babee842b0dec6aa1e83ccbdea8023dced",
},
{
mode: argon2d, time: 3, memory: 1024, threads: 6,
hash: "a3351b0319a53229152023d9206902f4ef59661cdca89481",
},
{
mode: argon2id, time: 3, memory: 1024, threads: 6,
hash: "1640b932f4b60e272f5d2207b9a9c626ffa1bd88d2349016",
},
}

53
vendor/golang.org/x/crypto/argon2/blake2b.go generated vendored Normal file
View File

@@ -0,0 +1,53 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package argon2
import (
"encoding/binary"
"hash"
"golang.org/x/crypto/blake2b"
)
// blake2bHash computes an arbitrary long hash value of in
// and writes the hash to out.
func blake2bHash(out []byte, in []byte) {
var b2 hash.Hash
if n := len(out); n < blake2b.Size {
b2, _ = blake2b.New(n, nil)
} else {
b2, _ = blake2b.New512(nil)
}
var buffer [blake2b.Size]byte
binary.LittleEndian.PutUint32(buffer[:4], uint32(len(out)))
b2.Write(buffer[:4])
b2.Write(in)
if len(out) <= blake2b.Size {
b2.Sum(out[:0])
return
}
outLen := len(out)
b2.Sum(buffer[:0])
b2.Reset()
copy(out, buffer[:32])
out = out[32:]
for len(out) > blake2b.Size {
b2.Write(buffer[:])
b2.Sum(buffer[:0])
copy(out, buffer[:32])
out = out[32:]
b2.Reset()
}
if outLen%blake2b.Size > 0 { // outLen > 64
r := ((outLen + 31) / 32) - 2 // ⌈τ /32⌉-2
b2, _ = blake2b.New(outLen-32*r, nil)
}
b2.Write(buffer[:])
b2.Sum(out[:0])
}

59
vendor/golang.org/x/crypto/argon2/blamka_amd64.go generated vendored Normal file
View File

@@ -0,0 +1,59 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package argon2
func init() {
useSSE4 = supportsSSE4()
}
//go:noescape
func supportsSSE4() bool
//go:noescape
func mixBlocksSSE2(out, a, b, c *block)
//go:noescape
func xorBlocksSSE2(out, a, b, c *block)
//go:noescape
func blamkaSSE4(b *block)
func processBlockSSE(out, in1, in2 *block, xor bool) {
var t block
mixBlocksSSE2(&t, in1, in2, &t)
if useSSE4 {
blamkaSSE4(&t)
} else {
for i := 0; i < blockLength; i += 16 {
blamkaGeneric(
&t[i+0], &t[i+1], &t[i+2], &t[i+3],
&t[i+4], &t[i+5], &t[i+6], &t[i+7],
&t[i+8], &t[i+9], &t[i+10], &t[i+11],
&t[i+12], &t[i+13], &t[i+14], &t[i+15],
)
}
for i := 0; i < blockLength/8; i += 2 {
blamkaGeneric(
&t[i], &t[i+1], &t[16+i], &t[16+i+1],
&t[32+i], &t[32+i+1], &t[48+i], &t[48+i+1],
&t[64+i], &t[64+i+1], &t[80+i], &t[80+i+1],
&t[96+i], &t[96+i+1], &t[112+i], &t[112+i+1],
)
}
}
if xor {
xorBlocksSSE2(out, in1, in2, &t)
} else {
mixBlocksSSE2(out, in1, in2, &t)
}
}
func processBlock(out, in1, in2 *block) {
processBlockSSE(out, in1, in2, false)
}
func processBlockXOR(out, in1, in2 *block) {
processBlockSSE(out, in1, in2, true)
}

252
vendor/golang.org/x/crypto/argon2/blamka_amd64.s generated vendored Normal file
View File

@@ -0,0 +1,252 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build amd64,!gccgo,!appengine
#include "textflag.h"
DATA ·c40<>+0x00(SB)/8, $0x0201000706050403
DATA ·c40<>+0x08(SB)/8, $0x0a09080f0e0d0c0b
GLOBL ·c40<>(SB), (NOPTR+RODATA), $16
DATA ·c48<>+0x00(SB)/8, $0x0100070605040302
DATA ·c48<>+0x08(SB)/8, $0x09080f0e0d0c0b0a
GLOBL ·c48<>(SB), (NOPTR+RODATA), $16
#define SHUFFLE(v2, v3, v4, v5, v6, v7, t1, t2) \
MOVO v4, t1; \
MOVO v5, v4; \
MOVO t1, v5; \
MOVO v6, t1; \
PUNPCKLQDQ v6, t2; \
PUNPCKHQDQ v7, v6; \
PUNPCKHQDQ t2, v6; \
PUNPCKLQDQ v7, t2; \
MOVO t1, v7; \
MOVO v2, t1; \
PUNPCKHQDQ t2, v7; \
PUNPCKLQDQ v3, t2; \
PUNPCKHQDQ t2, v2; \
PUNPCKLQDQ t1, t2; \
PUNPCKHQDQ t2, v3
#define SHUFFLE_INV(v2, v3, v4, v5, v6, v7, t1, t2) \
MOVO v4, t1; \
MOVO v5, v4; \
MOVO t1, v5; \
MOVO v2, t1; \
PUNPCKLQDQ v2, t2; \
PUNPCKHQDQ v3, v2; \
PUNPCKHQDQ t2, v2; \
PUNPCKLQDQ v3, t2; \
MOVO t1, v3; \
MOVO v6, t1; \
PUNPCKHQDQ t2, v3; \
PUNPCKLQDQ v7, t2; \
PUNPCKHQDQ t2, v6; \
PUNPCKLQDQ t1, t2; \
PUNPCKHQDQ t2, v7
#define HALF_ROUND(v0, v1, v2, v3, v4, v5, v6, v7, t0, c40, c48) \
MOVO v0, t0; \
PMULULQ v2, t0; \
PADDQ v2, v0; \
PADDQ t0, v0; \
PADDQ t0, v0; \
PXOR v0, v6; \
PSHUFD $0xB1, v6, v6; \
MOVO v4, t0; \
PMULULQ v6, t0; \
PADDQ v6, v4; \
PADDQ t0, v4; \
PADDQ t0, v4; \
PXOR v4, v2; \
PSHUFB c40, v2; \
MOVO v0, t0; \
PMULULQ v2, t0; \
PADDQ v2, v0; \
PADDQ t0, v0; \
PADDQ t0, v0; \
PXOR v0, v6; \
PSHUFB c48, v6; \
MOVO v4, t0; \
PMULULQ v6, t0; \
PADDQ v6, v4; \
PADDQ t0, v4; \
PADDQ t0, v4; \
PXOR v4, v2; \
MOVO v2, t0; \
PADDQ v2, t0; \
PSRLQ $63, v2; \
PXOR t0, v2; \
MOVO v1, t0; \
PMULULQ v3, t0; \
PADDQ v3, v1; \
PADDQ t0, v1; \
PADDQ t0, v1; \
PXOR v1, v7; \
PSHUFD $0xB1, v7, v7; \
MOVO v5, t0; \
PMULULQ v7, t0; \
PADDQ v7, v5; \
PADDQ t0, v5; \
PADDQ t0, v5; \
PXOR v5, v3; \
PSHUFB c40, v3; \
MOVO v1, t0; \
PMULULQ v3, t0; \
PADDQ v3, v1; \
PADDQ t0, v1; \
PADDQ t0, v1; \
PXOR v1, v7; \
PSHUFB c48, v7; \
MOVO v5, t0; \
PMULULQ v7, t0; \
PADDQ v7, v5; \
PADDQ t0, v5; \
PADDQ t0, v5; \
PXOR v5, v3; \
MOVO v3, t0; \
PADDQ v3, t0; \
PSRLQ $63, v3; \
PXOR t0, v3
#define LOAD_MSG_0(block, off) \
MOVOU 8*(off+0)(block), X0; \
MOVOU 8*(off+2)(block), X1; \
MOVOU 8*(off+4)(block), X2; \
MOVOU 8*(off+6)(block), X3; \
MOVOU 8*(off+8)(block), X4; \
MOVOU 8*(off+10)(block), X5; \
MOVOU 8*(off+12)(block), X6; \
MOVOU 8*(off+14)(block), X7
#define STORE_MSG_0(block, off) \
MOVOU X0, 8*(off+0)(block); \
MOVOU X1, 8*(off+2)(block); \
MOVOU X2, 8*(off+4)(block); \
MOVOU X3, 8*(off+6)(block); \
MOVOU X4, 8*(off+8)(block); \
MOVOU X5, 8*(off+10)(block); \
MOVOU X6, 8*(off+12)(block); \
MOVOU X7, 8*(off+14)(block)
#define LOAD_MSG_1(block, off) \
MOVOU 8*off+0*8(block), X0; \
MOVOU 8*off+16*8(block), X1; \
MOVOU 8*off+32*8(block), X2; \
MOVOU 8*off+48*8(block), X3; \
MOVOU 8*off+64*8(block), X4; \
MOVOU 8*off+80*8(block), X5; \
MOVOU 8*off+96*8(block), X6; \
MOVOU 8*off+112*8(block), X7
#define STORE_MSG_1(block, off) \
MOVOU X0, 8*off+0*8(block); \
MOVOU X1, 8*off+16*8(block); \
MOVOU X2, 8*off+32*8(block); \
MOVOU X3, 8*off+48*8(block); \
MOVOU X4, 8*off+64*8(block); \
MOVOU X5, 8*off+80*8(block); \
MOVOU X6, 8*off+96*8(block); \
MOVOU X7, 8*off+112*8(block)
#define BLAMKA_ROUND_0(block, off, t0, t1, c40, c48) \
LOAD_MSG_0(block, off); \
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, t0, c40, c48); \
SHUFFLE(X2, X3, X4, X5, X6, X7, t0, t1); \
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, t0, c40, c48); \
SHUFFLE_INV(X2, X3, X4, X5, X6, X7, t0, t1); \
STORE_MSG_0(block, off)
#define BLAMKA_ROUND_1(block, off, t0, t1, c40, c48) \
LOAD_MSG_1(block, off); \
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, t0, c40, c48); \
SHUFFLE(X2, X3, X4, X5, X6, X7, t0, t1); \
HALF_ROUND(X0, X1, X2, X3, X4, X5, X6, X7, t0, c40, c48); \
SHUFFLE_INV(X2, X3, X4, X5, X6, X7, t0, t1); \
STORE_MSG_1(block, off)
// func blamkaSSE4(b *block)
TEXT ·blamkaSSE4(SB), 4, $0-8
MOVQ b+0(FP), AX
MOVOU ·c40<>(SB), X10
MOVOU ·c48<>(SB), X11
BLAMKA_ROUND_0(AX, 0, X8, X9, X10, X11)
BLAMKA_ROUND_0(AX, 16, X8, X9, X10, X11)
BLAMKA_ROUND_0(AX, 32, X8, X9, X10, X11)
BLAMKA_ROUND_0(AX, 48, X8, X9, X10, X11)
BLAMKA_ROUND_0(AX, 64, X8, X9, X10, X11)
BLAMKA_ROUND_0(AX, 80, X8, X9, X10, X11)
BLAMKA_ROUND_0(AX, 96, X8, X9, X10, X11)
BLAMKA_ROUND_0(AX, 112, X8, X9, X10, X11)
BLAMKA_ROUND_1(AX, 0, X8, X9, X10, X11)
BLAMKA_ROUND_1(AX, 2, X8, X9, X10, X11)
BLAMKA_ROUND_1(AX, 4, X8, X9, X10, X11)
BLAMKA_ROUND_1(AX, 6, X8, X9, X10, X11)
BLAMKA_ROUND_1(AX, 8, X8, X9, X10, X11)
BLAMKA_ROUND_1(AX, 10, X8, X9, X10, X11)
BLAMKA_ROUND_1(AX, 12, X8, X9, X10, X11)
BLAMKA_ROUND_1(AX, 14, X8, X9, X10, X11)
RET
// func mixBlocksSSE2(out, a, b, c *block)
TEXT ·mixBlocksSSE2(SB), 4, $0-32
MOVQ out+0(FP), DX
MOVQ a+8(FP), AX
MOVQ b+16(FP), BX
MOVQ a+24(FP), CX
MOVQ $128, BP
loop:
MOVOU 0(AX), X0
MOVOU 0(BX), X1
MOVOU 0(CX), X2
PXOR X1, X0
PXOR X2, X0
MOVOU X0, 0(DX)
ADDQ $16, AX
ADDQ $16, BX
ADDQ $16, CX
ADDQ $16, DX
SUBQ $2, BP
JA loop
RET
// func xorBlocksSSE2(out, a, b, c *block)
TEXT ·xorBlocksSSE2(SB), 4, $0-32
MOVQ out+0(FP), DX
MOVQ a+8(FP), AX
MOVQ b+16(FP), BX
MOVQ a+24(FP), CX
MOVQ $128, BP
loop:
MOVOU 0(AX), X0
MOVOU 0(BX), X1
MOVOU 0(CX), X2
MOVOU 0(DX), X3
PXOR X1, X0
PXOR X2, X0
PXOR X3, X0
MOVOU X0, 0(DX)
ADDQ $16, AX
ADDQ $16, BX
ADDQ $16, CX
ADDQ $16, DX
SUBQ $2, BP
JA loop
RET
// func supportsSSE4() bool
TEXT ·supportsSSE4(SB), 4, $0-1
MOVL $1, AX
CPUID
SHRL $19, CX // Bit 19 indicates SSE4 support
ANDL $1, CX // CX != 0 if support SSE4
MOVB CX, ret+0(FP)
RET

163
vendor/golang.org/x/crypto/argon2/blamka_generic.go generated vendored Normal file
View File

@@ -0,0 +1,163 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package argon2
var useSSE4 bool
func processBlockGeneric(out, in1, in2 *block, xor bool) {
var t block
for i := range t {
t[i] = in1[i] ^ in2[i]
}
for i := 0; i < blockLength; i += 16 {
blamkaGeneric(
&t[i+0], &t[i+1], &t[i+2], &t[i+3],
&t[i+4], &t[i+5], &t[i+6], &t[i+7],
&t[i+8], &t[i+9], &t[i+10], &t[i+11],
&t[i+12], &t[i+13], &t[i+14], &t[i+15],
)
}
for i := 0; i < blockLength/8; i += 2 {
blamkaGeneric(
&t[i], &t[i+1], &t[16+i], &t[16+i+1],
&t[32+i], &t[32+i+1], &t[48+i], &t[48+i+1],
&t[64+i], &t[64+i+1], &t[80+i], &t[80+i+1],
&t[96+i], &t[96+i+1], &t[112+i], &t[112+i+1],
)
}
if xor {
for i := range t {
out[i] ^= in1[i] ^ in2[i] ^ t[i]
}
} else {
for i := range t {
out[i] = in1[i] ^ in2[i] ^ t[i]
}
}
}
func blamkaGeneric(t00, t01, t02, t03, t04, t05, t06, t07, t08, t09, t10, t11, t12, t13, t14, t15 *uint64) {
v00, v01, v02, v03 := *t00, *t01, *t02, *t03
v04, v05, v06, v07 := *t04, *t05, *t06, *t07
v08, v09, v10, v11 := *t08, *t09, *t10, *t11
v12, v13, v14, v15 := *t12, *t13, *t14, *t15
v00 += v04 + 2*uint64(uint32(v00))*uint64(uint32(v04))
v12 ^= v00
v12 = v12>>32 | v12<<32
v08 += v12 + 2*uint64(uint32(v08))*uint64(uint32(v12))
v04 ^= v08
v04 = v04>>24 | v04<<40
v00 += v04 + 2*uint64(uint32(v00))*uint64(uint32(v04))
v12 ^= v00
v12 = v12>>16 | v12<<48
v08 += v12 + 2*uint64(uint32(v08))*uint64(uint32(v12))
v04 ^= v08
v04 = v04>>63 | v04<<1
v01 += v05 + 2*uint64(uint32(v01))*uint64(uint32(v05))
v13 ^= v01
v13 = v13>>32 | v13<<32
v09 += v13 + 2*uint64(uint32(v09))*uint64(uint32(v13))
v05 ^= v09
v05 = v05>>24 | v05<<40
v01 += v05 + 2*uint64(uint32(v01))*uint64(uint32(v05))
v13 ^= v01
v13 = v13>>16 | v13<<48
v09 += v13 + 2*uint64(uint32(v09))*uint64(uint32(v13))
v05 ^= v09
v05 = v05>>63 | v05<<1
v02 += v06 + 2*uint64(uint32(v02))*uint64(uint32(v06))
v14 ^= v02
v14 = v14>>32 | v14<<32
v10 += v14 + 2*uint64(uint32(v10))*uint64(uint32(v14))
v06 ^= v10
v06 = v06>>24 | v06<<40
v02 += v06 + 2*uint64(uint32(v02))*uint64(uint32(v06))
v14 ^= v02
v14 = v14>>16 | v14<<48
v10 += v14 + 2*uint64(uint32(v10))*uint64(uint32(v14))
v06 ^= v10
v06 = v06>>63 | v06<<1
v03 += v07 + 2*uint64(uint32(v03))*uint64(uint32(v07))
v15 ^= v03
v15 = v15>>32 | v15<<32
v11 += v15 + 2*uint64(uint32(v11))*uint64(uint32(v15))
v07 ^= v11
v07 = v07>>24 | v07<<40
v03 += v07 + 2*uint64(uint32(v03))*uint64(uint32(v07))
v15 ^= v03
v15 = v15>>16 | v15<<48
v11 += v15 + 2*uint64(uint32(v11))*uint64(uint32(v15))
v07 ^= v11
v07 = v07>>63 | v07<<1
v00 += v05 + 2*uint64(uint32(v00))*uint64(uint32(v05))
v15 ^= v00
v15 = v15>>32 | v15<<32
v10 += v15 + 2*uint64(uint32(v10))*uint64(uint32(v15))
v05 ^= v10
v05 = v05>>24 | v05<<40
v00 += v05 + 2*uint64(uint32(v00))*uint64(uint32(v05))
v15 ^= v00
v15 = v15>>16 | v15<<48
v10 += v15 + 2*uint64(uint32(v10))*uint64(uint32(v15))
v05 ^= v10
v05 = v05>>63 | v05<<1
v01 += v06 + 2*uint64(uint32(v01))*uint64(uint32(v06))
v12 ^= v01
v12 = v12>>32 | v12<<32
v11 += v12 + 2*uint64(uint32(v11))*uint64(uint32(v12))
v06 ^= v11
v06 = v06>>24 | v06<<40
v01 += v06 + 2*uint64(uint32(v01))*uint64(uint32(v06))
v12 ^= v01
v12 = v12>>16 | v12<<48
v11 += v12 + 2*uint64(uint32(v11))*uint64(uint32(v12))
v06 ^= v11
v06 = v06>>63 | v06<<1
v02 += v07 + 2*uint64(uint32(v02))*uint64(uint32(v07))
v13 ^= v02
v13 = v13>>32 | v13<<32
v08 += v13 + 2*uint64(uint32(v08))*uint64(uint32(v13))
v07 ^= v08
v07 = v07>>24 | v07<<40
v02 += v07 + 2*uint64(uint32(v02))*uint64(uint32(v07))
v13 ^= v02
v13 = v13>>16 | v13<<48
v08 += v13 + 2*uint64(uint32(v08))*uint64(uint32(v13))
v07 ^= v08
v07 = v07>>63 | v07<<1
v03 += v04 + 2*uint64(uint32(v03))*uint64(uint32(v04))
v14 ^= v03
v14 = v14>>32 | v14<<32
v09 += v14 + 2*uint64(uint32(v09))*uint64(uint32(v14))
v04 ^= v09
v04 = v04>>24 | v04<<40
v03 += v04 + 2*uint64(uint32(v03))*uint64(uint32(v04))
v14 ^= v03
v14 = v14>>16 | v14<<48
v09 += v14 + 2*uint64(uint32(v09))*uint64(uint32(v14))
v04 ^= v09
v04 = v04>>63 | v04<<1
*t00, *t01, *t02, *t03 = v00, v01, v02, v03
*t04, *t05, *t06, *t07 = v04, v05, v06, v07
*t08, *t09, *t10, *t11 = v08, v09, v10, v11
*t12, *t13, *t14, *t15 = v12, v13, v14, v15
}

15
vendor/golang.org/x/crypto/argon2/blamka_ref.go generated vendored Normal file
View File

@@ -0,0 +1,15 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !amd64 appengine gccgo
package argon2
func processBlock(out, in1, in2 *block) {
processBlockGeneric(out, in1, in2, false)
}
func processBlockXOR(out, in1, in2 *block) {
processBlockGeneric(out, in1, in2, true)
}

View File

@@ -241,11 +241,11 @@ func (p *hashed) Hash() []byte {
n = 3 n = 3
} }
arr[n] = '$' arr[n] = '$'
n += 1 n++
copy(arr[n:], []byte(fmt.Sprintf("%02d", p.cost))) copy(arr[n:], []byte(fmt.Sprintf("%02d", p.cost)))
n += 2 n += 2
arr[n] = '$' arr[n] = '$'
n += 1 n++
copy(arr[n:], p.salt) copy(arr[n:], p.salt)
n += encodedSaltSize n += encodedSaltSize
copy(arr[n:], p.hash) copy(arr[n:], p.hash)

View File

@@ -39,7 +39,10 @@ var (
useSSE4 bool useSSE4 bool
) )
var errKeySize = errors.New("blake2b: invalid key size") var (
errKeySize = errors.New("blake2b: invalid key size")
errHashSize = errors.New("blake2b: invalid hash size")
)
var iv = [8]uint64{ var iv = [8]uint64{
0x6a09e667f3bcc908, 0xbb67ae8584caa73b, 0x3c6ef372fe94f82b, 0xa54ff53a5f1d36f1, 0x6a09e667f3bcc908, 0xbb67ae8584caa73b, 0x3c6ef372fe94f82b, 0xa54ff53a5f1d36f1,
@@ -83,7 +86,18 @@ func New384(key []byte) (hash.Hash, error) { return newDigest(Size384, key) }
// key turns the hash into a MAC. The key must between zero and 64 bytes long. // key turns the hash into a MAC. The key must between zero and 64 bytes long.
func New256(key []byte) (hash.Hash, error) { return newDigest(Size256, key) } func New256(key []byte) (hash.Hash, error) { return newDigest(Size256, key) }
// New returns a new hash.Hash computing the BLAKE2b checksum with a custom length.
// A non-nil key turns the hash into a MAC. The key must between zero and 64 bytes long.
// The hash size can be a value between 1 and 64 but it is highly recommended to use
// values equal or greater than:
// - 32 if BLAKE2b is used as a hash function (The key is zero bytes long).
// - 16 if BLAKE2b is used as a MAC function (The key is at least 16 bytes long).
func New(size int, key []byte) (hash.Hash, error) { return newDigest(size, key) }
func newDigest(hashSize int, key []byte) (*digest, error) { func newDigest(hashSize int, key []byte) (*digest, error) {
if hashSize < 1 || hashSize > Size {
return nil, errHashSize
}
if len(key) > Size { if len(key) > Size {
return nil, errKeySize return nil, errKeySize
} }

View File

@@ -185,7 +185,7 @@ func testHashes2X(t *testing.T) {
if n, err := h.Read(result[:]); err != nil { if n, err := h.Read(result[:]); err != nil {
t.Fatalf("#unknown length: error from Read: %v", err) t.Fatalf("#unknown length: error from Read: %v", err)
} else if n != len(result) { } else if n != len(result) {
t.Fatalf("#unknown length: Read returned %d bytes, want %d: %v", n, len(result)) t.Fatalf("#unknown length: Read returned %d bytes, want %d", n, len(result))
} }
const expected = "2a9a6977d915a2c4dd07dbcafe1918bf1682e56d9c8e567ecd19bfd7cd93528833c764d12b34a5e2a219c9fd463dab45e972c5574d73f45de5b2e23af72530d8" const expected = "2a9a6977d915a2c4dd07dbcafe1918bf1682e56d9c8e567ecd19bfd7cd93528833c764d12b34a5e2a219c9fd463dab45e972c5574d73f45de5b2e23af72530d8"

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package bn256 implements a particular bilinear group at the 128-bit security level. // Package bn256 implements a particular bilinear group.
// //
// Bilinear groups are the basis of many of the new cryptographic protocols // Bilinear groups are the basis of many of the new cryptographic protocols
// that have been proposed over the past decade. They consist of a triplet of // that have been proposed over the past decade. They consist of a triplet of
@@ -14,6 +14,10 @@
// Barreto-Naehrig curve as described in // Barreto-Naehrig curve as described in
// http://cryptojedi.org/papers/dclxvi-20100714.pdf. Its output is compatible // http://cryptojedi.org/papers/dclxvi-20100714.pdf. Its output is compatible
// with the implementation described in that paper. // with the implementation described in that paper.
//
// (This package previously claimed to operate at a 128-bit security level.
// However, recent improvements in attacks mean that is no longer true. See
// https://moderncrypto.org/mail-archive/curves/2016/000740.html.)
package bn256 // import "golang.org/x/crypto/bn256" package bn256 // import "golang.org/x/crypto/bn256"
import ( import (
@@ -49,8 +53,8 @@ func RandomG1(r io.Reader) (*big.Int, *G1, error) {
return k, new(G1).ScalarBaseMult(k), nil return k, new(G1).ScalarBaseMult(k), nil
} }
func (g *G1) String() string { func (e *G1) String() string {
return "bn256.G1" + g.p.String() return "bn256.G1" + e.p.String()
} }
// ScalarBaseMult sets e to g*k where g is the generator of the group and // ScalarBaseMult sets e to g*k where g is the generator of the group and
@@ -92,11 +96,11 @@ func (e *G1) Neg(a *G1) *G1 {
} }
// Marshal converts n to a byte slice. // Marshal converts n to a byte slice.
func (n *G1) Marshal() []byte { func (e *G1) Marshal() []byte {
n.p.MakeAffine(nil) e.p.MakeAffine(nil)
xBytes := new(big.Int).Mod(n.p.x, p).Bytes() xBytes := new(big.Int).Mod(e.p.x, p).Bytes()
yBytes := new(big.Int).Mod(n.p.y, p).Bytes() yBytes := new(big.Int).Mod(e.p.y, p).Bytes()
// Each value is a 256-bit number. // Each value is a 256-bit number.
const numBytes = 256 / 8 const numBytes = 256 / 8
@@ -166,8 +170,8 @@ func RandomG2(r io.Reader) (*big.Int, *G2, error) {
return k, new(G2).ScalarBaseMult(k), nil return k, new(G2).ScalarBaseMult(k), nil
} }
func (g *G2) String() string { func (e *G2) String() string {
return "bn256.G2" + g.p.String() return "bn256.G2" + e.p.String()
} }
// ScalarBaseMult sets e to g*k where g is the generator of the group and // ScalarBaseMult sets e to g*k where g is the generator of the group and

View File

@@ -7,7 +7,7 @@ package chacha20poly1305
import ( import (
"encoding/binary" "encoding/binary"
"golang.org/x/crypto/chacha20poly1305/internal/chacha20" "golang.org/x/crypto/internal/chacha20"
"golang.org/x/crypto/poly1305" "golang.org/x/crypto/poly1305"
) )

View File

@@ -47,7 +47,7 @@ func Sum(m []byte, key *[KeySize]byte) *[Size]byte {
// Verify checks that digest is a valid authenticator of message m under the // Verify checks that digest is a valid authenticator of message m under the
// given secret key. Verify does not leak timing information. // given secret key. Verify does not leak timing information.
func Verify(digest []byte, m []byte, key *[32]byte) bool { func Verify(digest []byte, m []byte, key *[KeySize]byte) bool {
if len(digest) != Size { if len(digest) != Size {
return false return false
} }

View File

@@ -760,7 +760,7 @@ func CreateResponse(issuer, responderCert *x509.Certificate, template Response,
} }
if template.Certificate != nil { if template.Certificate != nil {
response.Certificates = []asn1.RawValue{ response.Certificates = []asn1.RawValue{
asn1.RawValue{FullBytes: template.Certificate.Raw}, {FullBytes: template.Certificate.Raw},
} }
} }
responseDER, err := asn1.Marshal(response) responseDER, err := asn1.Marshal(response)

View File

@@ -43,11 +43,11 @@ func TestOCSPDecode(t *testing.T) {
} }
if !reflect.DeepEqual(resp.ThisUpdate, expected.ThisUpdate) { if !reflect.DeepEqual(resp.ThisUpdate, expected.ThisUpdate) {
t.Errorf("resp.ThisUpdate: got %d, want %d", resp.ThisUpdate, expected.ThisUpdate) t.Errorf("resp.ThisUpdate: got %v, want %v", resp.ThisUpdate, expected.ThisUpdate)
} }
if !reflect.DeepEqual(resp.NextUpdate, expected.NextUpdate) { if !reflect.DeepEqual(resp.NextUpdate, expected.NextUpdate) {
t.Errorf("resp.NextUpdate: got %d, want %d", resp.NextUpdate, expected.NextUpdate) t.Errorf("resp.NextUpdate: got %v, want %v", resp.NextUpdate, expected.NextUpdate)
} }
if resp.Status != expected.Status { if resp.Status != expected.Status {
@@ -218,7 +218,7 @@ func TestOCSPResponse(t *testing.T) {
extensionBytes, _ := hex.DecodeString(ocspExtensionValueHex) extensionBytes, _ := hex.DecodeString(ocspExtensionValueHex)
extensions := []pkix.Extension{ extensions := []pkix.Extension{
pkix.Extension{ {
Id: ocspExtensionOID, Id: ocspExtensionOID,
Critical: false, Critical: false,
Value: extensionBytes, Value: extensionBytes,
@@ -268,15 +268,15 @@ func TestOCSPResponse(t *testing.T) {
} }
if !reflect.DeepEqual(resp.ThisUpdate, template.ThisUpdate) { if !reflect.DeepEqual(resp.ThisUpdate, template.ThisUpdate) {
t.Errorf("resp.ThisUpdate: got %d, want %d", resp.ThisUpdate, template.ThisUpdate) t.Errorf("resp.ThisUpdate: got %v, want %v", resp.ThisUpdate, template.ThisUpdate)
} }
if !reflect.DeepEqual(resp.NextUpdate, template.NextUpdate) { if !reflect.DeepEqual(resp.NextUpdate, template.NextUpdate) {
t.Errorf("resp.NextUpdate: got %d, want %d", resp.NextUpdate, template.NextUpdate) t.Errorf("resp.NextUpdate: got %v, want %v", resp.NextUpdate, template.NextUpdate)
} }
if !reflect.DeepEqual(resp.RevokedAt, template.RevokedAt) { if !reflect.DeepEqual(resp.RevokedAt, template.RevokedAt) {
t.Errorf("resp.RevokedAt: got %d, want %d", resp.RevokedAt, template.RevokedAt) t.Errorf("resp.RevokedAt: got %v, want %v", resp.RevokedAt, template.RevokedAt)
} }
if !reflect.DeepEqual(resp.Extensions, template.ExtraExtensions) { if !reflect.DeepEqual(resp.Extensions, template.ExtraExtensions) {

View File

@@ -325,9 +325,8 @@ func ReadEntity(packets *packet.Reader) (*Entity, error) {
if e.PrivateKey, ok = p.(*packet.PrivateKey); !ok { if e.PrivateKey, ok = p.(*packet.PrivateKey); !ok {
packets.Unread(p) packets.Unread(p)
return nil, errors.StructuralError("first packet was not a public/private key") return nil, errors.StructuralError("first packet was not a public/private key")
} else {
e.PrimaryKey = &e.PrivateKey.PublicKey
} }
e.PrimaryKey = &e.PrivateKey.PublicKey
} }
if !e.PrimaryKey.PubKeyAlgo.CanSign() { if !e.PrimaryKey.PubKeyAlgo.CanSign() {

View File

@@ -155,3 +155,22 @@ func TestWithHMACSHA1(t *testing.T) {
func TestWithHMACSHA256(t *testing.T) { func TestWithHMACSHA256(t *testing.T) {
testHash(t, sha256.New, "SHA256", sha256TestVectors) testHash(t, sha256.New, "SHA256", sha256TestVectors)
} }
var sink uint8
func benchmark(b *testing.B, h func() hash.Hash) {
password := make([]byte, h().Size())
salt := make([]byte, 8)
for i := 0; i < b.N; i++ {
password = Key(password, salt, 4096, len(password), h)
}
sink += password[0]
}
func BenchmarkHMACSHA1(b *testing.B) {
benchmark(b, sha1.New)
}
func BenchmarkHMACSHA256(b *testing.B) {
benchmark(b, sha256.New)
}

View File

@@ -122,7 +122,6 @@ func (c *rc2Cipher) Encrypt(dst, src []byte) {
r3 = r3 + c.k[r2&63] r3 = r3 + c.k[r2&63]
for j <= 40 { for j <= 40 {
// mix r0 // mix r0
r0 = r0 + c.k[j] + (r3 & r2) + ((^r3) & r1) r0 = r0 + c.k[j] + (r3 & r2) + ((^r3) & r1)
r0 = rotl16(r0, 1) r0 = rotl16(r0, 1)
@@ -151,7 +150,6 @@ func (c *rc2Cipher) Encrypt(dst, src []byte) {
r3 = r3 + c.k[r2&63] r3 = r3 + c.k[r2&63]
for j <= 60 { for j <= 60 {
// mix r0 // mix r0
r0 = r0 + c.k[j] + (r3 & r2) + ((^r3) & r1) r0 = r0 + c.k[j] + (r3 & r2) + ((^r3) & r1)
r0 = rotl16(r0, 1) r0 = rotl16(r0, 1)
@@ -244,7 +242,6 @@ func (c *rc2Cipher) Decrypt(dst, src []byte) {
r0 = r0 - c.k[r3&63] r0 = r0 - c.k[r3&63]
for j >= 0 { for j >= 0 {
// unmix r3 // unmix r3
r3 = rotl16(r3, 16-5) r3 = rotl16(r3, 16-5)
r3 = r3 - c.k[j] - (r2 & r1) - ((^r2) & r0) r3 = r3 - c.k[j] - (r2 & r1) - ((^r2) & r0)

View File

@@ -11,7 +11,6 @@ import (
) )
func TestEncryptDecrypt(t *testing.T) { func TestEncryptDecrypt(t *testing.T) {
// TODO(dgryski): add the rest of the test vectors from the RFC // TODO(dgryski): add the rest of the test vectors from the RFC
var tests = []struct { var tests = []struct {
key string key string

View File

@@ -202,7 +202,7 @@ func TestSqueezing(t *testing.T) {
d1 := newShakeHash() d1 := newShakeHash()
d1.Write([]byte(testString)) d1.Write([]byte(testString))
var multiple []byte var multiple []byte
for _ = range ref { for range ref {
one := make([]byte, 1) one := make([]byte, 1)
d1.Read(one) d1.Read(one)
multiple = append(multiple, one...) multiple = append(multiple, one...)

View File

@@ -98,7 +98,7 @@ const (
agentAddIdentity = 17 agentAddIdentity = 17
agentRemoveIdentity = 18 agentRemoveIdentity = 18
agentRemoveAllIdentities = 19 agentRemoveAllIdentities = 19
agentAddIdConstrained = 25 agentAddIDConstrained = 25
// 3.3 Key-type independent requests from client to agent // 3.3 Key-type independent requests from client to agent
agentAddSmartcardKey = 20 agentAddSmartcardKey = 20
@@ -515,7 +515,7 @@ func (c *client) insertKey(s interface{}, comment string, constraints []byte) er
// if constraints are present then the message type needs to be changed. // if constraints are present then the message type needs to be changed.
if len(constraints) != 0 { if len(constraints) != 0 {
req[0] = agentAddIdConstrained req[0] = agentAddIDConstrained
} }
resp, err := c.call(req) resp, err := c.call(req)
@@ -577,11 +577,11 @@ func (c *client) Add(key AddedKey) error {
constraints = append(constraints, agentConstrainConfirm) constraints = append(constraints, agentConstrainConfirm)
} }
if cert := key.Certificate; cert == nil { cert := key.Certificate
if cert == nil {
return c.insertKey(key.PrivateKey, key.Comment, constraints) return c.insertKey(key.PrivateKey, key.Comment, constraints)
} else {
return c.insertCert(key.PrivateKey, cert, key.Comment, constraints)
} }
return c.insertCert(key.PrivateKey, cert, key.Comment, constraints)
} }
func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string, constraints []byte) error { func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string, constraints []byte) error {
@@ -633,7 +633,7 @@ func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string
// if constraints are present then the message type needs to be changed. // if constraints are present then the message type needs to be changed.
if len(constraints) != 0 { if len(constraints) != 0 {
req[0] = agentAddIdConstrained req[0] = agentAddIDConstrained
} }
signer, err := ssh.NewSignerFromKey(s) signer, err := ssh.NewSignerFromKey(s)

View File

@@ -148,7 +148,7 @@ func (s *server) processRequest(data []byte) (interface{}, error) {
} }
return rep, nil return rep, nil
case agentAddIdConstrained, agentAddIdentity: case agentAddIDConstrained, agentAddIdentity:
return nil, s.insertIdentity(data) return nil, s.insertIdentity(data)
} }

View File

@@ -40,7 +40,8 @@ func sshPipe() (Conn, *server, error) {
} }
clientConf := ClientConfig{ clientConf := ClientConfig{
User: "user", User: "user",
HostKeyCallback: InsecureIgnoreHostKey(),
} }
serverConf := ServerConfig{ serverConf := ServerConfig{
NoClientAuth: true, NoClientAuth: true,

View File

@@ -340,10 +340,10 @@ func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permis
// the signature of the certificate. // the signature of the certificate.
func (c *CertChecker) CheckCert(principal string, cert *Certificate) error { func (c *CertChecker) CheckCert(principal string, cert *Certificate) error {
if c.IsRevoked != nil && c.IsRevoked(cert) { if c.IsRevoked != nil && c.IsRevoked(cert) {
return fmt.Errorf("ssh: certicate serial %d revoked", cert.Serial) return fmt.Errorf("ssh: certificate serial %d revoked", cert.Serial)
} }
for opt, _ := range cert.CriticalOptions { for opt := range cert.CriticalOptions {
// sourceAddressCriticalOption will be enforced by // sourceAddressCriticalOption will be enforced by
// serverAuthenticate // serverAuthenticate
if opt == sourceAddressCriticalOption { if opt == sourceAddressCriticalOption {

View File

@@ -6,10 +6,15 @@ package ssh
import ( import (
"bytes" "bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand" "crypto/rand"
"net"
"reflect" "reflect"
"testing" "testing"
"time" "time"
"golang.org/x/crypto/ssh/testdata"
) )
// Cert generated by ssh-keygen 6.0p1 Debian-4. // Cert generated by ssh-keygen 6.0p1 Debian-4.
@@ -220,3 +225,111 @@ func TestHostKeyCert(t *testing.T) {
} }
} }
} }
func TestCertTypes(t *testing.T) {
var testVars = []struct {
name string
keys func() Signer
}{
{
name: CertAlgoECDSA256v01,
keys: func() Signer {
s, _ := ParsePrivateKey(testdata.PEMBytes["ecdsap256"])
return s
},
},
{
name: CertAlgoECDSA384v01,
keys: func() Signer {
s, _ := ParsePrivateKey(testdata.PEMBytes["ecdsap384"])
return s
},
},
{
name: CertAlgoECDSA521v01,
keys: func() Signer {
s, _ := ParsePrivateKey(testdata.PEMBytes["ecdsap521"])
return s
},
},
{
name: CertAlgoED25519v01,
keys: func() Signer {
s, _ := ParsePrivateKey(testdata.PEMBytes["ed25519"])
return s
},
},
{
name: CertAlgoRSAv01,
keys: func() Signer {
s, _ := ParsePrivateKey(testdata.PEMBytes["rsa"])
return s
},
},
{
name: CertAlgoDSAv01,
keys: func() Signer {
s, _ := ParsePrivateKey(testdata.PEMBytes["dsa"])
return s
},
},
}
k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("error generating host key: %v", err)
}
signer, err := NewSignerFromKey(k)
if err != nil {
t.Fatalf("error generating signer for ssh listener: %v", err)
}
conf := &ServerConfig{
PublicKeyCallback: func(c ConnMetadata, k PublicKey) (*Permissions, error) {
return new(Permissions), nil
},
}
conf.AddHostKey(signer)
for _, m := range testVars {
t.Run(m.name, func(t *testing.T) {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go NewServerConn(c1, conf)
priv := m.keys()
if err != nil {
t.Fatalf("error generating ssh pubkey: %v", err)
}
cert := &Certificate{
CertType: UserCert,
Key: priv.PublicKey(),
}
cert.SignCert(rand.Reader, priv)
certSigner, err := NewCertSigner(cert, priv)
if err != nil {
t.Fatalf("error generating cert signer: %v", err)
}
config := &ClientConfig{
User: "user",
HostKeyCallback: func(h string, r net.Addr, k PublicKey) error { return nil },
Auth: []AuthMethod{PublicKeys(certSigner)},
}
_, _, _, err = NewClientConn(c2, "", config)
if err != nil {
t.Fatalf("error connecting: %v", err)
}
})
}
}

View File

@@ -205,32 +205,32 @@ type channel struct {
// writePacket sends a packet. If the packet is a channel close, it updates // writePacket sends a packet. If the packet is a channel close, it updates
// sentClose. This method takes the lock c.writeMu. // sentClose. This method takes the lock c.writeMu.
func (c *channel) writePacket(packet []byte) error { func (ch *channel) writePacket(packet []byte) error {
c.writeMu.Lock() ch.writeMu.Lock()
if c.sentClose { if ch.sentClose {
c.writeMu.Unlock() ch.writeMu.Unlock()
return io.EOF return io.EOF
} }
c.sentClose = (packet[0] == msgChannelClose) ch.sentClose = (packet[0] == msgChannelClose)
err := c.mux.conn.writePacket(packet) err := ch.mux.conn.writePacket(packet)
c.writeMu.Unlock() ch.writeMu.Unlock()
return err return err
} }
func (c *channel) sendMessage(msg interface{}) error { func (ch *channel) sendMessage(msg interface{}) error {
if debugMux { if debugMux {
log.Printf("send(%d): %#v", c.mux.chanList.offset, msg) log.Printf("send(%d): %#v", ch.mux.chanList.offset, msg)
} }
p := Marshal(msg) p := Marshal(msg)
binary.BigEndian.PutUint32(p[1:], c.remoteId) binary.BigEndian.PutUint32(p[1:], ch.remoteId)
return c.writePacket(p) return ch.writePacket(p)
} }
// WriteExtended writes data to a specific extended stream. These streams are // WriteExtended writes data to a specific extended stream. These streams are
// used, for example, for stderr. // used, for example, for stderr.
func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) { func (ch *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) {
if c.sentEOF { if ch.sentEOF {
return 0, io.EOF return 0, io.EOF
} }
// 1 byte message type, 4 bytes remoteId, 4 bytes data length // 1 byte message type, 4 bytes remoteId, 4 bytes data length
@@ -241,16 +241,16 @@ func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err er
opCode = msgChannelExtendedData opCode = msgChannelExtendedData
} }
c.writeMu.Lock() ch.writeMu.Lock()
packet := c.packetPool[extendedCode] packet := ch.packetPool[extendedCode]
// We don't remove the buffer from packetPool, so // We don't remove the buffer from packetPool, so
// WriteExtended calls from different goroutines will be // WriteExtended calls from different goroutines will be
// flagged as errors by the race detector. // flagged as errors by the race detector.
c.writeMu.Unlock() ch.writeMu.Unlock()
for len(data) > 0 { for len(data) > 0 {
space := min(c.maxRemotePayload, len(data)) space := min(ch.maxRemotePayload, len(data))
if space, err = c.remoteWin.reserve(space); err != nil { if space, err = ch.remoteWin.reserve(space); err != nil {
return n, err return n, err
} }
if want := headerLength + space; uint32(cap(packet)) < want { if want := headerLength + space; uint32(cap(packet)) < want {
@@ -262,13 +262,13 @@ func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err er
todo := data[:space] todo := data[:space]
packet[0] = opCode packet[0] = opCode
binary.BigEndian.PutUint32(packet[1:], c.remoteId) binary.BigEndian.PutUint32(packet[1:], ch.remoteId)
if extendedCode > 0 { if extendedCode > 0 {
binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode)) binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode))
} }
binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo))) binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo)))
copy(packet[headerLength:], todo) copy(packet[headerLength:], todo)
if err = c.writePacket(packet); err != nil { if err = ch.writePacket(packet); err != nil {
return n, err return n, err
} }
@@ -276,14 +276,14 @@ func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err er
data = data[len(todo):] data = data[len(todo):]
} }
c.writeMu.Lock() ch.writeMu.Lock()
c.packetPool[extendedCode] = packet ch.packetPool[extendedCode] = packet
c.writeMu.Unlock() ch.writeMu.Unlock()
return n, err return n, err
} }
func (c *channel) handleData(packet []byte) error { func (ch *channel) handleData(packet []byte) error {
headerLen := 9 headerLen := 9
isExtendedData := packet[0] == msgChannelExtendedData isExtendedData := packet[0] == msgChannelExtendedData
if isExtendedData { if isExtendedData {
@@ -303,7 +303,7 @@ func (c *channel) handleData(packet []byte) error {
if length == 0 { if length == 0 {
return nil return nil
} }
if length > c.maxIncomingPayload { if length > ch.maxIncomingPayload {
// TODO(hanwen): should send Disconnect? // TODO(hanwen): should send Disconnect?
return errors.New("ssh: incoming packet exceeds maximum payload size") return errors.New("ssh: incoming packet exceeds maximum payload size")
} }
@@ -313,21 +313,21 @@ func (c *channel) handleData(packet []byte) error {
return errors.New("ssh: wrong packet length") return errors.New("ssh: wrong packet length")
} }
c.windowMu.Lock() ch.windowMu.Lock()
if c.myWindow < length { if ch.myWindow < length {
c.windowMu.Unlock() ch.windowMu.Unlock()
// TODO(hanwen): should send Disconnect with reason? // TODO(hanwen): should send Disconnect with reason?
return errors.New("ssh: remote side wrote too much") return errors.New("ssh: remote side wrote too much")
} }
c.myWindow -= length ch.myWindow -= length
c.windowMu.Unlock() ch.windowMu.Unlock()
if extended == 1 { if extended == 1 {
c.extPending.write(data) ch.extPending.write(data)
} else if extended > 0 { } else if extended > 0 {
// discard other extended data. // discard other extended data.
} else { } else {
c.pending.write(data) ch.pending.write(data)
} }
return nil return nil
} }
@@ -384,31 +384,31 @@ func (c *channel) close() {
// responseMessageReceived is called when a success or failure message is // responseMessageReceived is called when a success or failure message is
// received on a channel to check that such a message is reasonable for the // received on a channel to check that such a message is reasonable for the
// given channel. // given channel.
func (c *channel) responseMessageReceived() error { func (ch *channel) responseMessageReceived() error {
if c.direction == channelInbound { if ch.direction == channelInbound {
return errors.New("ssh: channel response message received on inbound channel") return errors.New("ssh: channel response message received on inbound channel")
} }
if c.decided { if ch.decided {
return errors.New("ssh: duplicate response received for channel") return errors.New("ssh: duplicate response received for channel")
} }
c.decided = true ch.decided = true
return nil return nil
} }
func (c *channel) handlePacket(packet []byte) error { func (ch *channel) handlePacket(packet []byte) error {
switch packet[0] { switch packet[0] {
case msgChannelData, msgChannelExtendedData: case msgChannelData, msgChannelExtendedData:
return c.handleData(packet) return ch.handleData(packet)
case msgChannelClose: case msgChannelClose:
c.sendMessage(channelCloseMsg{PeersId: c.remoteId}) ch.sendMessage(channelCloseMsg{PeersID: ch.remoteId})
c.mux.chanList.remove(c.localId) ch.mux.chanList.remove(ch.localId)
c.close() ch.close()
return nil return nil
case msgChannelEOF: case msgChannelEOF:
// RFC 4254 is mute on how EOF affects dataExt messages but // RFC 4254 is mute on how EOF affects dataExt messages but
// it is logical to signal EOF at the same time. // it is logical to signal EOF at the same time.
c.extPending.eof() ch.extPending.eof()
c.pending.eof() ch.pending.eof()
return nil return nil
} }
@@ -419,24 +419,24 @@ func (c *channel) handlePacket(packet []byte) error {
switch msg := decoded.(type) { switch msg := decoded.(type) {
case *channelOpenFailureMsg: case *channelOpenFailureMsg:
if err := c.responseMessageReceived(); err != nil { if err := ch.responseMessageReceived(); err != nil {
return err return err
} }
c.mux.chanList.remove(msg.PeersId) ch.mux.chanList.remove(msg.PeersID)
c.msg <- msg ch.msg <- msg
case *channelOpenConfirmMsg: case *channelOpenConfirmMsg:
if err := c.responseMessageReceived(); err != nil { if err := ch.responseMessageReceived(); err != nil {
return err return err
} }
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize) return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize)
} }
c.remoteId = msg.MyId ch.remoteId = msg.MyID
c.maxRemotePayload = msg.MaxPacketSize ch.maxRemotePayload = msg.MaxPacketSize
c.remoteWin.add(msg.MyWindow) ch.remoteWin.add(msg.MyWindow)
c.msg <- msg ch.msg <- msg
case *windowAdjustMsg: case *windowAdjustMsg:
if !c.remoteWin.add(msg.AdditionalBytes) { if !ch.remoteWin.add(msg.AdditionalBytes) {
return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes) return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes)
} }
case *channelRequestMsg: case *channelRequestMsg:
@@ -444,12 +444,12 @@ func (c *channel) handlePacket(packet []byte) error {
Type: msg.Request, Type: msg.Request,
WantReply: msg.WantReply, WantReply: msg.WantReply,
Payload: msg.RequestSpecificData, Payload: msg.RequestSpecificData,
ch: c, ch: ch,
} }
c.incomingRequests <- &req ch.incomingRequests <- &req
default: default:
c.msg <- msg ch.msg <- msg
} }
return nil return nil
} }
@@ -488,23 +488,23 @@ func (e *extChannel) Read(data []byte) (n int, err error) {
return e.ch.ReadExtended(data, e.code) return e.ch.ReadExtended(data, e.code)
} }
func (c *channel) Accept() (Channel, <-chan *Request, error) { func (ch *channel) Accept() (Channel, <-chan *Request, error) {
if c.decided { if ch.decided {
return nil, nil, errDecidedAlready return nil, nil, errDecidedAlready
} }
c.maxIncomingPayload = channelMaxPacket ch.maxIncomingPayload = channelMaxPacket
confirm := channelOpenConfirmMsg{ confirm := channelOpenConfirmMsg{
PeersId: c.remoteId, PeersID: ch.remoteId,
MyId: c.localId, MyID: ch.localId,
MyWindow: c.myWindow, MyWindow: ch.myWindow,
MaxPacketSize: c.maxIncomingPayload, MaxPacketSize: ch.maxIncomingPayload,
} }
c.decided = true ch.decided = true
if err := c.sendMessage(confirm); err != nil { if err := ch.sendMessage(confirm); err != nil {
return nil, nil, err return nil, nil, err
} }
return c, c.incomingRequests, nil return ch, ch.incomingRequests, nil
} }
func (ch *channel) Reject(reason RejectionReason, message string) error { func (ch *channel) Reject(reason RejectionReason, message string) error {
@@ -512,7 +512,7 @@ func (ch *channel) Reject(reason RejectionReason, message string) error {
return errDecidedAlready return errDecidedAlready
} }
reject := channelOpenFailureMsg{ reject := channelOpenFailureMsg{
PeersId: ch.remoteId, PeersID: ch.remoteId,
Reason: reason, Reason: reason,
Message: message, Message: message,
Language: "en", Language: "en",
@@ -541,7 +541,7 @@ func (ch *channel) CloseWrite() error {
} }
ch.sentEOF = true ch.sentEOF = true
return ch.sendMessage(channelEOFMsg{ return ch.sendMessage(channelEOFMsg{
PeersId: ch.remoteId}) PeersID: ch.remoteId})
} }
func (ch *channel) Close() error { func (ch *channel) Close() error {
@@ -550,7 +550,7 @@ func (ch *channel) Close() error {
} }
return ch.sendMessage(channelCloseMsg{ return ch.sendMessage(channelCloseMsg{
PeersId: ch.remoteId}) PeersID: ch.remoteId})
} }
// Extended returns an io.ReadWriter that sends and receives data on the given, // Extended returns an io.ReadWriter that sends and receives data on the given,
@@ -577,7 +577,7 @@ func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (boo
} }
msg := channelRequestMsg{ msg := channelRequestMsg{
PeersId: ch.remoteId, PeersID: ch.remoteId,
Request: name, Request: name,
WantReply: wantReply, WantReply: wantReply,
RequestSpecificData: payload, RequestSpecificData: payload,
@@ -614,11 +614,11 @@ func (ch *channel) ackRequest(ok bool) error {
var msg interface{} var msg interface{}
if !ok { if !ok {
msg = channelRequestFailureMsg{ msg = channelRequestFailureMsg{
PeersId: ch.remoteId, PeersID: ch.remoteId,
} }
} else { } else {
msg = channelRequestSuccessMsg{ msg = channelRequestSuccessMsg{
PeersId: ch.remoteId, PeersID: ch.remoteId,
} }
} }
return ch.sendMessage(msg) return ch.sendMessage(msg)

View File

@@ -304,7 +304,7 @@ type gcmCipher struct {
buf []byte buf []byte
} }
func newGCMCipher(iv, key, macKey []byte) (packetCipher, error) { func newGCMCipher(iv, key []byte) (packetCipher, error) {
c, err := aes.NewCipher(key) c, err := aes.NewCipher(key)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -372,7 +372,7 @@ func (c *gcmCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) {
} }
length := binary.BigEndian.Uint32(c.prefix[:]) length := binary.BigEndian.Uint32(c.prefix[:])
if length > maxPacket { if length > maxPacket {
return nil, errors.New("ssh: max packet length exceeded.") return nil, errors.New("ssh: max packet length exceeded")
} }
if cap(c.buf) < int(length+gcmTagSize) { if cap(c.buf) < int(length+gcmTagSize) {
@@ -548,11 +548,11 @@ func (c *cbcCipher) readPacketLeaky(seqNum uint32, r io.Reader) ([]byte, error)
c.packetData = c.packetData[:entirePacketSize] c.packetData = c.packetData[:entirePacketSize]
} }
if n, err := io.ReadFull(r, c.packetData[firstBlockLength:]); err != nil { n, err := io.ReadFull(r, c.packetData[firstBlockLength:])
if err != nil {
return nil, err return nil, err
} else {
c.oracleCamouflage -= uint32(n)
} }
c.oracleCamouflage -= uint32(n)
remainingCrypted := c.packetData[firstBlockLength:macStart] remainingCrypted := c.packetData[firstBlockLength:macStart]
c.decrypter.CryptBlocks(remainingCrypted, remainingCrypted) c.decrypter.CryptBlocks(remainingCrypted, remainingCrypted)

View File

@@ -21,47 +21,48 @@ func TestDefaultCiphersExist(t *testing.T) {
} }
func TestPacketCiphers(t *testing.T) { func TestPacketCiphers(t *testing.T) {
// Still test aes128cbc cipher although it's commented out. defaultMac := "hmac-sha2-256"
cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil} defaultCipher := "aes128-ctr"
defer delete(cipherModes, aes128cbcID)
for cipher := range cipherModes { for cipher := range cipherModes {
for mac := range macModes { t.Run("cipher="+cipher,
kr := &kexResult{Hash: crypto.SHA1} func(t *testing.T) { testPacketCipher(t, cipher, defaultMac) })
algs := directionAlgorithms{ }
Cipher: cipher, for mac := range macModes {
MAC: mac, t.Run("mac="+mac,
Compression: "none", func(t *testing.T) { testPacketCipher(t, defaultCipher, mac) })
} }
client, err := newPacketCipher(clientKeys, algs, kr) }
if err != nil {
t.Errorf("newPacketCipher(client, %q, %q): %v", cipher, mac, err)
continue
}
server, err := newPacketCipher(clientKeys, algs, kr)
if err != nil {
t.Errorf("newPacketCipher(client, %q, %q): %v", cipher, mac, err)
continue
}
want := "bla bla" func testPacketCipher(t *testing.T, cipher, mac string) {
input := []byte(want) kr := &kexResult{Hash: crypto.SHA1}
buf := &bytes.Buffer{} algs := directionAlgorithms{
if err := client.writePacket(0, buf, rand.Reader, input); err != nil { Cipher: cipher,
t.Errorf("writePacket(%q, %q): %v", cipher, mac, err) MAC: mac,
continue Compression: "none",
} }
client, err := newPacketCipher(clientKeys, algs, kr)
if err != nil {
t.Fatalf("newPacketCipher(client, %q, %q): %v", cipher, mac, err)
}
server, err := newPacketCipher(clientKeys, algs, kr)
if err != nil {
t.Fatalf("newPacketCipher(client, %q, %q): %v", cipher, mac, err)
}
packet, err := server.readPacket(0, buf) want := "bla bla"
if err != nil { input := []byte(want)
t.Errorf("readPacket(%q, %q): %v", cipher, mac, err) buf := &bytes.Buffer{}
continue if err := client.writePacket(0, buf, rand.Reader, input); err != nil {
} t.Fatalf("writePacket(%q, %q): %v", cipher, mac, err)
}
if string(packet) != want { packet, err := server.readPacket(0, buf)
t.Errorf("roundtrip(%q, %q): got %q, want %q", cipher, mac, packet, want) if err != nil {
} t.Fatalf("readPacket(%q, %q): %v", cipher, mac, err)
} }
if string(packet) != want {
t.Errorf("roundtrip(%q, %q): got %q, want %q", cipher, mac, packet, want)
} }
} }

View File

@@ -9,6 +9,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"os"
"sync" "sync"
"time" "time"
) )
@@ -187,6 +188,10 @@ func Dial(network, addr string, config *ClientConfig) (*Client, error) {
// net.Conn underlying the the SSH connection. // net.Conn underlying the the SSH connection.
type HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error type HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
// BannerCallback is the function type used for treat the banner sent by
// the server. A BannerCallback receives the message sent by the remote server.
type BannerCallback func(message string) error
// A ClientConfig structure is used to configure a Client. It must not be // A ClientConfig structure is used to configure a Client. It must not be
// modified after having been passed to an SSH function. // modified after having been passed to an SSH function.
type ClientConfig struct { type ClientConfig struct {
@@ -209,6 +214,12 @@ type ClientConfig struct {
// FixedHostKey can be used for simplistic host key checks. // FixedHostKey can be used for simplistic host key checks.
HostKeyCallback HostKeyCallback HostKeyCallback HostKeyCallback
// BannerCallback is called during the SSH dance to display a custom
// server's message. The client configuration can supply this callback to
// handle it as wished. The function BannerDisplayStderr can be used for
// simplistic display on Stderr.
BannerCallback BannerCallback
// ClientVersion contains the version identification string that will // ClientVersion contains the version identification string that will
// be used for the connection. If empty, a reasonable default is used. // be used for the connection. If empty, a reasonable default is used.
ClientVersion string ClientVersion string
@@ -255,3 +266,13 @@ func FixedHostKey(key PublicKey) HostKeyCallback {
hk := &fixedHostKey{key} hk := &fixedHostKey{key}
return hk.check return hk.check
} }
// BannerDisplayStderr returns a function that can be used for
// ClientConfig.BannerCallback to display banners on os.Stderr.
func BannerDisplayStderr() BannerCallback {
return func(banner string) error {
_, err := os.Stderr.WriteString(banner)
return err
}
}

View File

@@ -283,7 +283,9 @@ func confirmKeyAck(key PublicKey, c packetConn) (bool, error) {
} }
switch packet[0] { switch packet[0] {
case msgUserAuthBanner: case msgUserAuthBanner:
// TODO(gpaul): add callback to present the banner to the user if err := handleBannerResponse(c, packet); err != nil {
return false, err
}
case msgUserAuthPubKeyOk: case msgUserAuthPubKeyOk:
var msg userAuthPubKeyOkMsg var msg userAuthPubKeyOkMsg
if err := Unmarshal(packet, &msg); err != nil { if err := Unmarshal(packet, &msg); err != nil {
@@ -325,7 +327,9 @@ func handleAuthResponse(c packetConn) (bool, []string, error) {
switch packet[0] { switch packet[0] {
case msgUserAuthBanner: case msgUserAuthBanner:
// TODO: add callback to present the banner to the user if err := handleBannerResponse(c, packet); err != nil {
return false, nil, err
}
case msgUserAuthFailure: case msgUserAuthFailure:
var msg userAuthFailureMsg var msg userAuthFailureMsg
if err := Unmarshal(packet, &msg); err != nil { if err := Unmarshal(packet, &msg); err != nil {
@@ -340,6 +344,24 @@ func handleAuthResponse(c packetConn) (bool, []string, error) {
} }
} }
func handleBannerResponse(c packetConn, packet []byte) error {
var msg userAuthBannerMsg
if err := Unmarshal(packet, &msg); err != nil {
return err
}
transport, ok := c.(*handshakeTransport)
if !ok {
return nil
}
if transport.bannerCallback != nil {
return transport.bannerCallback(msg.Message)
}
return nil
}
// KeyboardInteractiveChallenge should print questions, optionally // KeyboardInteractiveChallenge should print questions, optionally
// disabling echoing (e.g. for passwords), and return all the answers. // disabling echoing (e.g. for passwords), and return all the answers.
// Challenge may be called multiple times in a single session. After // Challenge may be called multiple times in a single session. After
@@ -385,7 +407,9 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
// like handleAuthResponse, but with less options. // like handleAuthResponse, but with less options.
switch packet[0] { switch packet[0] {
case msgUserAuthBanner: case msgUserAuthBanner:
// TODO: Print banners during userauth. if err := handleBannerResponse(c, packet); err != nil {
return false, nil, err
}
continue continue
case msgUserAuthInfoRequest: case msgUserAuthInfoRequest:
// OK // OK

View File

@@ -5,41 +5,77 @@
package ssh package ssh
import ( import (
"net"
"strings" "strings"
"testing" "testing"
) )
func testClientVersion(t *testing.T, config *ClientConfig, expected string) { func TestClientVersion(t *testing.T) {
clientConn, serverConn := net.Pipe() for _, tt := range []struct {
defer clientConn.Close() name string
receivedVersion := make(chan string, 1) version string
config.HostKeyCallback = InsecureIgnoreHostKey() multiLine string
go func() { wantErr bool
version, err := readVersion(serverConn) }{
if err != nil { {
receivedVersion <- "" name: "default version",
} else { version: packageVersion,
receivedVersion <- string(version) },
} {
serverConn.Close() name: "custom version",
}() version: "SSH-2.0-CustomClientVersionString",
NewClientConn(clientConn, "", config) },
actual := <-receivedVersion {
if actual != expected { name: "good multi line version",
t.Fatalf("got %s; want %s", actual, expected) version: packageVersion,
multiLine: strings.Repeat("ignored\r\n", 20),
},
{
name: "bad multi line version",
version: packageVersion,
multiLine: "bad multi line version",
wantErr: true,
},
{
name: "long multi line version",
version: packageVersion,
multiLine: strings.Repeat("long multi line version\r\n", 50)[:256],
wantErr: true,
},
} {
t.Run(tt.name, func(t *testing.T) {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go func() {
if tt.multiLine != "" {
c1.Write([]byte(tt.multiLine))
}
NewClientConn(c1, "", &ClientConfig{
ClientVersion: tt.version,
HostKeyCallback: InsecureIgnoreHostKey(),
})
c1.Close()
}()
conf := &ServerConfig{NoClientAuth: true}
conf.AddHostKey(testSigners["rsa"])
conn, _, _, err := NewServerConn(c2, conf)
if err == nil == tt.wantErr {
t.Fatalf("got err %v; wantErr %t", err, tt.wantErr)
}
if tt.wantErr {
// Don't verify the version on an expected error.
return
}
if got := string(conn.ClientVersion()); got != tt.version {
t.Fatalf("got %q; want %q", got, tt.version)
}
})
} }
} }
func TestCustomClientVersion(t *testing.T) {
version := "Test-Client-Version-0.0"
testClientVersion(t, &ClientConfig{ClientVersion: version}, version)
}
func TestDefaultClientVersion(t *testing.T) {
testClientVersion(t, &ClientConfig{}, packageVersion)
}
func TestHostKeyCheck(t *testing.T) { func TestHostKeyCheck(t *testing.T) {
for _, tt := range []struct { for _, tt := range []struct {
name string name string
@@ -79,3 +115,52 @@ func TestHostKeyCheck(t *testing.T) {
} }
} }
} }
func TestBannerCallback(t *testing.T) {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
serverConf := &ServerConfig{
PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
return &Permissions{}, nil
},
BannerCallback: func(conn ConnMetadata) string {
return "Hello World"
},
}
serverConf.AddHostKey(testSigners["rsa"])
go NewServerConn(c1, serverConf)
var receivedBanner string
var bannerCount int
clientConf := ClientConfig{
Auth: []AuthMethod{
Password("123"),
},
User: "user",
HostKeyCallback: InsecureIgnoreHostKey(),
BannerCallback: func(message string) error {
bannerCount++
receivedBanner = message
return nil
},
}
_, _, _, err = NewClientConn(c2, "", &clientConf)
if err != nil {
t.Fatal(err)
}
if bannerCount != 1 {
t.Errorf("got %d banners; want 1", bannerCount)
}
expected := "Hello World"
if receivedBanner != expected {
t.Fatalf("got %s; want %s", receivedBanner, expected)
}
}

View File

@@ -242,7 +242,7 @@ func (c *Config) SetDefaults() {
// buildDataSignedForAuth returns the data that is signed in order to prove // buildDataSignedForAuth returns the data that is signed in order to prove
// possession of a private key. See RFC 4252, section 7. // possession of a private key. See RFC 4252, section 7.
func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte { func buildDataSignedForAuth(sessionID []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
data := struct { data := struct {
Session []byte Session []byte
Type byte Type byte
@@ -253,7 +253,7 @@ func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubK
Algo []byte Algo []byte
PubKey []byte PubKey []byte
}{ }{
sessionId, sessionID,
msgUserAuthRequest, msgUserAuthRequest,
req.User, req.User,
req.Service, req.Service,

View File

@@ -78,6 +78,11 @@ type handshakeTransport struct {
dialAddress string dialAddress string
remoteAddr net.Addr remoteAddr net.Addr
// bannerCallback is non-empty if we are the client and it has been set in
// ClientConfig. In that case it is called during the user authentication
// dance to handle a custom server's message.
bannerCallback BannerCallback
// Algorithms agreed in the last key exchange. // Algorithms agreed in the last key exchange.
algorithms *algorithms algorithms *algorithms
@@ -120,6 +125,7 @@ func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byt
t.dialAddress = dialAddr t.dialAddress = dialAddr
t.remoteAddr = addr t.remoteAddr = addr
t.hostKeyCallback = config.HostKeyCallback t.hostKeyCallback = config.HostKeyCallback
t.bannerCallback = config.BannerCallback
if config.HostKeyAlgorithms != nil { if config.HostKeyAlgorithms != nil {
t.hostKeyAlgorithms = config.HostKeyAlgorithms t.hostKeyAlgorithms = config.HostKeyAlgorithms
} else { } else {

View File

@@ -119,7 +119,7 @@ func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handsha
return nil, err return nil, err
} }
kInt, err := group.diffieHellman(kexDHReply.Y, x) ki, err := group.diffieHellman(kexDHReply.Y, x)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -129,8 +129,8 @@ func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handsha
writeString(h, kexDHReply.HostKey) writeString(h, kexDHReply.HostKey)
writeInt(h, X) writeInt(h, X)
writeInt(h, kexDHReply.Y) writeInt(h, kexDHReply.Y)
K := make([]byte, intLength(kInt)) K := make([]byte, intLength(ki))
marshalInt(K, kInt) marshalInt(K, ki)
h.Write(K) h.Write(K)
return &kexResult{ return &kexResult{
@@ -164,7 +164,7 @@ func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handsha
} }
Y := new(big.Int).Exp(group.g, y, group.p) Y := new(big.Int).Exp(group.g, y, group.p)
kInt, err := group.diffieHellman(kexDHInit.X, y) ki, err := group.diffieHellman(kexDHInit.X, y)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -177,8 +177,8 @@ func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handsha
writeInt(h, kexDHInit.X) writeInt(h, kexDHInit.X)
writeInt(h, Y) writeInt(h, Y)
K := make([]byte, intLength(kInt)) K := make([]byte, intLength(ki))
marshalInt(K, kInt) marshalInt(K, ki)
h.Write(K) h.Write(K)
H := h.Sum(nil) H := h.Sum(nil)
@@ -462,9 +462,9 @@ func (kex *curve25519sha256) Client(c packetConn, rand io.Reader, magics *handsh
writeString(h, kp.pub[:]) writeString(h, kp.pub[:])
writeString(h, reply.EphemeralPubKey) writeString(h, reply.EphemeralPubKey)
kInt := new(big.Int).SetBytes(secret[:]) ki := new(big.Int).SetBytes(secret[:])
K := make([]byte, intLength(kInt)) K := make([]byte, intLength(ki))
marshalInt(K, kInt) marshalInt(K, ki)
h.Write(K) h.Write(K)
return &kexResult{ return &kexResult{
@@ -510,9 +510,9 @@ func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handsh
writeString(h, kexInit.ClientPubKey) writeString(h, kexInit.ClientPubKey)
writeString(h, kp.pub[:]) writeString(h, kp.pub[:])
kInt := new(big.Int).SetBytes(secret[:]) ki := new(big.Int).SetBytes(secret[:])
K := make([]byte, intLength(kInt)) K := make([]byte, intLength(ki))
marshalInt(K, kInt) marshalInt(K, ki)
h.Write(K) h.Write(K)
H := h.Sum(nil) H := h.Sum(nil)

View File

@@ -363,7 +363,7 @@ func (r *rsaPublicKey) CryptoPublicKey() crypto.PublicKey {
type dsaPublicKey dsa.PublicKey type dsaPublicKey dsa.PublicKey
func (r *dsaPublicKey) Type() string { func (k *dsaPublicKey) Type() string {
return "ssh-dss" return "ssh-dss"
} }
@@ -481,12 +481,12 @@ func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
type ecdsaPublicKey ecdsa.PublicKey type ecdsaPublicKey ecdsa.PublicKey
func (key *ecdsaPublicKey) Type() string { func (k *ecdsaPublicKey) Type() string {
return "ecdsa-sha2-" + key.nistID() return "ecdsa-sha2-" + k.nistID()
} }
func (key *ecdsaPublicKey) nistID() string { func (k *ecdsaPublicKey) nistID() string {
switch key.Params().BitSize { switch k.Params().BitSize {
case 256: case 256:
return "nistp256" return "nistp256"
case 384: case 384:
@@ -499,7 +499,7 @@ func (key *ecdsaPublicKey) nistID() string {
type ed25519PublicKey ed25519.PublicKey type ed25519PublicKey ed25519.PublicKey
func (key ed25519PublicKey) Type() string { func (k ed25519PublicKey) Type() string {
return KeyAlgoED25519 return KeyAlgoED25519
} }
@@ -518,23 +518,23 @@ func parseED25519(in []byte) (out PublicKey, rest []byte, err error) {
return (ed25519PublicKey)(key), w.Rest, nil return (ed25519PublicKey)(key), w.Rest, nil
} }
func (key ed25519PublicKey) Marshal() []byte { func (k ed25519PublicKey) Marshal() []byte {
w := struct { w := struct {
Name string Name string
KeyBytes []byte KeyBytes []byte
}{ }{
KeyAlgoED25519, KeyAlgoED25519,
[]byte(key), []byte(k),
} }
return Marshal(&w) return Marshal(&w)
} }
func (key ed25519PublicKey) Verify(b []byte, sig *Signature) error { func (k ed25519PublicKey) Verify(b []byte, sig *Signature) error {
if sig.Format != key.Type() { if sig.Format != k.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type()) return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
} }
edKey := (ed25519.PublicKey)(key) edKey := (ed25519.PublicKey)(k)
if ok := ed25519.Verify(edKey, b, sig.Blob); !ok { if ok := ed25519.Verify(edKey, b, sig.Blob); !ok {
return errors.New("ssh: signature did not verify") return errors.New("ssh: signature did not verify")
} }
@@ -595,9 +595,9 @@ func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) {
return (*ecdsaPublicKey)(key), w.Rest, nil return (*ecdsaPublicKey)(key), w.Rest, nil
} }
func (key *ecdsaPublicKey) Marshal() []byte { func (k *ecdsaPublicKey) Marshal() []byte {
// See RFC 5656, section 3.1. // See RFC 5656, section 3.1.
keyBytes := elliptic.Marshal(key.Curve, key.X, key.Y) keyBytes := elliptic.Marshal(k.Curve, k.X, k.Y)
// ECDSA publickey struct layout should match the struct used by // ECDSA publickey struct layout should match the struct used by
// parseECDSACert in the x/crypto/ssh/agent package. // parseECDSACert in the x/crypto/ssh/agent package.
w := struct { w := struct {
@@ -605,20 +605,20 @@ func (key *ecdsaPublicKey) Marshal() []byte {
ID string ID string
Key []byte Key []byte
}{ }{
key.Type(), k.Type(),
key.nistID(), k.nistID(),
keyBytes, keyBytes,
} }
return Marshal(&w) return Marshal(&w)
} }
func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error { func (k *ecdsaPublicKey) Verify(data []byte, sig *Signature) error {
if sig.Format != key.Type() { if sig.Format != k.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type()) return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
} }
h := ecHash(key.Curve).New() h := ecHash(k.Curve).New()
h.Write(data) h.Write(data)
digest := h.Sum(nil) digest := h.Sum(nil)
@@ -635,7 +635,7 @@ func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error {
return err return err
} }
if ecdsa.Verify((*ecdsa.PublicKey)(key), digest, ecSig.R, ecSig.S) { if ecdsa.Verify((*ecdsa.PublicKey)(k), digest, ecSig.R, ecSig.S) {
return nil return nil
} }
return errors.New("ssh: signature did not verify") return errors.New("ssh: signature did not verify")
@@ -758,7 +758,7 @@ func NewPublicKey(key interface{}) (PublicKey, error) {
return (*rsaPublicKey)(key), nil return (*rsaPublicKey)(key), nil
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
if !supportedEllipticCurve(key.Curve) { if !supportedEllipticCurve(key.Curve) {
return nil, errors.New("ssh: only P-256, P-384 and P-521 EC keys are supported.") return nil, errors.New("ssh: only P-256, P-384 and P-521 EC keys are supported")
} }
return (*ecdsaPublicKey)(key), nil return (*ecdsaPublicKey)(key), nil
case *dsa.PublicKey: case *dsa.PublicKey:

View File

@@ -108,8 +108,8 @@ func wildcardMatch(pat []byte, str []byte) bool {
} }
} }
func (l *hostPattern) match(a addr) bool { func (p *hostPattern) match(a addr) bool {
return wildcardMatch([]byte(l.addr.host), []byte(a.host)) && l.addr.port == a.port return wildcardMatch([]byte(p.addr.host), []byte(a.host)) && p.addr.port == a.port
} }
type keyDBLine struct { type keyDBLine struct {

View File

@@ -23,10 +23,6 @@ const (
msgUnimplemented = 3 msgUnimplemented = 3
msgDebug = 4 msgDebug = 4
msgNewKeys = 21 msgNewKeys = 21
// Standard authentication messages
msgUserAuthSuccess = 52
msgUserAuthBanner = 53
) )
// SSH messages: // SSH messages:
@@ -137,6 +133,18 @@ type userAuthFailureMsg struct {
PartialSuccess bool PartialSuccess bool
} }
// See RFC 4252, section 5.1
const msgUserAuthSuccess = 52
// See RFC 4252, section 5.4
const msgUserAuthBanner = 53
type userAuthBannerMsg struct {
Message string `sshtype:"53"`
// unused, but required to allow message parsing
Language string
}
// See RFC 4256, section 3.2 // See RFC 4256, section 3.2
const msgUserAuthInfoRequest = 60 const msgUserAuthInfoRequest = 60
const msgUserAuthInfoResponse = 61 const msgUserAuthInfoResponse = 61
@@ -154,7 +162,7 @@ const msgChannelOpen = 90
type channelOpenMsg struct { type channelOpenMsg struct {
ChanType string `sshtype:"90"` ChanType string `sshtype:"90"`
PeersId uint32 PeersID uint32
PeersWindow uint32 PeersWindow uint32
MaxPacketSize uint32 MaxPacketSize uint32
TypeSpecificData []byte `ssh:"rest"` TypeSpecificData []byte `ssh:"rest"`
@@ -165,7 +173,7 @@ const msgChannelData = 94
// Used for debug print outs of packets. // Used for debug print outs of packets.
type channelDataMsg struct { type channelDataMsg struct {
PeersId uint32 `sshtype:"94"` PeersID uint32 `sshtype:"94"`
Length uint32 Length uint32
Rest []byte `ssh:"rest"` Rest []byte `ssh:"rest"`
} }
@@ -174,8 +182,8 @@ type channelDataMsg struct {
const msgChannelOpenConfirm = 91 const msgChannelOpenConfirm = 91
type channelOpenConfirmMsg struct { type channelOpenConfirmMsg struct {
PeersId uint32 `sshtype:"91"` PeersID uint32 `sshtype:"91"`
MyId uint32 MyID uint32
MyWindow uint32 MyWindow uint32
MaxPacketSize uint32 MaxPacketSize uint32
TypeSpecificData []byte `ssh:"rest"` TypeSpecificData []byte `ssh:"rest"`
@@ -185,7 +193,7 @@ type channelOpenConfirmMsg struct {
const msgChannelOpenFailure = 92 const msgChannelOpenFailure = 92
type channelOpenFailureMsg struct { type channelOpenFailureMsg struct {
PeersId uint32 `sshtype:"92"` PeersID uint32 `sshtype:"92"`
Reason RejectionReason Reason RejectionReason
Message string Message string
Language string Language string
@@ -194,7 +202,7 @@ type channelOpenFailureMsg struct {
const msgChannelRequest = 98 const msgChannelRequest = 98
type channelRequestMsg struct { type channelRequestMsg struct {
PeersId uint32 `sshtype:"98"` PeersID uint32 `sshtype:"98"`
Request string Request string
WantReply bool WantReply bool
RequestSpecificData []byte `ssh:"rest"` RequestSpecificData []byte `ssh:"rest"`
@@ -204,28 +212,28 @@ type channelRequestMsg struct {
const msgChannelSuccess = 99 const msgChannelSuccess = 99
type channelRequestSuccessMsg struct { type channelRequestSuccessMsg struct {
PeersId uint32 `sshtype:"99"` PeersID uint32 `sshtype:"99"`
} }
// See RFC 4254, section 5.4. // See RFC 4254, section 5.4.
const msgChannelFailure = 100 const msgChannelFailure = 100
type channelRequestFailureMsg struct { type channelRequestFailureMsg struct {
PeersId uint32 `sshtype:"100"` PeersID uint32 `sshtype:"100"`
} }
// See RFC 4254, section 5.3 // See RFC 4254, section 5.3
const msgChannelClose = 97 const msgChannelClose = 97
type channelCloseMsg struct { type channelCloseMsg struct {
PeersId uint32 `sshtype:"97"` PeersID uint32 `sshtype:"97"`
} }
// See RFC 4254, section 5.3 // See RFC 4254, section 5.3
const msgChannelEOF = 96 const msgChannelEOF = 96
type channelEOFMsg struct { type channelEOFMsg struct {
PeersId uint32 `sshtype:"96"` PeersID uint32 `sshtype:"96"`
} }
// See RFC 4254, section 4 // See RFC 4254, section 4
@@ -255,7 +263,7 @@ type globalRequestFailureMsg struct {
const msgChannelWindowAdjust = 93 const msgChannelWindowAdjust = 93
type windowAdjustMsg struct { type windowAdjustMsg struct {
PeersId uint32 `sshtype:"93"` PeersID uint32 `sshtype:"93"`
AdditionalBytes uint32 AdditionalBytes uint32
} }

View File

@@ -278,7 +278,7 @@ func (m *mux) handleChannelOpen(packet []byte) error {
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
failMsg := channelOpenFailureMsg{ failMsg := channelOpenFailureMsg{
PeersId: msg.PeersId, PeersID: msg.PeersID,
Reason: ConnectionFailed, Reason: ConnectionFailed,
Message: "invalid request", Message: "invalid request",
Language: "en_US.UTF-8", Language: "en_US.UTF-8",
@@ -287,7 +287,7 @@ func (m *mux) handleChannelOpen(packet []byte) error {
} }
c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData) c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
c.remoteId = msg.PeersId c.remoteId = msg.PeersID
c.maxRemotePayload = msg.MaxPacketSize c.maxRemotePayload = msg.MaxPacketSize
c.remoteWin.add(msg.PeersWindow) c.remoteWin.add(msg.PeersWindow)
m.incomingChannels <- c m.incomingChannels <- c
@@ -313,7 +313,7 @@ func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
PeersWindow: ch.myWindow, PeersWindow: ch.myWindow,
MaxPacketSize: ch.maxIncomingPayload, MaxPacketSize: ch.maxIncomingPayload,
TypeSpecificData: extra, TypeSpecificData: extra,
PeersId: ch.localId, PeersID: ch.localId,
} }
if err := m.sendMessage(open); err != nil { if err := m.sendMessage(open); err != nil {
return nil, err return nil, err

View File

@@ -95,6 +95,10 @@ type ServerConfig struct {
// Note that RFC 4253 section 4.2 requires that this string start with // Note that RFC 4253 section 4.2 requires that this string start with
// "SSH-2.0-". // "SSH-2.0-".
ServerVersion string ServerVersion string
// BannerCallback, if present, is called and the return string is sent to
// the client after key exchange completed but before authentication.
BannerCallback func(conn ConnMetadata) string
} }
// AddHostKey adds a private key as a host key. If an existing host // AddHostKey adds a private key as a host key. If an existing host
@@ -252,7 +256,7 @@ func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error)
func isAcceptableAlgo(algo string) bool { func isAcceptableAlgo(algo string) bool {
switch algo { switch algo {
case KeyAlgoRSA, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, KeyAlgoED25519, case KeyAlgoRSA, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, KeyAlgoED25519,
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01: CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoED25519v01:
return true return true
} }
return false return false
@@ -312,6 +316,7 @@ func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, err
authFailures := 0 authFailures := 0
var authErrs []error var authErrs []error
var displayedBanner bool
userAuthLoop: userAuthLoop:
for { for {
@@ -343,6 +348,20 @@ userAuthLoop:
} }
s.user = userAuthReq.User s.user = userAuthReq.User
if !displayedBanner && config.BannerCallback != nil {
displayedBanner = true
msg := config.BannerCallback(s)
if msg != "" {
bannerMsg := &userAuthBannerMsg{
Message: msg,
}
if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil {
return nil, err
}
}
}
perms = nil perms = nil
authErr := errors.New("no auth passed yet") authErr := errors.New("no auth passed yet")

View File

@@ -406,7 +406,7 @@ func (s *Session) Wait() error {
s.stdinPipeWriter.Close() s.stdinPipeWriter.Close()
} }
var copyError error var copyError error
for _ = range s.copyFuncs { for range s.copyFuncs {
if err := <-s.errors; err != nil && copyError == nil { if err := <-s.errors; err != nil && copyError == nil {
copyError = err copyError = err
} }

View File

@@ -617,7 +617,7 @@ func writeWithCRLF(w io.Writer, buf []byte) (n int, err error) {
if _, err = w.Write(crlf); err != nil { if _, err = w.Write(crlf); err != nil {
return n, err return n, err
} }
n += 1 n++
buf = buf[1:] buf = buf[1:]
} }
} }

View File

@@ -17,44 +17,41 @@
package terminal // import "golang.org/x/crypto/ssh/terminal" package terminal // import "golang.org/x/crypto/ssh/terminal"
import ( import (
"syscall"
"unsafe"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// State contains the state of a terminal. // State contains the state of a terminal.
type State struct { type State struct {
termios syscall.Termios termios unix.Termios
} }
// IsTerminal returns true if the given file descriptor is a terminal. // IsTerminal returns true if the given file descriptor is a terminal.
func IsTerminal(fd int) bool { func IsTerminal(fd int) bool {
var termios syscall.Termios _, err := unix.IoctlGetTermios(fd, ioctlReadTermios)
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&termios)), 0, 0, 0) return err == nil
return err == 0
} }
// MakeRaw put the terminal connected to the given file descriptor into raw // MakeRaw put the terminal connected to the given file descriptor into raw
// mode and returns the previous state of the terminal so that it can be // mode and returns the previous state of the terminal so that it can be
// restored. // restored.
func MakeRaw(fd int) (*State, error) { func MakeRaw(fd int) (*State, error) {
var oldState State termios, err := unix.IoctlGetTermios(fd, ioctlReadTermios)
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState.termios)), 0, 0, 0); err != 0 { if err != nil {
return nil, err return nil, err
} }
newState := oldState.termios oldState := State{termios: *termios}
// This attempts to replicate the behaviour documented for cfmakeraw in // This attempts to replicate the behaviour documented for cfmakeraw in
// the termios(3) manpage. // the termios(3) manpage.
newState.Iflag &^= syscall.IGNBRK | syscall.BRKINT | syscall.PARMRK | syscall.ISTRIP | syscall.INLCR | syscall.IGNCR | syscall.ICRNL | syscall.IXON termios.Iflag &^= unix.IGNBRK | unix.BRKINT | unix.PARMRK | unix.ISTRIP | unix.INLCR | unix.IGNCR | unix.ICRNL | unix.IXON
newState.Oflag &^= syscall.OPOST termios.Oflag &^= unix.OPOST
newState.Lflag &^= syscall.ECHO | syscall.ECHONL | syscall.ICANON | syscall.ISIG | syscall.IEXTEN termios.Lflag &^= unix.ECHO | unix.ECHONL | unix.ICANON | unix.ISIG | unix.IEXTEN
newState.Cflag &^= syscall.CSIZE | syscall.PARENB termios.Cflag &^= unix.CSIZE | unix.PARENB
newState.Cflag |= syscall.CS8 termios.Cflag |= unix.CS8
newState.Cc[unix.VMIN] = 1 termios.Cc[unix.VMIN] = 1
newState.Cc[unix.VTIME] = 0 termios.Cc[unix.VTIME] = 0
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&newState)), 0, 0, 0); err != 0 { if err := unix.IoctlSetTermios(fd, ioctlWriteTermios, termios); err != nil {
return nil, err return nil, err
} }
@@ -64,59 +61,55 @@ func MakeRaw(fd int) (*State, error) {
// GetState returns the current state of a terminal which may be useful to // GetState returns the current state of a terminal which may be useful to
// restore the terminal after a signal. // restore the terminal after a signal.
func GetState(fd int) (*State, error) { func GetState(fd int) (*State, error) {
var oldState State termios, err := unix.IoctlGetTermios(fd, ioctlReadTermios)
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState.termios)), 0, 0, 0); err != 0 { if err != nil {
return nil, err return nil, err
} }
return &oldState, nil return &State{termios: *termios}, nil
} }
// Restore restores the terminal connected to the given file descriptor to a // Restore restores the terminal connected to the given file descriptor to a
// previous state. // previous state.
func Restore(fd int, state *State) error { func Restore(fd int, state *State) error {
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&state.termios)), 0, 0, 0); err != 0 { return unix.IoctlSetTermios(fd, ioctlWriteTermios, &state.termios)
return err
}
return nil
} }
// GetSize returns the dimensions of the given terminal. // GetSize returns the dimensions of the given terminal.
func GetSize(fd int) (width, height int, err error) { func GetSize(fd int) (width, height int, err error) {
var dimensions [4]uint16 ws, err := unix.IoctlGetWinsize(fd, unix.TIOCGWINSZ)
if err != nil {
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), uintptr(syscall.TIOCGWINSZ), uintptr(unsafe.Pointer(&dimensions)), 0, 0, 0); err != 0 {
return -1, -1, err return -1, -1, err
} }
return int(dimensions[1]), int(dimensions[0]), nil return int(ws.Col), int(ws.Row), nil
} }
// passwordReader is an io.Reader that reads from a specific file descriptor. // passwordReader is an io.Reader that reads from a specific file descriptor.
type passwordReader int type passwordReader int
func (r passwordReader) Read(buf []byte) (int, error) { func (r passwordReader) Read(buf []byte) (int, error) {
return syscall.Read(int(r), buf) return unix.Read(int(r), buf)
} }
// ReadPassword reads a line of input from a terminal without local echo. This // ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice // is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n. // returned does not include the \n.
func ReadPassword(fd int) ([]byte, error) { func ReadPassword(fd int) ([]byte, error) {
var oldState syscall.Termios termios, err := unix.IoctlGetTermios(fd, ioctlReadTermios)
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0); err != 0 { if err != nil {
return nil, err return nil, err
} }
newState := oldState newState := *termios
newState.Lflag &^= syscall.ECHO newState.Lflag &^= unix.ECHO
newState.Lflag |= syscall.ICANON | syscall.ISIG newState.Lflag |= unix.ICANON | unix.ISIG
newState.Iflag |= syscall.ICRNL newState.Iflag |= unix.ICRNL
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&newState)), 0, 0, 0); err != 0 { if err := unix.IoctlSetTermios(fd, ioctlWriteTermios, &newState); err != nil {
return nil, err return nil, err
} }
defer func() { defer func() {
syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0) unix.IoctlSetTermios(fd, ioctlWriteTermios, termios)
}() }()
return readPasswordLine(passwordReader(fd)) return readPasswordLine(passwordReader(fd))

View File

@@ -17,6 +17,8 @@
package terminal package terminal
import ( import (
"os"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
) )
@@ -71,13 +73,6 @@ func GetSize(fd int) (width, height int, err error) {
return int(info.Size.X), int(info.Size.Y), nil return int(info.Size.X), int(info.Size.Y), nil
} }
// passwordReader is an io.Reader that reads from a specific Windows HANDLE.
type passwordReader int
func (r passwordReader) Read(buf []byte) (int, error) {
return windows.Read(windows.Handle(r), buf)
}
// ReadPassword reads a line of input from a terminal without local echo. This // ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice // is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n. // returned does not include the \n.
@@ -98,5 +93,5 @@ func ReadPassword(fd int) ([]byte, error) {
windows.SetConsoleMode(windows.Handle(fd), old) windows.SetConsoleMode(windows.Handle(fd), old)
}() }()
return readPasswordLine(passwordReader(fd)) return readPasswordLine(os.NewFile(uintptr(fd), "stdin"))
} }

32
vendor/golang.org/x/crypto/ssh/test/banner_test.go generated vendored Normal file
View File

@@ -0,0 +1,32 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin dragonfly freebsd linux netbsd openbsd
package test
import (
"testing"
)
func TestBannerCallbackAgainstOpenSSH(t *testing.T) {
server := newServer(t)
defer server.Shutdown()
clientConf := clientConfig()
var receivedBanner string
clientConf.BannerCallback = func(message string) error {
receivedBanner = message
return nil
}
conn := server.Dial(clientConf)
defer conn.Close()
expected := "Server Banner"
if receivedBanner != expected {
t.Fatalf("got %v; want %v", receivedBanner, expected)
}
}

View File

@@ -2,6 +2,6 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This package contains integration tests for the // Package test contains integration tests for the
// golang.org/x/crypto/ssh package. // golang.org/x/crypto/ssh package.
package test // import "golang.org/x/crypto/ssh/test" package test // import "golang.org/x/crypto/ssh/test"

View File

@@ -333,21 +333,22 @@ func TestCiphers(t *testing.T) {
cipherOrder = append(cipherOrder, "aes128-cbc", "3des-cbc") cipherOrder = append(cipherOrder, "aes128-cbc", "3des-cbc")
for _, ciph := range cipherOrder { for _, ciph := range cipherOrder {
server := newServer(t) t.Run(ciph, func(t *testing.T) {
defer server.Shutdown() server := newServer(t)
conf := clientConfig() defer server.Shutdown()
conf.Ciphers = []string{ciph} conf := clientConfig()
// Don't fail if sshd doesn't have the cipher. conf.Ciphers = []string{ciph}
conf.Ciphers = append(conf.Ciphers, cipherOrder...) // Don't fail if sshd doesn't have the cipher.
conn, err := server.TryDial(conf) conf.Ciphers = append(conf.Ciphers, cipherOrder...)
if err == nil { conn, err := server.TryDial(conf)
conn.Close() if err == nil {
} else { conn.Close()
t.Fatalf("failed for cipher %q", ciph) } else {
} t.Fatalf("failed for cipher %q", ciph)
}
})
} }
} }
func TestMACs(t *testing.T) { func TestMACs(t *testing.T) {
var config ssh.Config var config ssh.Config
config.SetDefaults() config.SetDefaults()

View File

@@ -25,8 +25,9 @@ import (
"golang.org/x/crypto/ssh/testdata" "golang.org/x/crypto/ssh/testdata"
) )
const sshd_config = ` const sshdConfig = `
Protocol 2 Protocol 2
Banner {{.Dir}}/banner
HostKey {{.Dir}}/id_rsa HostKey {{.Dir}}/id_rsa
HostKey {{.Dir}}/id_dsa HostKey {{.Dir}}/id_dsa
HostKey {{.Dir}}/id_ecdsa HostKey {{.Dir}}/id_ecdsa
@@ -50,7 +51,7 @@ HostbasedAuthentication no
PubkeyAcceptedKeyTypes=* PubkeyAcceptedKeyTypes=*
` `
var configTmpl = template.Must(template.New("").Parse(sshd_config)) var configTmpl = template.Must(template.New("").Parse(sshdConfig))
type server struct { type server struct {
t *testing.T t *testing.T
@@ -256,6 +257,8 @@ func newServer(t *testing.T) *server {
} }
f.Close() f.Close()
writeFile(filepath.Join(dir, "banner"), []byte("Server Banner"))
for k, v := range testdata.PEMBytes { for k, v := range testdata.PEMBytes {
filename := "id_" + k filename := "id_" + k
writeFile(filepath.Join(dir, filename), v) writeFile(filepath.Join(dir, filename), v)
@@ -268,7 +271,7 @@ func newServer(t *testing.T) *server {
} }
var authkeys bytes.Buffer var authkeys bytes.Buffer
for k, _ := range testdata.PEMBytes { for k := range testdata.PEMBytes {
authkeys.Write(ssh.MarshalAuthorizedKey(testPublicKeys[k])) authkeys.Write(ssh.MarshalAuthorizedKey(testPublicKeys[k]))
} }
writeFile(filepath.Join(dir, "authorized_keys"), authkeys.Bytes()) writeFile(filepath.Join(dir, "authorized_keys"), authkeys.Bytes())

View File

@@ -23,6 +23,27 @@ MHcCAQEEINGWx0zo6fhJ/0EAfrPzVFyFC9s18lBt3cRoEDhS3ARooAoGCCqGSM49
AwEHoUQDQgAEi9Hdw6KvZcWxfg2IDhA7UkpDtzzt6ZqJXSsFdLd+Kx4S3Sx4cVO+ AwEHoUQDQgAEi9Hdw6KvZcWxfg2IDhA7UkpDtzzt6ZqJXSsFdLd+Kx4S3Sx4cVO+
6/ZOXRnPmNAlLUqjShUsUBBngG0u2fqEqA== 6/ZOXRnPmNAlLUqjShUsUBBngG0u2fqEqA==
-----END EC PRIVATE KEY----- -----END EC PRIVATE KEY-----
`),
"ecdsap256": []byte(`-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIAPCE25zK0PQSnsgVcEbM1mbKTASH4pqb5QJajplDwDZoAoGCCqGSM49
AwEHoUQDQgAEWy8TxGcIHRh5XGpO4dFVfDjeNY+VkgubQrf/eyFJZHxAn1SKraXU
qJUjTKj1z622OxYtJ5P7s9CfAEVsTzLCzg==
-----END EC PRIVATE KEY-----
`),
"ecdsap384": []byte(`-----BEGIN EC PRIVATE KEY-----
MIGkAgEBBDBWfSnMuNKq8J9rQLzzEkx3KAoEohSXqhE/4CdjEYtoU2i22HW80DDS
qQhYNHRAduygBwYFK4EEACKhZANiAAQWaDMAd0HUd8ZiXCX7mYDDnC54gwH/nG43
VhCUEYmF7HMZm/B9Yn3GjFk3qYEDEvuF/52+NvUKBKKaLbh32AWxMv0ibcoba4cz
hL9+hWYhUD9XIUlzMWiZ2y6eBE9PdRI=
-----END EC PRIVATE KEY-----
`),
"ecdsap521": []byte(`-----BEGIN EC PRIVATE KEY-----
MIHcAgEBBEIBrkYpQcy8KTVHNiAkjlFZwee90224Bu6wz94R4OBo+Ts0eoAQG7SF
iaygEDMUbx6kTgXTBcKZ0jrWPKakayNZ/kigBwYFK4EEACOhgYkDgYYABADFuvLV
UoaCDGHcw5uNfdRIsvaLKuWSpLsl48eWGZAwdNG432GDVKduO+pceuE+8XzcyJb+
uMv+D2b11Q/LQUcHJwE6fqbm8m3EtDKPsoKs0u/XUJb0JsH4J8lkZzbUTjvGYamn
FFlRjzoB3Oxu8UQgb+MWPedtH9XYBbg9biz4jJLkXQ==
-----END EC PRIVATE KEY-----
`), `),
"rsa": []byte(`-----BEGIN RSA PRIVATE KEY----- "rsa": []byte(`-----BEGIN RSA PRIVATE KEY-----
MIICXAIBAAKBgQC8A6FGHDiWCSREAXCq6yBfNVr0xCVG2CzvktFNRpue+RXrGs/2 MIICXAIBAAKBgQC8A6FGHDiWCSREAXCq6yBfNVr0xCVG2CzvktFNRpue+RXrGs/2

View File

@@ -6,6 +6,7 @@ package ssh
import ( import (
"bufio" "bufio"
"bytes"
"errors" "errors"
"io" "io"
"log" "log"
@@ -76,17 +77,17 @@ type connectionState struct {
// both directions are triggered by reading and writing a msgNewKey packet // both directions are triggered by reading and writing a msgNewKey packet
// respectively. // respectively.
func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error { func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error {
if ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult); err != nil { ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult)
if err != nil {
return err return err
} else {
t.reader.pendingKeyChange <- ciph
} }
t.reader.pendingKeyChange <- ciph
if ciph, err := newPacketCipher(t.writer.dir, algs.w, kexResult); err != nil { ciph, err = newPacketCipher(t.writer.dir, algs.w, kexResult)
if err != nil {
return err return err
} else {
t.writer.pendingKeyChange <- ciph
} }
t.writer.pendingKeyChange <- ciph
return nil return nil
} }
@@ -139,7 +140,7 @@ func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
case cipher := <-s.pendingKeyChange: case cipher := <-s.pendingKeyChange:
s.packetCipher = cipher s.packetCipher = cipher
default: default:
return nil, errors.New("ssh: got bogus newkeys message.") return nil, errors.New("ssh: got bogus newkeys message")
} }
case msgDisconnect: case msgDisconnect:
@@ -254,7 +255,7 @@ func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (pac
iv, key, macKey := generateKeys(d, algs, kex) iv, key, macKey := generateKeys(d, algs, kex)
if algs.Cipher == gcmCipherID { if algs.Cipher == gcmCipherID {
return newGCMCipher(iv, key, macKey) return newGCMCipher(iv, key)
} }
if algs.Cipher == aes128cbcID { if algs.Cipher == aes128cbcID {
@@ -342,7 +343,7 @@ func readVersion(r io.Reader) ([]byte, error) {
var ok bool var ok bool
var buf [1]byte var buf [1]byte
for len(versionString) < maxVersionStringBytes { for length := 0; length < maxVersionStringBytes; length++ {
_, err := io.ReadFull(r, buf[:]) _, err := io.ReadFull(r, buf[:])
if err != nil { if err != nil {
return nil, err return nil, err
@@ -350,6 +351,13 @@ func readVersion(r io.Reader) ([]byte, error) {
// The RFC says that the version should be terminated with \r\n // The RFC says that the version should be terminated with \r\n
// but several SSH servers actually only send a \n. // but several SSH servers actually only send a \n.
if buf[0] == '\n' { if buf[0] == '\n' {
if !bytes.HasPrefix(versionString, []byte("SSH-")) {
// RFC 4253 says we need to ignore all version string lines
// except the one containing the SSH version (provided that
// all the lines do not exceed 255 bytes in total).
versionString = versionString[:0]
continue
}
ok = true ok = true
break break
} }

View File

@@ -13,11 +13,13 @@ import (
) )
func TestReadVersion(t *testing.T) { func TestReadVersion(t *testing.T) {
longversion := strings.Repeat("SSH-2.0-bla", 50)[:253] longVersion := strings.Repeat("SSH-2.0-bla", 50)[:253]
multiLineVersion := strings.Repeat("ignored\r\n", 20) + "SSH-2.0-bla\r\n"
cases := map[string]string{ cases := map[string]string{
"SSH-2.0-bla\r\n": "SSH-2.0-bla", "SSH-2.0-bla\r\n": "SSH-2.0-bla",
"SSH-2.0-bla\n": "SSH-2.0-bla", "SSH-2.0-bla\n": "SSH-2.0-bla",
longversion + "\r\n": longversion, multiLineVersion: "SSH-2.0-bla",
longVersion + "\r\n": longVersion,
} }
for in, want := range cases { for in, want := range cases {
@@ -33,9 +35,11 @@ func TestReadVersion(t *testing.T) {
} }
func TestReadVersionError(t *testing.T) { func TestReadVersionError(t *testing.T) {
longversion := strings.Repeat("SSH-2.0-bla", 50)[:253] longVersion := strings.Repeat("SSH-2.0-bla", 50)[:253]
multiLineVersion := strings.Repeat("ignored\r\n", 50) + "SSH-2.0-bla\r\n"
cases := []string{ cases := []string{
longversion + "too-long\r\n", longVersion + "too-long\r\n",
multiLineVersion,
} }
for _, in := range cases { for _, in := range cases {
if _, err := readVersion(bytes.NewBufferString(in)); err == nil { if _, err := readVersion(bytes.NewBufferString(in)); err == nil {
@@ -60,7 +64,7 @@ func TestExchangeVersionsBasic(t *testing.T) {
func TestExchangeVersions(t *testing.T) { func TestExchangeVersions(t *testing.T) {
cases := []string{ cases := []string{
"not\x000allowed", "not\x000allowed",
"not allowed\n", "not allowed\x01\r\n",
} }
for _, c := range cases { for _, c := range cases {
buf := bytes.NewBufferString("SSH-2.0-bla\r\n") buf := bytes.NewBufferString("SSH-2.0-bla\r\n")

View File

@@ -5,7 +5,6 @@
// Package tea implements the TEA algorithm, as defined in Needham and // Package tea implements the TEA algorithm, as defined in Needham and
// Wheeler's 1994 technical report, “TEA, a Tiny Encryption Algorithm”. See // Wheeler's 1994 technical report, “TEA, a Tiny Encryption Algorithm”. See
// http://www.cix.co.uk/~klockstone/tea.pdf for details. // http://www.cix.co.uk/~klockstone/tea.pdf for details.
package tea package tea
import ( import (

View File

@@ -69,7 +69,7 @@ func initCipher(c *Cipher, key []byte) {
// Precalculate the table // Precalculate the table
const delta = 0x9E3779B9 const delta = 0x9E3779B9
var sum uint32 = 0 var sum uint32
// Two rounds of XTEA applied per loop // Two rounds of XTEA applied per loop
for i := 0; i < numRounds; { for i := 0; i < numRounds; {

View File

@@ -25,3 +25,7 @@ func Clearenv() {
func Environ() []string { func Environ() []string {
return syscall.Environ() return syscall.Environ()
} }
func Unsetenv(key string) error {
return syscall.Unsetenv(key)
}

View File

@@ -1,14 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.4
package plan9
import "syscall"
func Unsetenv(key string) error {
// This was added in Go 1.4.
return syscall.Unsetenv(key)
}

Some files were not shown because too many files have changed in this diff Show More