Merge branch 'release/container'
/ docker (push) Successful in 1m19s
Details
/ docker (push) Successful in 1m19s
Details
This commit is contained in:
commit
1bef820f7a
|
@ -0,0 +1,44 @@
|
||||||
|
---
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
- develop
|
||||||
|
|
||||||
|
env:
|
||||||
|
DOCKER_REGISTRY: source.toby3d.me
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
docker:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
container:
|
||||||
|
image: catthehacker/ubuntu:act-latest
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
packages: write
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: https://gitea.com/actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Set up QEMU
|
||||||
|
uses: https://gitea.com/docker/setup-qemu-action@v2
|
||||||
|
|
||||||
|
- name: Set up Docker BuildX
|
||||||
|
uses: https://gitea.com/docker/setup-buildx-action@v2
|
||||||
|
|
||||||
|
- name: Login to registry
|
||||||
|
uses: https://gitea.com/docker/login-action@v2
|
||||||
|
with:
|
||||||
|
registry: ${{ env.DOCKER_REGISTRY }}
|
||||||
|
username: ${{ gitea.repository_owner }}
|
||||||
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
|
||||||
|
- name: Build and push
|
||||||
|
uses: https://gitea.com/docker/build-push-action@v4
|
||||||
|
env:
|
||||||
|
ACTIONS_RUNTIME_TOKEN: "" # See https://gitea.com/gitea/act_runner/issues/119
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
file: ./build/Dockerfile
|
||||||
|
push: true
|
||||||
|
tags: ${{ env.DOCKER_REGISTRY }}/${{ gitea.repository }}:${{ gitea.ref_name }}
|
|
@ -0,0 +1,28 @@
|
||||||
|
#!/usr/bin/make -f
|
||||||
|
SHELL = /bin/sh
|
||||||
|
|
||||||
|
#### Start of system configuration section. ####
|
||||||
|
|
||||||
|
srcdir = .
|
||||||
|
|
||||||
|
GO ?= go
|
||||||
|
GOFLAGS ?= -buildvcs=true
|
||||||
|
EXECUTABLE ?= hub
|
||||||
|
|
||||||
|
#### End of system configuration section. ####
|
||||||
|
|
||||||
|
.PHONY: all
|
||||||
|
all: main.go
|
||||||
|
$(GO) build -v $(GOFLAGS) -o $(EXECUTABLE)
|
||||||
|
|
||||||
|
.PHONY: clean
|
||||||
|
clean: ## Delete all files in the current directory that are normally created by building the program
|
||||||
|
$(GO) clean
|
||||||
|
|
||||||
|
.PHONY: check
|
||||||
|
check: ## Perform self-tests
|
||||||
|
$(GO) test -v -cover -failfast -short -shuffle=on $(GOFLAGS) $(srcdir)/...
|
||||||
|
|
||||||
|
.PHONY: help
|
||||||
|
help: ## Display this help screen
|
||||||
|
@grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
|
@ -0,0 +1,28 @@
|
||||||
|
# syntax=docker/dockerfile:1
|
||||||
|
# docker build --rm -f build/Dockerfile -t source.toby3d.me/toby3d/hub .
|
||||||
|
|
||||||
|
# Build
|
||||||
|
FROM golang:alpine AS builder
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
ENV CGO_ENABLED=0
|
||||||
|
ENV GOFLAGS=-mod=vendor
|
||||||
|
|
||||||
|
COPY go.mod go.sum *.go ./
|
||||||
|
COPY internal ./internal/
|
||||||
|
COPY vendor ./vendor/
|
||||||
|
COPY web ./web/
|
||||||
|
|
||||||
|
RUN go build -o ./hub
|
||||||
|
|
||||||
|
# Run
|
||||||
|
FROM scratch
|
||||||
|
|
||||||
|
WORKDIR /
|
||||||
|
|
||||||
|
COPY --from=builder /app/hub /hub
|
||||||
|
|
||||||
|
EXPOSE 3000
|
||||||
|
|
||||||
|
ENTRYPOINT ["/hub"]
|
|
@ -0,0 +1,49 @@
|
||||||
|
---
|
||||||
|
kind: "pipeline"
|
||||||
|
type: "docker"
|
||||||
|
name: "default"
|
||||||
|
|
||||||
|
environment:
|
||||||
|
CGO_ENABLED: 0
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: "test"
|
||||||
|
image: "golang:1.20"
|
||||||
|
volumes:
|
||||||
|
- name: "modules"
|
||||||
|
path: "/go/pkg/mod"
|
||||||
|
commands:
|
||||||
|
- "make check"
|
||||||
|
|
||||||
|
- name: "build"
|
||||||
|
image: "golang:1.20"
|
||||||
|
volumes:
|
||||||
|
- name: "modules"
|
||||||
|
path: "/go/pkg/mod"
|
||||||
|
commands:
|
||||||
|
- "make"
|
||||||
|
depends_on:
|
||||||
|
- "test"
|
||||||
|
|
||||||
|
- name: "delivery"
|
||||||
|
image: "drillster/drone-rsync"
|
||||||
|
settings:
|
||||||
|
hosts:
|
||||||
|
from_secret: "SSH_HOST_IP"
|
||||||
|
key:
|
||||||
|
from_secret: "SSH_PRIVATE_KEY"
|
||||||
|
source: "./hub"
|
||||||
|
target: "/etc/hub/hub"
|
||||||
|
prescript:
|
||||||
|
- "systemctl stop hub"
|
||||||
|
script:
|
||||||
|
- "systemctl start hub"
|
||||||
|
depends_on:
|
||||||
|
- build
|
||||||
|
when:
|
||||||
|
branch:
|
||||||
|
- master
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
- name: modules
|
||||||
|
temp: {}
|
|
@ -0,0 +1,68 @@
|
||||||
|
// Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT.
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"golang.org/x/text/language"
|
||||||
|
"golang.org/x/text/message"
|
||||||
|
"golang.org/x/text/message/catalog"
|
||||||
|
)
|
||||||
|
|
||||||
|
type dictionary struct {
|
||||||
|
index []uint32
|
||||||
|
data string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dictionary) Lookup(key string) (data string, ok bool) {
|
||||||
|
p, ok := messageKeyToIndex[key]
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
start, end := d.index[p], d.index[p+1]
|
||||||
|
if start == end {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return d.data[start:end], true
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
dict := map[string]catalog.Dictionary{
|
||||||
|
"en": &dictionary{index: enIndex, data: enData},
|
||||||
|
"ru": &dictionary{index: ruIndex, data: ruData},
|
||||||
|
}
|
||||||
|
fallback := language.MustParse("en")
|
||||||
|
cat, err := catalog.NewFromMap(dict, catalog.Fallback(fallback))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
message.DefaultCatalog = cat
|
||||||
|
}
|
||||||
|
|
||||||
|
var messageKeyToIndex = map[string]int{
|
||||||
|
"%d subscribers": 5,
|
||||||
|
"%s logo": 1,
|
||||||
|
"Dead simple WebSub hub": 2,
|
||||||
|
"How to publish and consume?": 3,
|
||||||
|
"What the spec?": 4,
|
||||||
|
"version": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
var enIndex = []uint32{ // 7 elements
|
||||||
|
0x00000000, 0x00000008, 0x00000013, 0x0000002a,
|
||||||
|
0x00000046, 0x00000055, 0x00000067,
|
||||||
|
} // Size: 52 bytes
|
||||||
|
|
||||||
|
const enData string = "" + // Size: 103 bytes
|
||||||
|
"\x02version\x02%[1]s logo\x02Dead simple WebSub hub\x02How to publish an" +
|
||||||
|
"d consume?\x02What the spec?\x02%[1]d subscribers"
|
||||||
|
|
||||||
|
var ruIndex = []uint32{ // 7 elements
|
||||||
|
0x00000000, 0x0000000d, 0x00000022, 0x00000045,
|
||||||
|
0x0000007a, 0x00000090, 0x000000ad,
|
||||||
|
} // Size: 52 bytes
|
||||||
|
|
||||||
|
const ruData string = "" + // Size: 173 bytes
|
||||||
|
"\x02версия\x02логотип %[1]s\x02Простейший хаб WebSub\x02Как публиковать " +
|
||||||
|
"и принимать?\x02В чём спека?\x02%[1]d подписчиков"
|
||||||
|
|
||||||
|
// Total table size 380 bytes (0KiB); checksum: 7D8C2E8B
|
|
@ -0,0 +1,2 @@
|
||||||
|
# WebSub [![Build Status](https://drone.toby3d.me/api/badges/toby3d/hub/status.svg)](https://drone.toby3d.me/toby3d/hub)
|
||||||
|
> Personal WebSub hub
|
35
go.mod
35
go.mod
|
@ -1,3 +1,36 @@
|
||||||
module source.toby3d.me/toby3d/hub
|
module source.toby3d.me/toby3d/hub
|
||||||
|
|
||||||
go 1.18
|
go 1.20
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/DATA-DOG/go-sqlmock v1.5.0
|
||||||
|
github.com/caarlos0/env/v7 v7.1.0
|
||||||
|
github.com/go-logfmt/logfmt v0.6.0
|
||||||
|
github.com/jmoiron/sqlx v1.3.5
|
||||||
|
github.com/valyala/quicktemplate v1.7.0
|
||||||
|
golang.org/x/text v0.8.0
|
||||||
|
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2
|
||||||
|
modernc.org/sqlite v1.21.0
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
|
github.com/google/uuid v1.3.0 // indirect
|
||||||
|
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
|
||||||
|
github.com/lib/pq v1.10.6 // indirect
|
||||||
|
github.com/mattn/go-isatty v0.0.17 // indirect
|
||||||
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
|
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||||
|
golang.org/x/mod v0.9.0 // indirect
|
||||||
|
golang.org/x/sys v0.6.0 // indirect
|
||||||
|
golang.org/x/tools v0.7.0 // indirect
|
||||||
|
lukechampine.com/uint128 v1.3.0 // indirect
|
||||||
|
modernc.org/cc/v3 v3.40.0 // indirect
|
||||||
|
modernc.org/ccgo/v3 v3.16.13 // indirect
|
||||||
|
modernc.org/libc v1.22.3 // indirect
|
||||||
|
modernc.org/mathutil v1.5.0 // indirect
|
||||||
|
modernc.org/memory v1.5.0 // indirect
|
||||||
|
modernc.org/opt v0.1.3 // indirect
|
||||||
|
modernc.org/strutil v1.1.3 // indirect
|
||||||
|
modernc.org/token v1.1.0 // indirect
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,86 @@
|
||||||
|
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
|
||||||
|
github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
|
||||||
|
github.com/andybalholm/brotli v1.0.2/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y=
|
||||||
|
github.com/andybalholm/brotli v1.0.3/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
||||||
|
github.com/caarlos0/env/v7 v7.1.0 h1:9lzTF5amyQeWHZzuZeKlCb5FWSUxpG1js43mhbY8ozg=
|
||||||
|
github.com/caarlos0/env/v7 v7.1.0/go.mod h1:LPPWniDUq4JaO6Q41vtlyikhMknqymCLBw0eX4dcH1E=
|
||||||
|
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||||
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
|
github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4=
|
||||||
|
github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
|
||||||
|
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
|
||||||
|
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
||||||
|
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||||
|
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||||
|
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
|
||||||
|
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||||
|
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g=
|
||||||
|
github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ=
|
||||||
|
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs=
|
||||||
|
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8=
|
||||||
|
github.com/klauspost/compress v1.13.4/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg=
|
||||||
|
github.com/klauspost/compress v1.13.5/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
|
||||||
|
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||||
|
github.com/lib/pq v1.10.6 h1:jbk+ZieJ0D7EVGJYpL9QTz7/YW6UHbmdnZWYyK5cdBs=
|
||||||
|
github.com/lib/pq v1.10.6/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||||
|
github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng=
|
||||||
|
github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||||
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
|
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||||
|
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||||
|
github.com/valyala/fasthttp v1.30.0/go.mod h1:2rsYD01CKFrjjsvFxx75KlEUNpWNBY9JWD3K/7o2Cus=
|
||||||
|
github.com/valyala/quicktemplate v1.7.0 h1:LUPTJmlVcb46OOUY3IeD9DojFpAVbsG+5WFTcjMJzCM=
|
||||||
|
github.com/valyala/quicktemplate v1.7.0/go.mod h1:sqKJnoaOF88V07vkO+9FL8fb9uZg/VPSJnLYn+LmLk8=
|
||||||
|
github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc=
|
||||||
|
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
|
||||||
|
golang.org/x/mod v0.9.0 h1:KENHtAZL2y3NLMYZeHY9DW8HW8V+kQyJsY/V9JlKvCs=
|
||||||
|
golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||||
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
|
golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||||
|
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
|
||||||
|
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ=
|
||||||
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
|
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
|
golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68=
|
||||||
|
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||||
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
|
golang.org/x/tools v0.7.0 h1:W4OVu8VVOaIO0yzWMNdepAulS7YfoS3Zabrm8DOXXU4=
|
||||||
|
golang.org/x/tools v0.7.0/go.mod h1:4pg6aUX35JBAogB10C9AtvVL+qowtN4pT3CGSQex14s=
|
||||||
|
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 h1:H2TDz8ibqkAF6YGhCdN3jS9O0/s90v0rJh3X/OLHEUk=
|
||||||
|
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8=
|
||||||
|
lukechampine.com/uint128 v1.3.0 h1:cDdUVfRwDUDovz610ABgFD17nXD4/uDgVHl2sC3+sbo=
|
||||||
|
lukechampine.com/uint128 v1.3.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk=
|
||||||
|
modernc.org/cc/v3 v3.40.0 h1:P3g79IUS/93SYhtoeaHW+kRCIrYaxJ27MFPv+7kaTOw=
|
||||||
|
modernc.org/cc/v3 v3.40.0/go.mod h1:/bTg4dnWkSXowUO6ssQKnOV0yMVxDYNIsIrzqTFDGH0=
|
||||||
|
modernc.org/ccgo/v3 v3.16.13 h1:Mkgdzl46i5F/CNR/Kj80Ri59hC8TKAhZrYSaqvkwzUw=
|
||||||
|
modernc.org/ccgo/v3 v3.16.13/go.mod h1:2Quk+5YgpImhPjv2Qsob1DnZ/4som1lJTodubIcoUkY=
|
||||||
|
modernc.org/ccorpus v1.11.6 h1:J16RXiiqiCgua6+ZvQot4yUuUy8zxgqbqEEUuGPlISk=
|
||||||
|
modernc.org/httpfs v1.0.6 h1:AAgIpFZRXuYnkjftxTAZwMIiwEqAfk8aVB2/oA6nAeM=
|
||||||
|
modernc.org/libc v1.22.3 h1:D/g6O5ftAfavceqlLOFwaZuA5KYafKwmr30A6iSqoyY=
|
||||||
|
modernc.org/libc v1.22.3/go.mod h1:MQrloYP209xa2zHome2a8HLiLm6k0UT8CoHpV74tOFw=
|
||||||
|
modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ=
|
||||||
|
modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E=
|
||||||
|
modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds=
|
||||||
|
modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU=
|
||||||
|
modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4=
|
||||||
|
modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0=
|
||||||
|
modernc.org/sqlite v1.21.0 h1:4aP4MdUf15i3R3M2mx6Q90WHKz3nZLoz96zlB6tNdow=
|
||||||
|
modernc.org/sqlite v1.21.0/go.mod h1:XwQ0wZPIh1iKb5mkvCJ3szzbhk+tykC8ZWqTRTgYRwI=
|
||||||
|
modernc.org/strutil v1.1.3 h1:fNMm+oJklMGYfU9Ylcywl0CO5O6nTfaowNsh2wpPjzY=
|
||||||
|
modernc.org/strutil v1.1.3/go.mod h1:MEHNA7PdEnEwLvspRMtWTNnp2nnyvMfkimT1NKNAGbw=
|
||||||
|
modernc.org/tcl v1.15.1 h1:mOQwiEK4p7HruMZcwKTZPw/aqtGM4aY00uzWhlKKYws=
|
||||||
|
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||||
|
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||||
|
modernc.org/z v1.7.0 h1:xkDw/KepgEjeizO2sNco+hqYkU12taxQFqPEmgm1GWE=
|
|
@ -0,0 +1,15 @@
|
||||||
|
[Unit]
|
||||||
|
Description=WebSub Hub
|
||||||
|
After=syslog.target
|
||||||
|
After=network.target
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
EnvironmentFile=/etc/hub/env
|
||||||
|
RestartSec=2s
|
||||||
|
Type=simple
|
||||||
|
WorkingDirectory=/etc/hub/
|
||||||
|
ExecStart=/etc/hub/hub
|
||||||
|
Restart=always
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
|
@ -0,0 +1,34 @@
|
||||||
|
package common
|
||||||
|
|
||||||
|
const (
|
||||||
|
MIMEApplicationForm = "application/x-www-form-urlencoded"
|
||||||
|
MIMEApplicationFormCharsetUTF8 = MIMEApplicationForm + "; " + charsetUTF8
|
||||||
|
MIMETextHTML = "text/html"
|
||||||
|
MIMETextHTMLCharsetUTF8 = MIMETextHTML + "; " + charsetUTF8
|
||||||
|
MIMETextPlain = "text/plain"
|
||||||
|
MIMETextPlainCharsetUTF8 = MIMETextPlain + "; " + charsetUTF8
|
||||||
|
|
||||||
|
charsetUTF8 = "charset=UTF-8"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
HeaderContentType = "Content-Type"
|
||||||
|
HeaderLink = "Link"
|
||||||
|
HeaderXHubSignature = "X-Hub-Signature"
|
||||||
|
HeaderAcceptLanguage = "Accept-Language"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
HubCallback = hub + ".callback"
|
||||||
|
HubChallenge = hub + ".challenge"
|
||||||
|
HubLeaseSeconds = hub + ".lease_seconds"
|
||||||
|
HubMode = hub + ".mode"
|
||||||
|
HubReason = hub + ".reason"
|
||||||
|
HubSecret = hub + ".secret"
|
||||||
|
HubTopic = hub + ".topic"
|
||||||
|
HubURL = hub + ".url"
|
||||||
|
|
||||||
|
hub = "hub"
|
||||||
|
)
|
||||||
|
|
||||||
|
const Und = "und"
|
|
@ -0,0 +1,77 @@
|
||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha1"
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/sha512"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"hash"
|
||||||
|
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Algorithm struct {
|
||||||
|
algorithm string
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
AlgorithmUnd = Algorithm{algorithm: ""} // "und"
|
||||||
|
AlgorithmSHA1 = Algorithm{algorithm: "sha1"} // "sha1"
|
||||||
|
AlgorithmSHA256 = Algorithm{algorithm: "sha256"} // "sha256"
|
||||||
|
AlgorithmSHA384 = Algorithm{algorithm: "sha384"} // "sha384"
|
||||||
|
AlgorithmSHA512 = Algorithm{algorithm: "sha512"} // "sha512"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrSyntaxAlgorithm = errors.New("bad algorithm syntax")
|
||||||
|
|
||||||
|
var stringsAlgorithms = map[string]Algorithm{
|
||||||
|
AlgorithmSHA1.algorithm: AlgorithmSHA1,
|
||||||
|
AlgorithmSHA256.algorithm: AlgorithmSHA256,
|
||||||
|
AlgorithmSHA384.algorithm: AlgorithmSHA384,
|
||||||
|
AlgorithmSHA512.algorithm: AlgorithmSHA512,
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseAlgorithm(algorithm string) (Algorithm, error) {
|
||||||
|
if alg, ok := stringsAlgorithms[algorithm]; ok {
|
||||||
|
return alg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return AlgorithmUnd, fmt.Errorf("%w: %s", ErrSyntaxAlgorithm, algorithm)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a Algorithm) Hash() hash.Hash {
|
||||||
|
switch a {
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
case AlgorithmSHA1:
|
||||||
|
return sha1.New()
|
||||||
|
case AlgorithmSHA256:
|
||||||
|
return sha256.New()
|
||||||
|
case AlgorithmSHA384:
|
||||||
|
return sha512.New384()
|
||||||
|
case AlgorithmSHA512:
|
||||||
|
return sha512.New()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Algorithm) UnmarshalForm(src []byte) error {
|
||||||
|
var err error
|
||||||
|
if *a, err = ParseAlgorithm(string(src)); err != nil {
|
||||||
|
return fmt.Errorf("Algorithm: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a Algorithm) String() string {
|
||||||
|
if a.algorithm != "" {
|
||||||
|
return a.algorithm
|
||||||
|
}
|
||||||
|
|
||||||
|
return common.Und
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a Algorithm) GoString() string {
|
||||||
|
return "domain.Algorithm(" + a.String() + ")"
|
||||||
|
}
|
|
@ -0,0 +1,35 @@
|
||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Challenge struct {
|
||||||
|
challenge string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChallenge(length uint8) (*Challenge, error) {
|
||||||
|
src := make([]byte, length)
|
||||||
|
if _, err := rand.Read(src); err != nil {
|
||||||
|
return nil, fmt.Errorf("cannot create a new challenge: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Challenge{challenge: base64.URLEncoding.EncodeToString(src)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Challenge) AddQuery(q url.Values) {
|
||||||
|
q.Add(common.HubChallenge, string(c.challenge))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Challenge) Equal(target string) bool {
|
||||||
|
return c.challenge == target
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Challenge) String() string {
|
||||||
|
return string(c.challenge)
|
||||||
|
}
|
|
@ -0,0 +1,26 @@
|
||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
BaseURL *url.URL `env:"BASE_URL" envDefault:"http://localhost:3000/"`
|
||||||
|
Bind string `end:"BIND,required" envDefault:":3000"`
|
||||||
|
Name string `env:"NAME" envDefault:"WebSub"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig(tb testing.TB) *Config {
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
|
return &Config{
|
||||||
|
BaseURL: &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "hub.example.com",
|
||||||
|
Path: "/",
|
||||||
|
},
|
||||||
|
Bind: ":3000",
|
||||||
|
Name: "WebSub",
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,55 @@
|
||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Error struct {
|
||||||
|
frame xerrors.Frame
|
||||||
|
topic *url.URL
|
||||||
|
reason string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewError(reason string, topic ...*url.URL) error {
|
||||||
|
err := &Error{
|
||||||
|
reason: reason,
|
||||||
|
frame: xerrors.Caller(1),
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(topic) > 0 {
|
||||||
|
err.topic = topic[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error returns a string representation of the error, satisfying the error
|
||||||
|
// interface.
|
||||||
|
func (e Error) Error() string {
|
||||||
|
return fmt.Sprint(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format prints the stack as error detail.
|
||||||
|
func (e Error) Format(state fmt.State, r rune) {
|
||||||
|
xerrors.FormatError(e, state, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatError prints the receiver's error, if any.
|
||||||
|
func (e Error) FormatError(printer xerrors.Printer) error {
|
||||||
|
printer.Print(e.reason)
|
||||||
|
|
||||||
|
if e.topic != nil {
|
||||||
|
printer.Printf(" (%s)", e.topic)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !printer.Detail() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
e.frame.Format(printer)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,63 @@
|
||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Mode struct {
|
||||||
|
mode string
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ModeUnd Mode = Mode{mode: ""} // "und"
|
||||||
|
ModeDenied Mode = Mode{mode: "denied"} // "denied"
|
||||||
|
ModePublish Mode = Mode{mode: "publish"} // "publish"
|
||||||
|
ModeSubscribe Mode = Mode{mode: "subscribe"} // "subscribe"
|
||||||
|
ModeUnsubscribe Mode = Mode{mode: "unsubscribe"} // "unsubscribe"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrModeSyntax = errors.New("bad mode syntax")
|
||||||
|
|
||||||
|
var stringsModes = map[string]Mode{
|
||||||
|
ModeDenied.mode: ModeDenied,
|
||||||
|
ModePublish.mode: ModePublish,
|
||||||
|
ModeSubscribe.mode: ModeSubscribe,
|
||||||
|
ModeUnsubscribe.mode: ModeUnsubscribe,
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseMode(mode string) (Mode, error) {
|
||||||
|
if mode, ok := stringsModes[mode]; ok {
|
||||||
|
return mode, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return ModeUnd, fmt.Errorf("%w: %s", ErrModeSyntax, mode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Mode) UnmarshalForm(src []byte) error {
|
||||||
|
var err error
|
||||||
|
if *m, err = ParseMode(string(src)); err != nil {
|
||||||
|
return fmt.Errorf("Mode: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Mode) AddQuery(q url.Values) {
|
||||||
|
q.Add(common.HubMode, m.mode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Mode) String() string {
|
||||||
|
if m.mode != "" {
|
||||||
|
return m.mode
|
||||||
|
}
|
||||||
|
|
||||||
|
return common.Und
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Mode) GoString() string {
|
||||||
|
return "domain.Mode(" + m.String() + ")"
|
||||||
|
}
|
|
@ -0,0 +1,7 @@
|
||||||
|
package domain
|
||||||
|
|
||||||
|
import "net/url"
|
||||||
|
|
||||||
|
type QueryAdder interface {
|
||||||
|
AddQuery(q url.Values)
|
||||||
|
}
|
|
@ -0,0 +1,75 @@
|
||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
cryptorand "crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"math/rand"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Secret describes a subscriber-provided cryptographically random unique secret
|
||||||
|
// string that will be used to compute an HMAC digest for authorized content
|
||||||
|
// distribution. If not supplied, the HMAC digest will not be present for
|
||||||
|
// content distribution requests. This parameter SHOULD only be specified when
|
||||||
|
// the request was made over HTTPS [RFC2818]. This parameter MUST be less than
|
||||||
|
// 200 bytes in length.
|
||||||
|
//
|
||||||
|
// [RFC2818]: https://tools.ietf.org/html/rfc2818
|
||||||
|
type Secret struct {
|
||||||
|
secret string
|
||||||
|
}
|
||||||
|
|
||||||
|
var ErrSyntaxSecret = NewError("secret MUST be less than 200 bytes in length")
|
||||||
|
|
||||||
|
var lengthMax = 200
|
||||||
|
|
||||||
|
func ParseSecret(raw string) (*Secret, error) {
|
||||||
|
if len(raw) >= lengthMax {
|
||||||
|
return nil, ErrSyntaxSecret
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Secret{secret: raw}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Secret) IsSet() bool {
|
||||||
|
return s.secret != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Secret) AddQuery(q url.Values) {
|
||||||
|
if s.secret == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
q.Add(common.HubSecret, s.secret)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Secret) String() string {
|
||||||
|
return s.secret
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSecret returns a valid random generated Secret.
|
||||||
|
func TestSecret(tb testing.TB) *Secret {
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
|
src := make([]byte, rand.Intn(lengthMax/2))
|
||||||
|
if _, err := cryptorand.Read(src); err != nil {
|
||||||
|
tb.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Secret{secret: base64.URLEncoding.EncodeToString(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSecret returns a invalid random generated Secret.
|
||||||
|
func TestSecretInvalid(tb testing.TB) *Secret {
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
|
src := make([]byte, lengthMax*2)
|
||||||
|
if _, err := cryptorand.Read(src); err != nil {
|
||||||
|
tb.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Secret{secret: base64.URLEncoding.EncodeToString(src)}
|
||||||
|
}
|
|
@ -0,0 +1,78 @@
|
||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Subscription is a unique relation to a topic by a subscriber that indicates
|
||||||
|
// it should receive updates for that topic.
|
||||||
|
type Subscription struct {
|
||||||
|
// First creation datetime
|
||||||
|
CreatedAt time.Time
|
||||||
|
|
||||||
|
// Last updating datetime
|
||||||
|
UpdatedAt time.Time
|
||||||
|
|
||||||
|
// Datetime when subscription must be deleted
|
||||||
|
ExpiredAt time.Time
|
||||||
|
|
||||||
|
// Datetime synced with topic updating time
|
||||||
|
SyncedAt time.Time
|
||||||
|
|
||||||
|
Callback *url.URL
|
||||||
|
Topic *url.URL
|
||||||
|
|
||||||
|
Secret Secret
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Subscription) AddQuery(q url.Values) {
|
||||||
|
s.Secret.AddQuery(q)
|
||||||
|
q.Add(common.HubTopic, s.Topic.String())
|
||||||
|
q.Add(common.HubCallback, s.Callback.String())
|
||||||
|
q.Add(common.HubLeaseSeconds, strconv.FormatFloat(s.LeaseSeconds(), 'g', 0, 64))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Subscription) SUID() SUID {
|
||||||
|
return SUID{
|
||||||
|
topic: s.Topic.String(),
|
||||||
|
callback: s.Callback.String(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Subscription) LeaseSeconds() float64 {
|
||||||
|
return s.ExpiredAt.Sub(s.UpdatedAt).Round(time.Second).Seconds()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Subscription) Synced(t Topic) bool {
|
||||||
|
return s.SyncedAt.Equal(t.UpdatedAt) || s.SyncedAt.After(t.UpdatedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Subscription) Expired(ts time.Time) bool {
|
||||||
|
return s.ExpiredAt.Before(ts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscription(tb testing.TB, callbackUrl string) *Subscription {
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
|
callback, err := url.Parse(callbackUrl)
|
||||||
|
if err != nil {
|
||||||
|
tb.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ts := time.Now().UTC().Round(time.Second)
|
||||||
|
secret := TestSecret(tb)
|
||||||
|
|
||||||
|
return &Subscription{
|
||||||
|
CreatedAt: ts,
|
||||||
|
UpdatedAt: ts,
|
||||||
|
ExpiredAt: ts.Add(10 * 24 * time.Hour).Round(time.Second),
|
||||||
|
Callback: callback,
|
||||||
|
Topic: &url.URL{Scheme: "https", Host: "example.com", Path: "/lipsum"},
|
||||||
|
Secret: *secret,
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
package domain
|
||||||
|
|
||||||
|
import "net/url"
|
||||||
|
|
||||||
|
// SUID describes a subscription's unique key is the tuple ([Topic] URL,
|
||||||
|
// Subscriber [Callback] URL).
|
||||||
|
type SUID struct {
|
||||||
|
topic string
|
||||||
|
callback string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSSID(topic Topic, callback *url.URL) SUID {
|
||||||
|
return SUID{
|
||||||
|
topic: topic.Self.String(),
|
||||||
|
callback: callback.String(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suid SUID) Equal(target SUID) bool {
|
||||||
|
return suid.topic == target.topic && suid.callback == target.callback
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suid SUID) GoString() string {
|
||||||
|
return "domain.SUID(" + suid.topic + ":" + suid.callback + ")"
|
||||||
|
}
|
|
@ -0,0 +1,43 @@
|
||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Topic struct {
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
Self *url.URL
|
||||||
|
ContentType string
|
||||||
|
Content []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTopic(tb testing.TB) *Topic {
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
|
now := time.Now().UTC().Add(-1 * time.Hour)
|
||||||
|
|
||||||
|
return &Topic{
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
Self: &url.URL{Scheme: "https", Host: "example.com", Path: "/"},
|
||||||
|
ContentType: "text/html",
|
||||||
|
Content: []byte("hello, world"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Topic) AddQuery(q url.Values) {
|
||||||
|
q.Add(common.HubTopic, t.Self.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Topic) Equal(target Topic) bool {
|
||||||
|
return t.Self.String() == target.Self.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Topic) String() string {
|
||||||
|
return t.Self.String()
|
||||||
|
}
|
|
@ -0,0 +1,222 @@
|
||||||
|
package http
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/text/language"
|
||||||
|
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/common"
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/domain"
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/hub"
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/subscription"
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/topic"
|
||||||
|
"source.toby3d.me/toby3d/hub/web/template"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
Request struct {
|
||||||
|
Callback *url.URL
|
||||||
|
Topic *url.URL
|
||||||
|
Secret domain.Secret
|
||||||
|
Mode domain.Mode
|
||||||
|
LeaseSeconds float64
|
||||||
|
}
|
||||||
|
|
||||||
|
Response struct {
|
||||||
|
Mode domain.Mode
|
||||||
|
Reason string
|
||||||
|
Topic domain.Topic
|
||||||
|
}
|
||||||
|
|
||||||
|
NewHandlerParams struct {
|
||||||
|
Hub hub.UseCase
|
||||||
|
Subscriptions subscription.UseCase
|
||||||
|
Topics topic.UseCase
|
||||||
|
Matcher language.Matcher
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
Handler struct {
|
||||||
|
hub hub.UseCase
|
||||||
|
subscriptions subscription.UseCase
|
||||||
|
topics topic.UseCase
|
||||||
|
matcher language.Matcher
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
var DefaultRequestLeaseSeconds = time.Duration(10 * 24 * time.Hour).Seconds() // 10 days
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrHubMode = errors.New(common.HubMode + " MUST be " + domain.ModeSubscribe.String() + " or " +
|
||||||
|
domain.ModeUnsubscribe.String())
|
||||||
|
ErrHubSecret = errors.New(common.HubSecret + " SHOULD be specified when the request was made over HTTPS")
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewHandler(params NewHandlerParams) *Handler {
|
||||||
|
return &Handler{
|
||||||
|
hub: params.Hub,
|
||||||
|
matcher: params.Matcher,
|
||||||
|
name: params.Name,
|
||||||
|
subscriptions: params.Subscriptions,
|
||||||
|
topics: params.Topics,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
now := time.Now().UTC().Round(time.Second)
|
||||||
|
|
||||||
|
switch r.Method {
|
||||||
|
default:
|
||||||
|
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
||||||
|
case http.MethodPost:
|
||||||
|
req := NewRequest()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
if err = req.bind(r); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(toby3d): send denied ping to callback if it's not accepted by hub
|
||||||
|
|
||||||
|
s := new(domain.Subscription)
|
||||||
|
req.populate(s, now)
|
||||||
|
|
||||||
|
switch req.Mode {
|
||||||
|
case domain.ModeSubscribe, domain.ModeUnsubscribe:
|
||||||
|
if _, err = h.hub.Verify(r.Context(), *s, req.Mode); err != nil {
|
||||||
|
r.Clone(context.WithValue(r.Context(), "error", err))
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusAccepted)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch req.Mode {
|
||||||
|
case domain.ModeSubscribe:
|
||||||
|
_, err = h.subscriptions.Subscribe(r.Context(), *s)
|
||||||
|
case domain.ModeUnsubscribe:
|
||||||
|
_, err = h.subscriptions.Unsubscribe(r.Context(), *s)
|
||||||
|
}
|
||||||
|
case domain.ModePublish:
|
||||||
|
_, err = h.topics.Publish(r.Context(), req.Topic)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
r.Clone(context.WithValue(r.Context(), "error", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusAccepted)
|
||||||
|
case "", http.MethodGet:
|
||||||
|
tags, _, _ := language.ParseAcceptLanguage(r.Header.Get(common.HeaderAcceptLanguage))
|
||||||
|
tag, _, _ := h.matcher.Match(tags...)
|
||||||
|
|
||||||
|
w.Header().Set(common.HeaderContentType, common.MIMETextHTMLCharsetUTF8)
|
||||||
|
template.WriteTemplate(w, &template.Home{BaseOf: template.NewBaseOf(tag, h.name)})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRequest() *Request {
|
||||||
|
return &Request{
|
||||||
|
Mode: domain.ModeUnd,
|
||||||
|
Callback: nil,
|
||||||
|
Secret: domain.Secret{},
|
||||||
|
Topic: nil,
|
||||||
|
LeaseSeconds: DefaultRequestLeaseSeconds,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Request) bind(req *http.Request) error {
|
||||||
|
var err error
|
||||||
|
if err = req.ParseForm(); err != nil {
|
||||||
|
return fmt.Errorf("cannot parse request form: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !req.PostForm.Has(common.HubMode) {
|
||||||
|
return fmt.Errorf("%s parameter is required, but not provided", common.HubMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE(toby3d): hub.mode
|
||||||
|
if r.Mode, err = domain.ParseMode(req.PostForm.Get(common.HubMode)); err != nil {
|
||||||
|
return fmt.Errorf("cannot parse %s: %w", common.HubMode, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE(toby3d): hub.topic
|
||||||
|
if !req.PostForm.Has(common.HubTopic) {
|
||||||
|
return fmt.Errorf("%s parameter is required, but not provided", common.HubTopic)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Topic, err = url.Parse(req.PostForm.Get(common.HubTopic)); err != nil {
|
||||||
|
return fmt.Errorf("cannot parse %s: %w", common.HubTopic, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch r.Mode {
|
||||||
|
case domain.ModePublish:
|
||||||
|
case domain.ModeSubscribe, domain.ModeUnsubscribe:
|
||||||
|
// NOTE(toby3d): hub.callback
|
||||||
|
if !req.PostForm.Has(common.HubCallback) {
|
||||||
|
return fmt.Errorf("%s parameter is required, but not provided", common.HubCallback)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Callback, err = url.Parse(req.PostForm.Get(common.HubCallback)); err != nil {
|
||||||
|
return fmt.Errorf("cannot parse %s: %w", common.HubCallback, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE(toby3d): hub.lease_seconds
|
||||||
|
if r.Mode != domain.ModeUnsubscribe && req.PostForm.Has(common.HubLeaseSeconds) {
|
||||||
|
r.LeaseSeconds, err = strconv.ParseFloat(req.PostForm.Get(common.HubLeaseSeconds), 64)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot parse %s: %w", common.HubLeaseSeconds, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE(toby3d): hub.secret
|
||||||
|
if !req.PostForm.Has(common.HubSecret) {
|
||||||
|
if req.TLS != nil {
|
||||||
|
return ErrHubSecret
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
secret, err := domain.ParseSecret(req.PostForm.Get(common.HubSecret))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot parse %s: %w", common.HubSecret, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Secret = *secret
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r Request) populate(s *domain.Subscription, ts time.Time) {
|
||||||
|
s.CreatedAt = ts
|
||||||
|
s.UpdatedAt = ts
|
||||||
|
s.ExpiredAt = ts.Add(time.Duration(r.LeaseSeconds) * time.Second).Round(time.Second)
|
||||||
|
s.Callback = r.Callback
|
||||||
|
s.Topic = r.Topic
|
||||||
|
s.Secret = r.Secret
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewResponse(t domain.Topic, err error) *Response {
|
||||||
|
return &Response{
|
||||||
|
Mode: domain.ModeDenied,
|
||||||
|
Topic: t,
|
||||||
|
Reason: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Response) populate(q url.Values) {
|
||||||
|
r.Mode.AddQuery(q)
|
||||||
|
r.Topic.AddQuery(q)
|
||||||
|
q.Add(common.HubReason, r.Reason)
|
||||||
|
}
|
|
@ -0,0 +1,108 @@
|
||||||
|
package http_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/text/language"
|
||||||
|
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/common"
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/domain"
|
||||||
|
delivery "source.toby3d.me/toby3d/hub/internal/hub/delivery/http"
|
||||||
|
hubucase "source.toby3d.me/toby3d/hub/internal/hub/usecase"
|
||||||
|
subscriptionmemoryrepo "source.toby3d.me/toby3d/hub/internal/subscription/repository/memory"
|
||||||
|
subscriptionucase "source.toby3d.me/toby3d/hub/internal/subscription/usecase"
|
||||||
|
topicmemoryrepo "source.toby3d.me/toby3d/hub/internal/topic/repository/memory"
|
||||||
|
topicucase "source.toby3d.me/toby3d/hub/internal/topic/usecase"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHandler_ServeHTTP_Subscribe(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Fprint(w, r.URL.Query().Get(common.HubChallenge))
|
||||||
|
}))
|
||||||
|
t.Cleanup(srv.Close)
|
||||||
|
|
||||||
|
in := domain.TestSubscription(t, srv.URL+"/lipsum")
|
||||||
|
subscriptions := subscriptionmemoryrepo.NewMemorySubscriptionRepository()
|
||||||
|
topics := topicmemoryrepo.NewMemoryTopicRepository()
|
||||||
|
hub := hubucase.NewHubUseCase(topics, subscriptions, srv.Client(), &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "hub.exmaple.com",
|
||||||
|
Path: "/",
|
||||||
|
})
|
||||||
|
|
||||||
|
payload := make(url.Values)
|
||||||
|
domain.ModeSubscribe.AddQuery(payload)
|
||||||
|
in.AddQuery(payload)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "https://hub.example.com/", strings.NewReader(payload.Encode()))
|
||||||
|
req.Header.Set(common.HeaderContentType, common.MIMEApplicationFormCharsetUTF8)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
delivery.NewHandler(delivery.NewHandlerParams{
|
||||||
|
Hub: hub,
|
||||||
|
Subscriptions: subscriptionucase.NewSubscriptionUseCase(subscriptions, topics, srv.Client()),
|
||||||
|
Topics: topicucase.NewTopicUseCase(topics, srv.Client()),
|
||||||
|
Matcher: language.NewMatcher([]language.Tag{language.English}),
|
||||||
|
Name: "WebSub",
|
||||||
|
}).ServeHTTP(w, req)
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
|
||||||
|
if expect := http.StatusAccepted; resp.StatusCode != expect {
|
||||||
|
t.Errorf("%s %s = %d, want %d", req.Method, req.RequestURI, resp.StatusCode, expect)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandler_ServeHTTP_Unsubscribe(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set(common.HeaderContentType, common.MIMETextPlainCharsetUTF8)
|
||||||
|
fmt.Fprint(w, r.URL.Query().Get(common.HubChallenge))
|
||||||
|
}))
|
||||||
|
t.Cleanup(srv.Close)
|
||||||
|
|
||||||
|
in := domain.TestSubscription(t, srv.URL+"/lipsum")
|
||||||
|
subscriptions := subscriptionmemoryrepo.NewMemorySubscriptionRepository()
|
||||||
|
topics := topicmemoryrepo.NewMemoryTopicRepository()
|
||||||
|
|
||||||
|
if err := subscriptions.Create(context.Background(), in.SUID(), *in); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
hub := hubucase.NewHubUseCase(topics, subscriptions, srv.Client(), &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "hub.exmaple.com",
|
||||||
|
Path: "/",
|
||||||
|
})
|
||||||
|
|
||||||
|
payload := make(url.Values)
|
||||||
|
domain.ModeUnsubscribe.AddQuery(payload)
|
||||||
|
in.AddQuery(payload)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "https://hub.example.com/", strings.NewReader(payload.Encode()))
|
||||||
|
req.Header.Set(common.HeaderContentType, common.MIMEApplicationFormCharsetUTF8)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
delivery.NewHandler(delivery.NewHandlerParams{
|
||||||
|
Hub: hub,
|
||||||
|
Subscriptions: subscriptionucase.NewSubscriptionUseCase(subscriptions, topics, srv.Client()),
|
||||||
|
Topics: topicucase.NewTopicUseCase(topics, srv.Client()),
|
||||||
|
Matcher: language.NewMatcher([]language.Tag{language.English}),
|
||||||
|
Name: "WebSub",
|
||||||
|
}).ServeHTTP(w, req)
|
||||||
|
|
||||||
|
resp := w.Result()
|
||||||
|
|
||||||
|
if expect := http.StatusAccepted; resp.StatusCode != expect {
|
||||||
|
t.Errorf("%s %s = %d, want %d", req.Method, req.RequestURI, resp.StatusCode, expect)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,19 @@
|
||||||
|
package hub
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UseCase interface {
|
||||||
|
Verify(ctx context.Context, subscription domain.Subscription, mode domain.Mode) (bool, error)
|
||||||
|
ListenAndServe(ctx context.Context) error
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrStatus = errors.New("subscriber replied with a non 2xx status")
|
||||||
|
ErrNotFound = errors.New("subscriber denied verification, responding with a 404 status")
|
||||||
|
ErrChallenge = errors.New("the challenge of the hub and the subscriber do not match")
|
||||||
|
)
|
|
@ -0,0 +1,186 @@
|
||||||
|
package usecase
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/hmac"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math/rand"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/common"
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/domain"
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/hub"
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/subscription"
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/topic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type hubUseCase struct {
|
||||||
|
subscriptions subscription.Repository
|
||||||
|
topics topic.Repository
|
||||||
|
client *http.Client
|
||||||
|
self *url.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
lengthMin = 16
|
||||||
|
lengthMax = 32
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewHubUseCase(t topic.Repository, s subscription.Repository, c *http.Client, u *url.URL) hub.UseCase {
|
||||||
|
return &hubUseCase{
|
||||||
|
client: c,
|
||||||
|
self: u,
|
||||||
|
topics: t,
|
||||||
|
subscriptions: s,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ucase *hubUseCase) Verify(ctx context.Context, s domain.Subscription, mode domain.Mode) (bool, error) {
|
||||||
|
challenge, err := domain.NewChallenge(uint8(lengthMin + rand.Intn(lengthMax-lengthMin)))
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("cannot generate hub.challenge: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
u, _ := url.Parse(s.Callback.String())
|
||||||
|
q := u.Query()
|
||||||
|
|
||||||
|
mode.AddQuery(q)
|
||||||
|
q.Add(common.HubTopic, s.Topic.String())
|
||||||
|
challenge.AddQuery(q)
|
||||||
|
|
||||||
|
if mode == domain.ModeSubscribe {
|
||||||
|
q.Add(common.HubLeaseSeconds, strconv.FormatFloat(s.LeaseSeconds(), 'g', 0, 64))
|
||||||
|
}
|
||||||
|
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("cannot build verification request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ucase.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("cannot send verification request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusNotFound {
|
||||||
|
return false, hub.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
return false, hub.ErrStatus
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("cannot verify subscriber response body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !challenge.Equal(string(body)) {
|
||||||
|
return false, fmt.Errorf("%w: got '%s', want '%s'", hub.ErrChallenge, body, *challenge)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ucase *hubUseCase) ListenAndServe(ctx context.Context) error {
|
||||||
|
ticker := time.NewTicker(time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for ts := range ticker.C {
|
||||||
|
ts = ts.Round(time.Second)
|
||||||
|
|
||||||
|
topics, err := ucase.topics.Fetch(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot fetch topics: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range topics {
|
||||||
|
subscriptions, err := ucase.subscriptions.Fetch(ctx, &topics[i])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot fetch subscriptions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for j := range subscriptions {
|
||||||
|
if subscriptions[j].Expired(ts) {
|
||||||
|
if err = ucase.subscriptions.Delete(ctx, subscriptions[j].SUID()); err != nil {
|
||||||
|
return fmt.Errorf("cannot remove expired subcription: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if subscriptions[j].Synced(topics[i]) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
go ucase.push(ctx, subscriptions[j], topics[i], ts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ucase *hubUseCase) push(ctx context.Context, s domain.Subscription, t domain.Topic, ts time.Time) (bool, error) {
|
||||||
|
req, err := http.NewRequest(http.MethodPost, s.Callback.String(), bytes.NewReader(t.Content))
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("cannot build request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set(common.HeaderContentType, t.ContentType)
|
||||||
|
req.Header.Set(common.HeaderLink, `<`+ucase.self.String()+`>; rel="hub", <`+s.Topic.String()+`>; rel="self"`)
|
||||||
|
setXHubSignatureHeader(req, domain.AlgorithmSHA512, s.Secret, t.Content)
|
||||||
|
|
||||||
|
resp, err := ucase.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("cannot push: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
suid := s.SUID()
|
||||||
|
|
||||||
|
// The subscriber's callback URL MAY return an HTTP 410 code to indicate
|
||||||
|
// that the subscription has been deleted, and the hub MAY terminate the
|
||||||
|
// subscription if it receives that code as a response.
|
||||||
|
if resp.StatusCode == http.StatusGone {
|
||||||
|
if err = ucase.subscriptions.Delete(ctx, suid); err != nil {
|
||||||
|
return false, fmt.Errorf("cannot remove deleted subscription: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// The subscriber's callback URL MUST return an HTTP 2xx response code
|
||||||
|
// to indicate a success.
|
||||||
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
return false, hub.ErrStatus
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = ucase.subscriptions.Update(ctx, suid, func(tx *domain.Subscription) (*domain.Subscription, error) {
|
||||||
|
tx.SyncedAt = t.UpdatedAt
|
||||||
|
|
||||||
|
return tx, nil
|
||||||
|
}); err != nil {
|
||||||
|
return false, fmt.Errorf("cannot sync sybsciption status: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setXHubSignatureHeader(req *http.Request, alg domain.Algorithm, secret domain.Secret, body []byte) {
|
||||||
|
if !secret.IsSet() || alg == domain.AlgorithmUnd {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h := hmac.New(alg.Hash, []byte(secret.String()))
|
||||||
|
h.Write(body)
|
||||||
|
|
||||||
|
req.Header.Set(common.HeaderXHubSignature, alg.String()+"="+hex.EncodeToString(h.Sum(nil)))
|
||||||
|
}
|
|
@ -0,0 +1,43 @@
|
||||||
|
package usecase_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/common"
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/domain"
|
||||||
|
hubucase "source.toby3d.me/toby3d/hub/internal/hub/usecase"
|
||||||
|
subscriptionmemoryrepo "source.toby3d.me/toby3d/hub/internal/subscription/repository/memory"
|
||||||
|
topicmemoryrepo "source.toby3d.me/toby3d/hub/internal/topic/repository/memory"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHubUseCase_Verify(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set(common.HeaderContentType, common.MIMETextPlainCharsetUTF8)
|
||||||
|
fmt.Fprint(w, r.FormValue(common.HubChallenge))
|
||||||
|
}))
|
||||||
|
t.Cleanup(srv.Close)
|
||||||
|
|
||||||
|
subscriptions := subscriptionmemoryrepo.NewMemorySubscriptionRepository()
|
||||||
|
topics := topicmemoryrepo.NewMemoryTopicRepository()
|
||||||
|
subscription := domain.TestSubscription(t, srv.URL)
|
||||||
|
|
||||||
|
ok, err := hubucase.NewHubUseCase(topics, subscriptions, srv.Client(), &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "hub.example.com",
|
||||||
|
Path: "/",
|
||||||
|
}).Verify(context.Background(), *subscription, domain.ModeSubscribe)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("want %t, got %t", true, ok)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,127 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-logfmt/logfmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
LogFmtConfig struct {
|
||||||
|
// Skipper defines a function to skip middleware.
|
||||||
|
Skipper Skipper
|
||||||
|
|
||||||
|
// Output is a writer where logs in JSON format are written.
|
||||||
|
// Optional. Default value os.Stdout.
|
||||||
|
Output io.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
logFmtResponse struct {
|
||||||
|
start time.Time
|
||||||
|
http.ResponseWriter
|
||||||
|
error error
|
||||||
|
id uint64
|
||||||
|
statusCode int
|
||||||
|
responseLength int
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
//nolint:gochecknoglobals // default configuration
|
||||||
|
var DefaultLogFmtConfig = LogFmtConfig{
|
||||||
|
Skipper: DefaultSkipper,
|
||||||
|
Output: os.Stdout,
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:gochecknoglobals
|
||||||
|
var globalConnID uint64
|
||||||
|
|
||||||
|
func LogFmt() Interceptor {
|
||||||
|
c := DefaultLogFmtConfig
|
||||||
|
|
||||||
|
return LogFmtWithConfig(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LogFmtWithConfig(config LogFmtConfig) Interceptor {
|
||||||
|
if config.Skipper == nil {
|
||||||
|
config.Skipper = DefaultLogFmtConfig.Skipper
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Output == nil {
|
||||||
|
config.Output = DefaultLogFmtConfig.Output
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder := logfmt.NewEncoder(config.Output)
|
||||||
|
|
||||||
|
return func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||||
|
rw := &logFmtResponse{
|
||||||
|
id: nextConnID(),
|
||||||
|
responseLength: 0,
|
||||||
|
ResponseWriter: w,
|
||||||
|
start: time.Now().UTC(),
|
||||||
|
statusCode: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
next(rw, r)
|
||||||
|
|
||||||
|
rw.error, _ = r.Context().Value("error").(error)
|
||||||
|
end := time.Now().UTC()
|
||||||
|
|
||||||
|
encoder.EncodeKeyvals(
|
||||||
|
"bytes_in", r.ContentLength,
|
||||||
|
"bytes_out", rw.responseLength,
|
||||||
|
"error", rw.error,
|
||||||
|
"host", r.Host,
|
||||||
|
"id", rw.id,
|
||||||
|
"latency", end.Sub(rw.start).Nanoseconds(),
|
||||||
|
"latency_human", end.Sub(rw.start).String(),
|
||||||
|
"method", r.Method,
|
||||||
|
"path", r.URL.Path,
|
||||||
|
"protocol", r.Proto,
|
||||||
|
"referer", r.Referer(),
|
||||||
|
"remote_ip", r.RemoteAddr,
|
||||||
|
"status", rw.statusCode,
|
||||||
|
"time_rfc3339", rw.start.Format(time.RFC3339),
|
||||||
|
"time_rfc3339_nano", rw.start.Format(time.RFC3339Nano),
|
||||||
|
"time_unix", rw.start.Unix(),
|
||||||
|
"time_unix_nano", rw.start.UnixNano(),
|
||||||
|
"uri", r.RequestURI,
|
||||||
|
"user_agent", r.UserAgent(),
|
||||||
|
)
|
||||||
|
|
||||||
|
for name, src := range map[string]map[string][]string{
|
||||||
|
"form": r.PostForm,
|
||||||
|
"header": r.Header,
|
||||||
|
"query": r.URL.Query(),
|
||||||
|
} {
|
||||||
|
for k, v := range src {
|
||||||
|
encoder.EncodeKeyval(name+"_"+strings.ReplaceAll(strings.ToLower(k), "-", "_"), v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder.EndRecord()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *logFmtResponse) WriteHeader(status int) {
|
||||||
|
r.statusCode = status
|
||||||
|
|
||||||
|
r.ResponseWriter.WriteHeader(status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *logFmtResponse) Write(src []byte) (int, error) {
|
||||||
|
var l int
|
||||||
|
|
||||||
|
l, r.error = r.ResponseWriter.Write(src)
|
||||||
|
r.responseLength += l
|
||||||
|
|
||||||
|
return l, r.error
|
||||||
|
}
|
||||||
|
|
||||||
|
func nextConnID() uint64 {
|
||||||
|
return atomic.AddUint64(&globalConnID, 1)
|
||||||
|
}
|
|
@ -0,0 +1,36 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
BeforeFunc = http.HandlerFunc
|
||||||
|
|
||||||
|
Chain []Interceptor
|
||||||
|
|
||||||
|
Interceptor func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc)
|
||||||
|
|
||||||
|
HandlerFunc http.HandlerFunc
|
||||||
|
|
||||||
|
Skipper func(r *http.Request) bool
|
||||||
|
)
|
||||||
|
|
||||||
|
var DefaultSkipper Skipper = func(_ *http.Request) bool { return false }
|
||||||
|
|
||||||
|
func (count HandlerFunc) Intercept(middleware Interceptor) HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
middleware(w, r, http.HandlerFunc(count))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (chain Chain) Handler(handler http.HandlerFunc) http.Handler {
|
||||||
|
current := HandlerFunc(handler)
|
||||||
|
|
||||||
|
for i := len(chain) - 1; i >= 0; i-- {
|
||||||
|
m := chain[i]
|
||||||
|
current = current.Intercept(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
return http.HandlerFunc(current)
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
package subscription
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
UpdateFunc func(subscription *domain.Subscription) (*domain.Subscription, error)
|
||||||
|
|
||||||
|
Repository interface {
|
||||||
|
Create(ctx context.Context, suid domain.SUID, subscription domain.Subscription) error
|
||||||
|
Get(ctx context.Context, suid domain.SUID) (*domain.Subscription, error)
|
||||||
|
Fetch(ctx context.Context, topic *domain.Topic) ([]domain.Subscription, error)
|
||||||
|
Update(ctx context.Context, suid domain.SUID, update UpdateFunc) error
|
||||||
|
Delete(ctx context.Context, suid domain.SUID) error
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrNotExist = errors.New("subscription does not exist")
|
||||||
|
ErrExist = errors.New("subscription already exists")
|
||||||
|
)
|
|
@ -0,0 +1,105 @@
|
||||||
|
package memory
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/domain"
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/subscription"
|
||||||
|
)
|
||||||
|
|
||||||
|
type memorySubscriptionRepository struct {
|
||||||
|
mutex *sync.RWMutex
|
||||||
|
subscriptions map[domain.SUID]domain.Subscription
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMemorySubscriptionRepository() subscription.Repository {
|
||||||
|
return &memorySubscriptionRepository{
|
||||||
|
mutex: new(sync.RWMutex),
|
||||||
|
subscriptions: make(map[domain.SUID]domain.Subscription),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *memorySubscriptionRepository) Create(ctx context.Context, suid domain.SUID, s domain.Subscription) error {
|
||||||
|
if _, err := repo.Get(ctx, suid); err != nil {
|
||||||
|
if !errors.Is(err, subscription.ErrNotExist) {
|
||||||
|
return fmt.Errorf("cannot create subscription: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("cannot create subscription: %w", subscription.ErrExist)
|
||||||
|
}
|
||||||
|
|
||||||
|
repo.mutex.Lock()
|
||||||
|
defer repo.mutex.Unlock()
|
||||||
|
|
||||||
|
repo.subscriptions[suid] = s
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *memorySubscriptionRepository) Delete(ctx context.Context, suid domain.SUID) error {
|
||||||
|
if _, err := repo.Get(ctx, suid); err != nil {
|
||||||
|
if !errors.Is(err, subscription.ErrNotExist) {
|
||||||
|
return fmt.Errorf("cannot delete subscription: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
repo.mutex.Lock()
|
||||||
|
defer repo.mutex.Unlock()
|
||||||
|
|
||||||
|
delete(repo.subscriptions, suid)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *memorySubscriptionRepository) Get(_ context.Context, suid domain.SUID) (*domain.Subscription, error) {
|
||||||
|
repo.mutex.RLock()
|
||||||
|
defer repo.mutex.RUnlock()
|
||||||
|
|
||||||
|
if out, ok := repo.subscriptions[suid]; ok {
|
||||||
|
return &out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, subscription.ErrNotExist
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *memorySubscriptionRepository) Fetch(ctx context.Context, t *domain.Topic) ([]domain.Subscription, error) {
|
||||||
|
repo.mutex.RLock()
|
||||||
|
defer repo.mutex.RUnlock()
|
||||||
|
|
||||||
|
out := make([]domain.Subscription, 0)
|
||||||
|
|
||||||
|
for _, s := range repo.subscriptions {
|
||||||
|
if t != nil && t.Self.String() != s.Topic.String() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
out = append(out, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update implements subscription.Repository
|
||||||
|
func (repo *memorySubscriptionRepository) Update(ctx context.Context, suid domain.SUID, update subscription.UpdateFunc) error {
|
||||||
|
in, err := repo.Get(ctx, suid)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot update subscription: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
repo.mutex.Lock()
|
||||||
|
defer repo.mutex.Unlock()
|
||||||
|
|
||||||
|
out, err := update(in)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot update subscription: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
repo.subscriptions[suid] = *out
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,12 @@
|
||||||
|
package subscription
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UseCase interface {
|
||||||
|
Subscribe(ctx context.Context, s domain.Subscription) (bool, error)
|
||||||
|
Unsubscribe(ctx context.Context, s domain.Subscription) (bool, error)
|
||||||
|
}
|
|
@ -0,0 +1,95 @@
|
||||||
|
package usecase
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/common"
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/domain"
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/subscription"
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/topic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type subscriptionUseCase struct {
|
||||||
|
topics topic.Repository
|
||||||
|
subscriptions subscription.Repository
|
||||||
|
client *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSubscriptionUseCase(subs subscription.Repository, tops topic.Repository, c *http.Client) subscription.UseCase {
|
||||||
|
return &subscriptionUseCase{
|
||||||
|
subscriptions: subs,
|
||||||
|
topics: tops,
|
||||||
|
client: c,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ucase *subscriptionUseCase) Subscribe(ctx context.Context, s domain.Subscription) (bool, error) {
|
||||||
|
now := time.Now().UTC().Round(time.Second)
|
||||||
|
|
||||||
|
if _, err := ucase.topics.Get(context.Background(), s.Topic); err != nil {
|
||||||
|
if !errors.Is(err, topic.ErrNotExist) {
|
||||||
|
return false, fmt.Errorf("cannot check subscription topic: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := ucase.client.Get(s.Topic.String())
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("cannot fetch a new topic subscription content: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("cannot read a new topic subscription content: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = ucase.topics.Create(ctx, s.Topic, domain.Topic{
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
Self: s.Topic,
|
||||||
|
ContentType: resp.Header.Get(common.HeaderContentType),
|
||||||
|
Content: content,
|
||||||
|
}); err != nil {
|
||||||
|
return false, fmt.Errorf("cannot create topic for subsciption: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ucase.subscriptions.Create(ctx, s.SUID(), domain.Subscription{
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
SyncedAt: now,
|
||||||
|
ExpiredAt: s.ExpiredAt,
|
||||||
|
Callback: s.Callback,
|
||||||
|
Topic: s.Topic,
|
||||||
|
Secret: s.Secret,
|
||||||
|
}); err != nil {
|
||||||
|
if !errors.Is(err, subscription.ErrExist) {
|
||||||
|
return false, fmt.Errorf("cannot create a new subscription: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = ucase.subscriptions.Update(ctx, s.SUID(), func(tx *domain.Subscription) (*domain.Subscription,
|
||||||
|
error,
|
||||||
|
) {
|
||||||
|
tx.UpdatedAt = now
|
||||||
|
tx.ExpiredAt = now.Add(time.Duration(s.LeaseSeconds()) * time.Second)
|
||||||
|
tx.Secret = s.Secret
|
||||||
|
|
||||||
|
return tx, nil
|
||||||
|
}); err != nil {
|
||||||
|
return false, fmt.Errorf("cannot resubscribe existing subscription: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ucase *subscriptionUseCase) Unsubscribe(ctx context.Context, s domain.Subscription) (bool, error) {
|
||||||
|
if err := ucase.subscriptions.Delete(ctx, s.SUID()); err != nil {
|
||||||
|
return false, fmt.Errorf("cannot unsubscribe: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
|
@ -0,0 +1,97 @@
|
||||||
|
package usecase_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/common"
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/domain"
|
||||||
|
subscriptionmemoryrepo "source.toby3d.me/toby3d/hub/internal/subscription/repository/memory"
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/subscription/usecase"
|
||||||
|
topicmemoryrepo "source.toby3d.me/toby3d/hub/internal/topic/repository/memory"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSubscriptionUseCase_Subscribe(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
topic := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.Header().Set(common.HeaderContentType, common.MIMETextPlainCharsetUTF8)
|
||||||
|
fmt.Fprint(w, "hello, world")
|
||||||
|
}))
|
||||||
|
t.Cleanup(topic.Close)
|
||||||
|
|
||||||
|
callback := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.Header().Set(common.HeaderContentType, common.MIMETextPlainCharsetUTF8)
|
||||||
|
fmt.Fprint(w, "hello, world")
|
||||||
|
}))
|
||||||
|
t.Cleanup(callback.Close)
|
||||||
|
|
||||||
|
subscription := domain.TestSubscription(t, callback.URL)
|
||||||
|
subscription.Topic, _ = url.Parse(topic.URL + "/")
|
||||||
|
topics := topicmemoryrepo.NewMemoryTopicRepository()
|
||||||
|
subscriptions := subscriptionmemoryrepo.NewMemorySubscriptionRepository()
|
||||||
|
|
||||||
|
ucase := usecase.NewSubscriptionUseCase(subscriptions, topics, callback.Client())
|
||||||
|
|
||||||
|
ok, err := ucase.Subscribe(context.Background(), *subscription)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("want %t, got %t", true, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := subscriptions.Get(context.Background(), subscription.SUID()); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("resubscribe", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ok, err := ucase.Subscribe(context.Background(), *subscription)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("want %t, got %t", true, ok)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriptionUseCase_Unsubscribe(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.Header().Set(common.HeaderContentType, common.MIMETextPlainCharsetUTF8)
|
||||||
|
fmt.Fprint(w, "hello, world")
|
||||||
|
}))
|
||||||
|
t.Cleanup(srv.Close)
|
||||||
|
|
||||||
|
subscription := domain.TestSubscription(t, "https://example.com/")
|
||||||
|
topics := topicmemoryrepo.NewMemoryTopicRepository()
|
||||||
|
subscriptions := subscriptionmemoryrepo.NewMemorySubscriptionRepository()
|
||||||
|
|
||||||
|
if err := subscriptions.Create(context.Background(), subscription.SUID(), *subscription); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := usecase.NewSubscriptionUseCase(subscriptions, topics, srv.Client()).
|
||||||
|
Unsubscribe(context.Background(), *subscription)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("want %t, got %t", true, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := subscriptions.Get(context.Background(), subscription.SUID()); err == nil {
|
||||||
|
t.Error("want error, got nil")
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,40 @@
|
||||||
|
package sqltest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
|
"github.com/jmoiron/sqlx"
|
||||||
|
_ "modernc.org/sqlite" // used for running tests without same import in "god object"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Time struct{}
|
||||||
|
|
||||||
|
func (Time) Match(v driver.Value) bool {
|
||||||
|
_, ok := v.(time.Time)
|
||||||
|
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open creates a new InMemory sqlite3 database for testing.
|
||||||
|
func Open(tb testing.TB) (*sqlx.DB, sqlmock.Sqlmock, func()) {
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
if err != nil {
|
||||||
|
tb.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
xdb := sqlx.NewDb(db, "sqlite")
|
||||||
|
if err = xdb.Ping(); err != nil {
|
||||||
|
_ = db.Close()
|
||||||
|
|
||||||
|
tb.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return xdb, mock, func() {
|
||||||
|
_ = db.Close()
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
package topic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
UpdateFunc func(t *domain.Topic) (*domain.Topic, error)
|
||||||
|
|
||||||
|
Repository interface {
|
||||||
|
Create(ctx context.Context, u *url.URL, topic domain.Topic) error
|
||||||
|
Update(ctx context.Context, u *url.URL, update UpdateFunc) error
|
||||||
|
Fetch(ctx context.Context) ([]domain.Topic, error)
|
||||||
|
Get(ctx context.Context, u *url.URL) (*domain.Topic, error)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrExist = errors.New("topic already exists")
|
||||||
|
ErrNotExist = errors.New("topic does not exist")
|
||||||
|
)
|
|
@ -0,0 +1,85 @@
|
||||||
|
package memory
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/domain"
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/topic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type memoryTopicRepository struct {
|
||||||
|
mutex *sync.RWMutex
|
||||||
|
topics map[string]domain.Topic
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMemoryTopicRepository() topic.Repository {
|
||||||
|
return &memoryTopicRepository{
|
||||||
|
mutex: new(sync.RWMutex),
|
||||||
|
topics: make(map[string]domain.Topic),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *memoryTopicRepository) Update(ctx context.Context, u *url.URL, update topic.UpdateFunc) error {
|
||||||
|
tx, err := repo.Get(ctx, u)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot find updating topic: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
repo.mutex.Lock()
|
||||||
|
defer repo.mutex.Unlock()
|
||||||
|
|
||||||
|
result, err := update(tx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot update topic: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
repo.topics[u.String()] = *result
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *memoryTopicRepository) Create(ctx context.Context, u *url.URL, t domain.Topic) error {
|
||||||
|
_, err := repo.Get(ctx, u)
|
||||||
|
if err != nil && !errors.Is(err, topic.ErrNotExist) {
|
||||||
|
return fmt.Errorf("cannot get topic: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
return topic.ErrExist
|
||||||
|
}
|
||||||
|
|
||||||
|
repo.mutex.Lock()
|
||||||
|
defer repo.mutex.Unlock()
|
||||||
|
|
||||||
|
repo.topics[u.String()] = t
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *memoryTopicRepository) Get(ctx context.Context, u *url.URL) (*domain.Topic, error) {
|
||||||
|
repo.mutex.RLock()
|
||||||
|
defer repo.mutex.RUnlock()
|
||||||
|
|
||||||
|
if out, ok := repo.topics[u.String()]; ok {
|
||||||
|
return &out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, topic.ErrNotExist
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *memoryTopicRepository) Fetch(_ context.Context) ([]domain.Topic, error) {
|
||||||
|
repo.mutex.RLock()
|
||||||
|
defer repo.mutex.RUnlock()
|
||||||
|
|
||||||
|
out := make([]domain.Topic, 0)
|
||||||
|
|
||||||
|
for _, t := range repo.topics {
|
||||||
|
out = append(out, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
|
@ -0,0 +1,10 @@
|
||||||
|
package topic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UseCase interface {
|
||||||
|
Publish(ctx context.Context, u *url.URL) (bool, error)
|
||||||
|
}
|
|
@ -0,0 +1,66 @@
|
||||||
|
package usecase
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/common"
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/domain"
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/topic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type topicUseCase struct {
|
||||||
|
client *http.Client
|
||||||
|
topics topic.Repository
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTopicUseCase(topics topic.Repository, client *http.Client) topic.UseCase {
|
||||||
|
return &topicUseCase{
|
||||||
|
topics: topics,
|
||||||
|
client: client,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ucase *topicUseCase) Publish(ctx context.Context, u *url.URL) (bool, error) {
|
||||||
|
now := time.Now().UTC().Round(time.Second)
|
||||||
|
|
||||||
|
resp, err := ucase.client.Get(u.String())
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("cannot fetch publishing url: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("cannot read topic response body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ucase.topics.Update(ctx, u, func(tx *domain.Topic) (*domain.Topic, error) {
|
||||||
|
tx.Self = resp.Request.URL
|
||||||
|
tx.UpdatedAt = now
|
||||||
|
tx.Content = content
|
||||||
|
tx.ContentType = resp.Header.Get(common.HeaderContentType)
|
||||||
|
|
||||||
|
return tx, nil
|
||||||
|
}); err != nil {
|
||||||
|
if !errors.Is(err, topic.ErrNotExist) {
|
||||||
|
return false, fmt.Errorf("cannot publish exists topic: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = ucase.topics.Create(ctx, resp.Request.URL, domain.Topic{
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
Self: resp.Request.URL,
|
||||||
|
ContentType: resp.Header.Get(common.HeaderContentType),
|
||||||
|
Content: content,
|
||||||
|
}); err != nil {
|
||||||
|
return false, fmt.Errorf("cannot publish a new topic: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
|
@ -0,0 +1,43 @@
|
||||||
|
package usecase_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/common"
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/domain"
|
||||||
|
topicmemoryrepo "source.toby3d.me/toby3d/hub/internal/topic/repository/memory"
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/topic/usecase"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTopicUseCase_Publish(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
topic := domain.TestTopic(t)
|
||||||
|
topics := topicmemoryrepo.NewMemoryTopicRepository()
|
||||||
|
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.Header().Set(common.HeaderContentType, topic.ContentType)
|
||||||
|
fmt.Fprint(w, topic.Content)
|
||||||
|
}))
|
||||||
|
t.Cleanup(srv.Close)
|
||||||
|
|
||||||
|
topic.Self, _ = url.Parse(srv.URL + "/")
|
||||||
|
|
||||||
|
ok, err := usecase.NewTopicUseCase(topics, srv.Client()).
|
||||||
|
Publish(context.Background(), topic.Self)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("want %t, got %t", true, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := topics.Get(context.Background(), topic.Self); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,22 @@
|
||||||
|
package urlutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ShiftPath splits off the first component of p, which will be cleaned of
|
||||||
|
// relative components before processing. head will never contain a slash and
|
||||||
|
// tail will always be a rooted path without trailing slash.
|
||||||
|
//
|
||||||
|
// See: https://blog.merovius.de/posts/2017-06-18-how-not-to-use-an-http-router/
|
||||||
|
func ShiftPath(p string) (head, tail string) {
|
||||||
|
p = path.Clean("/" + p)
|
||||||
|
|
||||||
|
i := strings.Index(p[1:], "/") + 1
|
||||||
|
if i <= 0 {
|
||||||
|
return p[1:], "/"
|
||||||
|
}
|
||||||
|
|
||||||
|
return p[1:i], p[i:]
|
||||||
|
}
|
|
@ -0,0 +1,37 @@
|
||||||
|
package urlutil_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/urlutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestShiftPath(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
for name, tc := range map[string]struct {
|
||||||
|
input string
|
||||||
|
expect [2]string
|
||||||
|
}{
|
||||||
|
"empty": {input: "", expect: [2]string{"", "/"}},
|
||||||
|
"root": {input: "/", expect: [2]string{"", "/"}},
|
||||||
|
"page": {input: "/foo", expect: [2]string{"foo", "/"}},
|
||||||
|
"folder": {input: "/foo/bar", expect: [2]string{"foo", "/bar"}},
|
||||||
|
} {
|
||||||
|
name, tc := name, tc
|
||||||
|
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
head, tail := urlutil.ShiftPath(tc.input)
|
||||||
|
|
||||||
|
if head != tc.expect[0] {
|
||||||
|
t.Errorf("want '%s', got '%s'", tc.expect[0], head)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tail != tc.expect[1] {
|
||||||
|
t.Errorf("want '%s', got '%s'", tc.expect[1], tail)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,67 @@
|
||||||
|
{
|
||||||
|
"language": "en",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"id": "version",
|
||||||
|
"message": "version",
|
||||||
|
"translation": "version",
|
||||||
|
"translatorComment": "Copied from source.",
|
||||||
|
"fuzzy": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "{Name} logo",
|
||||||
|
"message": "{Name} logo",
|
||||||
|
"translation": "{Name} logo",
|
||||||
|
"translatorComment": "Copied from source.",
|
||||||
|
"placeholders": [
|
||||||
|
{
|
||||||
|
"id": "Name",
|
||||||
|
"string": "%[1]s",
|
||||||
|
"type": "string",
|
||||||
|
"underlyingType": "string",
|
||||||
|
"argNum": 1,
|
||||||
|
"expr": "p.name"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"fuzzy": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Dead simple WebSub hub",
|
||||||
|
"message": "Dead simple WebSub hub",
|
||||||
|
"translation": "Dead simple WebSub hub",
|
||||||
|
"translatorComment": "Copied from source.",
|
||||||
|
"fuzzy": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "How to publish and consume?",
|
||||||
|
"message": "How to publish and consume?",
|
||||||
|
"translation": "How to publish and consume?",
|
||||||
|
"translatorComment": "Copied from source.",
|
||||||
|
"fuzzy": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "What the spec?",
|
||||||
|
"message": "What the spec?",
|
||||||
|
"translation": "What the spec?",
|
||||||
|
"translatorComment": "Copied from source.",
|
||||||
|
"fuzzy": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "{Subscribers} subscribers",
|
||||||
|
"message": "{Subscribers} subscribers",
|
||||||
|
"translation": "{Subscribers} subscribers",
|
||||||
|
"translatorComment": "Copied from source.",
|
||||||
|
"placeholders": [
|
||||||
|
{
|
||||||
|
"id": "Subscribers",
|
||||||
|
"string": "%[1]d",
|
||||||
|
"type": "int",
|
||||||
|
"underlyingType": "int",
|
||||||
|
"argNum": 1,
|
||||||
|
"expr": "p.Subscribers"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"fuzzy": true
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
|
@ -0,0 +1,55 @@
|
||||||
|
{
|
||||||
|
"language": "ru",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"id": "version",
|
||||||
|
"message": "version",
|
||||||
|
"translation": "версия"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "{Name} logo",
|
||||||
|
"message": "{Name} logo",
|
||||||
|
"translation": "логотип {Name}",
|
||||||
|
"placeholders": [
|
||||||
|
{
|
||||||
|
"id": "Name",
|
||||||
|
"string": "%[1]s",
|
||||||
|
"type": "string",
|
||||||
|
"underlyingType": "string",
|
||||||
|
"argNum": 1,
|
||||||
|
"expr": "p.name"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Dead simple WebSub hub",
|
||||||
|
"message": "Dead simple WebSub hub",
|
||||||
|
"translation": "Простейший хаб WebSub"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "How to publish and consume?",
|
||||||
|
"message": "How to publish and consume?",
|
||||||
|
"translation": "Как публиковать и принимать?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "What the spec?",
|
||||||
|
"message": "What the spec?",
|
||||||
|
"translation": "В чём спека?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "{Subscribers} subscribers",
|
||||||
|
"message": "{Subscribers} subscribers",
|
||||||
|
"translation": "{Subscribers} подписчиков",
|
||||||
|
"placeholders": [
|
||||||
|
{
|
||||||
|
"id": "Subscribers",
|
||||||
|
"string": "%[1]d",
|
||||||
|
"type": "int",
|
||||||
|
"underlyingType": "int",
|
||||||
|
"argNum": 1,
|
||||||
|
"expr": "p.Subscribers"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
|
@ -0,0 +1,55 @@
|
||||||
|
{
|
||||||
|
"language": "ru",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"id": "version",
|
||||||
|
"message": "version",
|
||||||
|
"translation": "версия"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "{Name} logo",
|
||||||
|
"message": "{Name} logo",
|
||||||
|
"translation": "логотип {Name}",
|
||||||
|
"placeholders": [
|
||||||
|
{
|
||||||
|
"id": "Name",
|
||||||
|
"string": "%[1]s",
|
||||||
|
"type": "string",
|
||||||
|
"underlyingType": "string",
|
||||||
|
"argNum": 1,
|
||||||
|
"expr": "p.name"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Dead simple WebSub hub",
|
||||||
|
"message": "Dead simple WebSub hub",
|
||||||
|
"translation": "Простейший хаб WebSub"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "How to publish and consume?",
|
||||||
|
"message": "How to publish and consume?",
|
||||||
|
"translation": "Как публиковать и принимать?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "What the spec?",
|
||||||
|
"message": "What the spec?",
|
||||||
|
"translation": "В чём спека?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "{Subscribers} subscribers",
|
||||||
|
"message": "{Subscribers} subscribers",
|
||||||
|
"translation": "{Subscribers} подписчиков",
|
||||||
|
"placeholders": [
|
||||||
|
{
|
||||||
|
"id": "Subscribers",
|
||||||
|
"string": "%[1]d",
|
||||||
|
"type": "int",
|
||||||
|
"underlyingType": "int",
|
||||||
|
"argNum": 1,
|
||||||
|
"expr": "p.Subscribers"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
|
@ -0,0 +1,111 @@
|
||||||
|
//go:generate go install github.com/valyala/quicktemplate/qtc@latest
|
||||||
|
//go:generate qtc -dir=web
|
||||||
|
//go:generate go install golang.org/x/text/cmd/gotext@latest
|
||||||
|
//go:generate gotext -srclang=en update -out=catalog_gen.go -lang=en,ru
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"embed"
|
||||||
|
"io/fs"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/caarlos0/env/v7"
|
||||||
|
"golang.org/x/text/feature/plural"
|
||||||
|
"golang.org/x/text/language"
|
||||||
|
"golang.org/x/text/message"
|
||||||
|
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/domain"
|
||||||
|
hubhttprelivery "source.toby3d.me/toby3d/hub/internal/hub/delivery/http"
|
||||||
|
hubucase "source.toby3d.me/toby3d/hub/internal/hub/usecase"
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/middleware"
|
||||||
|
subscriptionmemoryrepo "source.toby3d.me/toby3d/hub/internal/subscription/repository/memory"
|
||||||
|
subscriptionucase "source.toby3d.me/toby3d/hub/internal/subscription/usecase"
|
||||||
|
topicmemoryrepo "source.toby3d.me/toby3d/hub/internal/topic/repository/memory"
|
||||||
|
topicucase "source.toby3d.me/toby3d/hub/internal/topic/usecase"
|
||||||
|
"source.toby3d.me/toby3d/hub/internal/urlutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
var logger = log.New(os.Stdout, "hub", log.LstdFlags)
|
||||||
|
|
||||||
|
//go:embed web/static/*
|
||||||
|
var static embed.FS
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
message.Set(language.English, "%d subscribers",
|
||||||
|
plural.Selectf(1, "%d",
|
||||||
|
"one", "%d subscriber",
|
||||||
|
"other", "%d subscribers",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
message.Set(language.Russian, "%d subscribers",
|
||||||
|
plural.Selectf(1, "%d",
|
||||||
|
"one", "%d подписчик",
|
||||||
|
"few", "%d подписчика",
|
||||||
|
"many", "%d подписчиков",
|
||||||
|
"other", "%d подписчика",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
config := new(domain.Config)
|
||||||
|
if err := env.Parse(config, env.Options{
|
||||||
|
Prefix: "HUB_",
|
||||||
|
TagName: "env",
|
||||||
|
UseFieldNameByDefault: true,
|
||||||
|
}); err != nil {
|
||||||
|
logger.Fatalln(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
static, err := fs.Sub(static, filepath.Join("web"))
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatalln(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: 5 * time.Second}
|
||||||
|
matcher := language.NewMatcher(message.DefaultCatalog.Languages())
|
||||||
|
subscriptions := subscriptionmemoryrepo.NewMemorySubscriptionRepository()
|
||||||
|
topics := topicmemoryrepo.NewMemoryTopicRepository()
|
||||||
|
topicService := topicucase.NewTopicUseCase(topics, client)
|
||||||
|
subscriptionService := subscriptionucase.NewSubscriptionUseCase(subscriptions, topics, client)
|
||||||
|
hubService := hubucase.NewHubUseCase(topics, subscriptions, client, config.BaseURL)
|
||||||
|
|
||||||
|
handler := hubhttprelivery.NewHandler(hubhttprelivery.NewHandlerParams{
|
||||||
|
Hub: hubService,
|
||||||
|
Subscriptions: subscriptionService,
|
||||||
|
Topics: topicService,
|
||||||
|
Matcher: matcher,
|
||||||
|
Name: config.Name,
|
||||||
|
})
|
||||||
|
|
||||||
|
server := &http.Server{
|
||||||
|
Addr: config.Bind,
|
||||||
|
Handler: http.HandlerFunc(middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
head, _ := urlutil.ShiftPath(r.URL.Path)
|
||||||
|
|
||||||
|
switch head {
|
||||||
|
case "":
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
case "static":
|
||||||
|
http.FileServer(http.FS(static)).ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
}).Intercept(middleware.LogFmt())),
|
||||||
|
ReadTimeout: 5 * time.Second,
|
||||||
|
WriteTimeout: 5 * time.Second,
|
||||||
|
ErrorLog: logger,
|
||||||
|
}
|
||||||
|
|
||||||
|
go hubService.ListenAndServe(ctx)
|
||||||
|
|
||||||
|
logger.Printf("started %s on %s: %s", config.Name, config.Bind, config.BaseURL.String())
|
||||||
|
if err = server.ListenAndServe(); err != nil {
|
||||||
|
logger.Fatalln(err)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,4 @@
|
||||||
|
/examples/blog/blog
|
||||||
|
/examples/orders/orders
|
||||||
|
/examples/basic/basic
|
||||||
|
.idea/
|
|
@ -0,0 +1,26 @@
|
||||||
|
language: go
|
||||||
|
|
||||||
|
go_import_path: github.com/DATA-DOG/go-sqlmock
|
||||||
|
|
||||||
|
go:
|
||||||
|
- 1.2.x
|
||||||
|
- 1.3.x
|
||||||
|
- 1.4 # has no cover tool for latest releases
|
||||||
|
- 1.5.x
|
||||||
|
- 1.6.x
|
||||||
|
- 1.7.x
|
||||||
|
- 1.8.x
|
||||||
|
- 1.9.x
|
||||||
|
- 1.10.x
|
||||||
|
- 1.11.x
|
||||||
|
- 1.12.x
|
||||||
|
- 1.13.x
|
||||||
|
- 1.14.x
|
||||||
|
|
||||||
|
script:
|
||||||
|
- go vet
|
||||||
|
- test -z "$(go fmt ./...)" # fail if not formatted properly
|
||||||
|
- go test -race -coverprofile=coverage.txt -covermode=atomic
|
||||||
|
|
||||||
|
after_success:
|
||||||
|
- bash <(curl -s https://codecov.io/bash)
|
|
@ -0,0 +1,28 @@
|
||||||
|
The three clause BSD license (http://en.wikipedia.org/wiki/BSD_licenses)
|
||||||
|
|
||||||
|
Copyright (c) 2013-2019, DATA-DOG team
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
* Redistributions of source code must retain the above copyright notice, this
|
||||||
|
list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
* Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
this list of conditions and the following disclaimer in the documentation
|
||||||
|
and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
* The name DataDog.lt may not be used to endorse or promote products
|
||||||
|
derived from this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
DISCLAIMED. IN NO EVENT SHALL MICHAEL BOSTOCK BE LIABLE FOR ANY DIRECT,
|
||||||
|
INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||||
|
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||||
|
OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||||
|
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
|
||||||
|
EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,265 @@
|
||||||
|
[![Build Status](https://travis-ci.org/DATA-DOG/go-sqlmock.svg)](https://travis-ci.org/DATA-DOG/go-sqlmock)
|
||||||
|
[![GoDoc](https://godoc.org/github.com/DATA-DOG/go-sqlmock?status.svg)](https://godoc.org/github.com/DATA-DOG/go-sqlmock)
|
||||||
|
[![Go Report Card](https://goreportcard.com/badge/github.com/DATA-DOG/go-sqlmock)](https://goreportcard.com/report/github.com/DATA-DOG/go-sqlmock)
|
||||||
|
[![codecov.io](https://codecov.io/github/DATA-DOG/go-sqlmock/branch/master/graph/badge.svg)](https://codecov.io/github/DATA-DOG/go-sqlmock)
|
||||||
|
|
||||||
|
# Sql driver mock for Golang
|
||||||
|
|
||||||
|
**sqlmock** is a mock library implementing [sql/driver](https://godoc.org/database/sql/driver). Which has one and only
|
||||||
|
purpose - to simulate any **sql** driver behavior in tests, without needing a real database connection. It helps to
|
||||||
|
maintain correct **TDD** workflow.
|
||||||
|
|
||||||
|
- this library is now complete and stable. (you may not find new changes for this reason)
|
||||||
|
- supports concurrency and multiple connections.
|
||||||
|
- supports **go1.8** Context related feature mocking and Named sql parameters.
|
||||||
|
- does not require any modifications to your source code.
|
||||||
|
- the driver allows to mock any sql driver method behavior.
|
||||||
|
- has strict by default expectation order matching.
|
||||||
|
- has no third party dependencies.
|
||||||
|
|
||||||
|
**NOTE:** in **v1.2.0** **sqlmock.Rows** has changed to struct from interface, if you were using any type references to that
|
||||||
|
interface, you will need to switch it to a pointer struct type. Also, **sqlmock.Rows** were used to implement **driver.Rows**
|
||||||
|
interface, which was not required or useful for mocking and was removed. Hope it will not cause issues.
|
||||||
|
|
||||||
|
## Looking for maintainers
|
||||||
|
|
||||||
|
I do not have much spare time for this library and willing to transfer the repository ownership
|
||||||
|
to person or an organization motivated to maintain it. Open up a conversation if you are interested. See #230.
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
go get github.com/DATA-DOG/go-sqlmock
|
||||||
|
|
||||||
|
## Documentation and Examples
|
||||||
|
|
||||||
|
Visit [godoc](http://godoc.org/github.com/DATA-DOG/go-sqlmock) for general examples and public api reference.
|
||||||
|
See **.travis.yml** for supported **go** versions.
|
||||||
|
Different use case, is to functionally test with a real database - [go-txdb](https://github.com/DATA-DOG/go-txdb)
|
||||||
|
all database related actions are isolated within a single transaction so the database can remain in the same state.
|
||||||
|
|
||||||
|
See implementation examples:
|
||||||
|
|
||||||
|
- [blog API server](https://github.com/DATA-DOG/go-sqlmock/tree/master/examples/blog)
|
||||||
|
- [the same orders example](https://github.com/DATA-DOG/go-sqlmock/tree/master/examples/orders)
|
||||||
|
|
||||||
|
### Something you may want to test, assuming you use the [go-mysql-driver](https://github.com/go-sql-driver/mysql)
|
||||||
|
|
||||||
|
``` go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
_ "github.com/go-sql-driver/mysql"
|
||||||
|
)
|
||||||
|
|
||||||
|
func recordStats(db *sql.DB, userID, productID int64) (err error) {
|
||||||
|
tx, err = db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
switch err {
|
||||||
|
case nil:
|
||||||
|
err = tx.Commit()
|
||||||
|
default:
|
||||||
|
tx.Rollback()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if _, err = tx.Exec("UPDATE products SET views = views + 1"); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err = tx.Exec("INSERT INTO product_viewers (user_id, product_id) VALUES (?, ?)", userID, productID); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// @NOTE: the real connection is not required for tests
|
||||||
|
db, err := sql.Open("mysql", "root@/blog")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
if err = recordStats(db, 1 /*some user id*/, 5 /*some product id*/); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Tests with sqlmock
|
||||||
|
|
||||||
|
``` go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// a successful case
|
||||||
|
func TestShouldUpdateStats(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
mock.ExpectBegin()
|
||||||
|
mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||||
|
mock.ExpectExec("INSERT INTO product_viewers").WithArgs(2, 3).WillReturnResult(sqlmock.NewResult(1, 1))
|
||||||
|
mock.ExpectCommit()
|
||||||
|
|
||||||
|
// now we execute our method
|
||||||
|
if err = recordStats(db, 2, 3); err != nil {
|
||||||
|
t.Errorf("error was not expected while updating stats: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// we make sure that all expectations were met
|
||||||
|
if err := mock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("there were unfulfilled expectations: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// a failing test case
|
||||||
|
func TestShouldRollbackStatUpdatesOnFailure(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
mock.ExpectBegin()
|
||||||
|
mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||||
|
mock.ExpectExec("INSERT INTO product_viewers").
|
||||||
|
WithArgs(2, 3).
|
||||||
|
WillReturnError(fmt.Errorf("some error"))
|
||||||
|
mock.ExpectRollback()
|
||||||
|
|
||||||
|
// now we execute our method
|
||||||
|
if err = recordStats(db, 2, 3); err == nil {
|
||||||
|
t.Errorf("was expecting an error, but there was none")
|
||||||
|
}
|
||||||
|
|
||||||
|
// we make sure that all expectations were met
|
||||||
|
if err := mock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("there were unfulfilled expectations: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Customize SQL query matching
|
||||||
|
|
||||||
|
There were plenty of requests from users regarding SQL query string validation or different matching option.
|
||||||
|
We have now implemented the `QueryMatcher` interface, which can be passed through an option when calling
|
||||||
|
`sqlmock.New` or `sqlmock.NewWithDSN`.
|
||||||
|
|
||||||
|
This now allows to include some library, which would allow for example to parse and validate `mysql` SQL AST.
|
||||||
|
And create a custom QueryMatcher in order to validate SQL in sophisticated ways.
|
||||||
|
|
||||||
|
By default, **sqlmock** is preserving backward compatibility and default query matcher is `sqlmock.QueryMatcherRegexp`
|
||||||
|
which uses expected SQL string as a regular expression to match incoming query string. There is an equality matcher:
|
||||||
|
`QueryMatcherEqual` which will do a full case sensitive match.
|
||||||
|
|
||||||
|
In order to customize the QueryMatcher, use the following:
|
||||||
|
|
||||||
|
``` go
|
||||||
|
db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
|
||||||
|
```
|
||||||
|
|
||||||
|
The query matcher can be fully customized based on user needs. **sqlmock** will not
|
||||||
|
provide a standard sql parsing matchers, since various drivers may not follow the same SQL standard.
|
||||||
|
|
||||||
|
## Matching arguments like time.Time
|
||||||
|
|
||||||
|
There may be arguments which are of `struct` type and cannot be compared easily by value like `time.Time`. In this case
|
||||||
|
**sqlmock** provides an [Argument](https://godoc.org/github.com/DATA-DOG/go-sqlmock#Argument) interface which
|
||||||
|
can be used in more sophisticated matching. Here is a simple example of time argument matching:
|
||||||
|
|
||||||
|
``` go
|
||||||
|
type AnyTime struct{}
|
||||||
|
|
||||||
|
// Match satisfies sqlmock.Argument interface
|
||||||
|
func (a AnyTime) Match(v driver.Value) bool {
|
||||||
|
_, ok := v.(time.Time)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnyTimeArgument(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
db, mock, err := New()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
mock.ExpectExec("INSERT INTO users").
|
||||||
|
WithArgs("john", AnyTime{}).
|
||||||
|
WillReturnResult(NewResult(1, 1))
|
||||||
|
|
||||||
|
_, err = db.Exec("INSERT INTO users(name, created_at) VALUES (?, ?)", "john", time.Now())
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error '%s' was not expected, while inserting a row", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("there were unfulfilled expectations: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
It only asserts that argument is of `time.Time` type.
|
||||||
|
|
||||||
|
## Run tests
|
||||||
|
|
||||||
|
go test -race
|
||||||
|
|
||||||
|
## Change Log
|
||||||
|
|
||||||
|
- **2019-04-06** - added functionality to mock a sql MetaData request
|
||||||
|
- **2019-02-13** - added `go.mod` removed the references and suggestions using `gopkg.in`.
|
||||||
|
- **2018-12-11** - added expectation of Rows to be closed, while mocking expected query.
|
||||||
|
- **2018-12-11** - introduced an option to provide **QueryMatcher** in order to customize SQL query matching.
|
||||||
|
- **2017-09-01** - it is now possible to expect that prepared statement will be closed,
|
||||||
|
using **ExpectedPrepare.WillBeClosed**.
|
||||||
|
- **2017-02-09** - implemented support for **go1.8** features. **Rows** interface was changed to struct
|
||||||
|
but contains all methods as before and should maintain backwards compatibility. **ExpectedQuery.WillReturnRows** may now
|
||||||
|
accept multiple row sets.
|
||||||
|
- **2016-11-02** - `db.Prepare()` was not validating expected prepare SQL
|
||||||
|
query. It should still be validated even if Exec or Query is not
|
||||||
|
executed on that prepared statement.
|
||||||
|
- **2016-02-23** - added **sqlmock.AnyArg()** function to provide any kind
|
||||||
|
of argument matcher.
|
||||||
|
- **2016-02-23** - convert expected arguments to driver.Value as natural
|
||||||
|
driver does, the change may affect time.Time comparison and will be
|
||||||
|
stricter. See [issue](https://github.com/DATA-DOG/go-sqlmock/issues/31).
|
||||||
|
- **2015-08-27** - **v1** api change, concurrency support, all known issues fixed.
|
||||||
|
- **2014-08-16** instead of **panic** during reflect type mismatch when comparing query arguments - now return error
|
||||||
|
- **2014-08-14** added **sqlmock.NewErrorResult** which gives an option to return driver.Result with errors for
|
||||||
|
interface methods, see [issue](https://github.com/DATA-DOG/go-sqlmock/issues/5)
|
||||||
|
- **2014-05-29** allow to match arguments in more sophisticated ways, by providing an **sqlmock.Argument** interface
|
||||||
|
- **2014-04-21** introduce **sqlmock.New()** to open a mock database connection for tests. This method
|
||||||
|
calls sql.DB.Ping to ensure that connection is open, see [issue](https://github.com/DATA-DOG/go-sqlmock/issues/4).
|
||||||
|
This way on Close it will surely assert if all expectations are met, even if database was not triggered at all.
|
||||||
|
The old way is still available, but it is advisable to call db.Ping manually before asserting with db.Close.
|
||||||
|
- **2014-02-14** RowsFromCSVString is now a part of Rows interface named as FromCSVString.
|
||||||
|
It has changed to allow more ways to construct rows and to easily extend this API in future.
|
||||||
|
See [issue 1](https://github.com/DATA-DOG/go-sqlmock/issues/1)
|
||||||
|
**RowsFromCSVString** is deprecated and will be removed in future
|
||||||
|
|
||||||
|
## Contributions
|
||||||
|
|
||||||
|
Feel free to open a pull request. Note, if you wish to contribute an extension to public (exported methods or types) -
|
||||||
|
please open an issue before, to discuss whether these changes can be accepted. All backward incompatible changes are
|
||||||
|
and will be treated cautiously
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
The [three clause BSD license](http://en.wikipedia.org/wiki/BSD_licenses)
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
package sqlmock
|
||||||
|
|
||||||
|
import "database/sql/driver"
|
||||||
|
|
||||||
|
// Argument interface allows to match
|
||||||
|
// any argument in specific way when used with
|
||||||
|
// ExpectedQuery and ExpectedExec expectations.
|
||||||
|
type Argument interface {
|
||||||
|
Match(driver.Value) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// AnyArg will return an Argument which can
|
||||||
|
// match any kind of arguments.
|
||||||
|
//
|
||||||
|
// Useful for time.Time or similar kinds of arguments.
|
||||||
|
func AnyArg() Argument {
|
||||||
|
return anyArgument{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type anyArgument struct{}
|
||||||
|
|
||||||
|
func (a anyArgument) Match(_ driver.Value) bool {
|
||||||
|
return true
|
||||||
|
}
|
|
@ -0,0 +1,77 @@
|
||||||
|
package sqlmock
|
||||||
|
|
||||||
|
import "reflect"
|
||||||
|
|
||||||
|
// Column is a mocked column Metadata for rows.ColumnTypes()
|
||||||
|
type Column struct {
|
||||||
|
name string
|
||||||
|
dbType string
|
||||||
|
nullable bool
|
||||||
|
nullableOk bool
|
||||||
|
length int64
|
||||||
|
lengthOk bool
|
||||||
|
precision int64
|
||||||
|
scale int64
|
||||||
|
psOk bool
|
||||||
|
scanType reflect.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Column) Name() string {
|
||||||
|
return c.name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Column) DbType() string {
|
||||||
|
return c.dbType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Column) IsNullable() (bool, bool) {
|
||||||
|
return c.nullable, c.nullableOk
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Column) Length() (int64, bool) {
|
||||||
|
return c.length, c.lengthOk
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Column) PrecisionScale() (int64, int64, bool) {
|
||||||
|
return c.precision, c.scale, c.psOk
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Column) ScanType() reflect.Type {
|
||||||
|
return c.scanType
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewColumn returns a Column with specified name
|
||||||
|
func NewColumn(name string) *Column {
|
||||||
|
return &Column{
|
||||||
|
name: name,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Nullable returns the column with nullable metadata set
|
||||||
|
func (c *Column) Nullable(nullable bool) *Column {
|
||||||
|
c.nullable = nullable
|
||||||
|
c.nullableOk = true
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// OfType returns the column with type metadata set
|
||||||
|
func (c *Column) OfType(dbType string, sampleValue interface{}) *Column {
|
||||||
|
c.dbType = dbType
|
||||||
|
c.scanType = reflect.TypeOf(sampleValue)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithLength returns the column with length metadata set.
|
||||||
|
func (c *Column) WithLength(length int64) *Column {
|
||||||
|
c.length = length
|
||||||
|
c.lengthOk = true
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithPrecisionAndScale returns the column with precision and scale metadata set.
|
||||||
|
func (c *Column) WithPrecisionAndScale(precision, scale int64) *Column {
|
||||||
|
c.precision = precision
|
||||||
|
c.scale = scale
|
||||||
|
c.psOk = true
|
||||||
|
return c
|
||||||
|
}
|
|
@ -0,0 +1,81 @@
|
||||||
|
package sqlmock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
var pool *mockDriver
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
pool = &mockDriver{
|
||||||
|
conns: make(map[string]*sqlmock),
|
||||||
|
}
|
||||||
|
sql.Register("sqlmock", pool)
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockDriver struct {
|
||||||
|
sync.Mutex
|
||||||
|
counter int
|
||||||
|
conns map[string]*sqlmock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *mockDriver) Open(dsn string) (driver.Conn, error) {
|
||||||
|
d.Lock()
|
||||||
|
defer d.Unlock()
|
||||||
|
|
||||||
|
c, ok := d.conns[dsn]
|
||||||
|
if !ok {
|
||||||
|
return c, fmt.Errorf("expected a connection to be available, but it is not")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.opened++
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates sqlmock database connection and a mock to manage expectations.
|
||||||
|
// Accepts options, like ValueConverterOption, to use a ValueConverter from
|
||||||
|
// a specific driver.
|
||||||
|
// Pings db so that all expectations could be
|
||||||
|
// asserted.
|
||||||
|
func New(options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) {
|
||||||
|
pool.Lock()
|
||||||
|
dsn := fmt.Sprintf("sqlmock_db_%d", pool.counter)
|
||||||
|
pool.counter++
|
||||||
|
|
||||||
|
smock := &sqlmock{dsn: dsn, drv: pool, ordered: true}
|
||||||
|
pool.conns[dsn] = smock
|
||||||
|
pool.Unlock()
|
||||||
|
|
||||||
|
return smock.open(options)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWithDSN creates sqlmock database connection with a specific DSN
|
||||||
|
// and a mock to manage expectations.
|
||||||
|
// Accepts options, like ValueConverterOption, to use a ValueConverter from
|
||||||
|
// a specific driver.
|
||||||
|
// Pings db so that all expectations could be asserted.
|
||||||
|
//
|
||||||
|
// This method is introduced because of sql abstraction
|
||||||
|
// libraries, which do not provide a way to initialize
|
||||||
|
// with sql.DB instance. For example GORM library.
|
||||||
|
//
|
||||||
|
// Note, it will error if attempted to create with an
|
||||||
|
// already used dsn
|
||||||
|
//
|
||||||
|
// It is not recommended to use this method, unless you
|
||||||
|
// really need it and there is no other way around.
|
||||||
|
func NewWithDSN(dsn string, options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) {
|
||||||
|
pool.Lock()
|
||||||
|
if _, ok := pool.conns[dsn]; ok {
|
||||||
|
pool.Unlock()
|
||||||
|
return nil, nil, fmt.Errorf("cannot create a new mock database with the same dsn: %s", dsn)
|
||||||
|
}
|
||||||
|
smock := &sqlmock{dsn: dsn, drv: pool, ordered: true}
|
||||||
|
pool.conns[dsn] = smock
|
||||||
|
pool.Unlock()
|
||||||
|
|
||||||
|
return smock.open(options)
|
||||||
|
}
|
|
@ -0,0 +1,369 @@
|
||||||
|
package sqlmock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// an expectation interface
|
||||||
|
type expectation interface {
|
||||||
|
fulfilled() bool
|
||||||
|
Lock()
|
||||||
|
Unlock()
|
||||||
|
String() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// common expectation struct
|
||||||
|
// satisfies the expectation interface
|
||||||
|
type commonExpectation struct {
|
||||||
|
sync.Mutex
|
||||||
|
triggered bool
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *commonExpectation) fulfilled() bool {
|
||||||
|
return e.triggered
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpectedClose is used to manage *sql.DB.Close expectation
|
||||||
|
// returned by *Sqlmock.ExpectClose.
|
||||||
|
type ExpectedClose struct {
|
||||||
|
commonExpectation
|
||||||
|
}
|
||||||
|
|
||||||
|
// WillReturnError allows to set an error for *sql.DB.Close action
|
||||||
|
func (e *ExpectedClose) WillReturnError(err error) *ExpectedClose {
|
||||||
|
e.err = err
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns string representation
|
||||||
|
func (e *ExpectedClose) String() string {
|
||||||
|
msg := "ExpectedClose => expecting database Close"
|
||||||
|
if e.err != nil {
|
||||||
|
msg += fmt.Sprintf(", which should return error: %s", e.err)
|
||||||
|
}
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpectedBegin is used to manage *sql.DB.Begin expectation
|
||||||
|
// returned by *Sqlmock.ExpectBegin.
|
||||||
|
type ExpectedBegin struct {
|
||||||
|
commonExpectation
|
||||||
|
delay time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// WillReturnError allows to set an error for *sql.DB.Begin action
|
||||||
|
func (e *ExpectedBegin) WillReturnError(err error) *ExpectedBegin {
|
||||||
|
e.err = err
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns string representation
|
||||||
|
func (e *ExpectedBegin) String() string {
|
||||||
|
msg := "ExpectedBegin => expecting database transaction Begin"
|
||||||
|
if e.err != nil {
|
||||||
|
msg += fmt.Sprintf(", which should return error: %s", e.err)
|
||||||
|
}
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
// WillDelayFor allows to specify duration for which it will delay
|
||||||
|
// result. May be used together with Context
|
||||||
|
func (e *ExpectedBegin) WillDelayFor(duration time.Duration) *ExpectedBegin {
|
||||||
|
e.delay = duration
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpectedCommit is used to manage *sql.Tx.Commit expectation
|
||||||
|
// returned by *Sqlmock.ExpectCommit.
|
||||||
|
type ExpectedCommit struct {
|
||||||
|
commonExpectation
|
||||||
|
}
|
||||||
|
|
||||||
|
// WillReturnError allows to set an error for *sql.Tx.Close action
|
||||||
|
func (e *ExpectedCommit) WillReturnError(err error) *ExpectedCommit {
|
||||||
|
e.err = err
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns string representation
|
||||||
|
func (e *ExpectedCommit) String() string {
|
||||||
|
msg := "ExpectedCommit => expecting transaction Commit"
|
||||||
|
if e.err != nil {
|
||||||
|
msg += fmt.Sprintf(", which should return error: %s", e.err)
|
||||||
|
}
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpectedRollback is used to manage *sql.Tx.Rollback expectation
|
||||||
|
// returned by *Sqlmock.ExpectRollback.
|
||||||
|
type ExpectedRollback struct {
|
||||||
|
commonExpectation
|
||||||
|
}
|
||||||
|
|
||||||
|
// WillReturnError allows to set an error for *sql.Tx.Rollback action
|
||||||
|
func (e *ExpectedRollback) WillReturnError(err error) *ExpectedRollback {
|
||||||
|
e.err = err
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns string representation
|
||||||
|
func (e *ExpectedRollback) String() string {
|
||||||
|
msg := "ExpectedRollback => expecting transaction Rollback"
|
||||||
|
if e.err != nil {
|
||||||
|
msg += fmt.Sprintf(", which should return error: %s", e.err)
|
||||||
|
}
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpectedQuery is used to manage *sql.DB.Query, *dql.DB.QueryRow, *sql.Tx.Query,
|
||||||
|
// *sql.Tx.QueryRow, *sql.Stmt.Query or *sql.Stmt.QueryRow expectations.
|
||||||
|
// Returned by *Sqlmock.ExpectQuery.
|
||||||
|
type ExpectedQuery struct {
|
||||||
|
queryBasedExpectation
|
||||||
|
rows driver.Rows
|
||||||
|
delay time.Duration
|
||||||
|
rowsMustBeClosed bool
|
||||||
|
rowsWereClosed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithArgs will match given expected args to actual database query arguments.
|
||||||
|
// if at least one argument does not match, it will return an error. For specific
|
||||||
|
// arguments an sqlmock.Argument interface can be used to match an argument.
|
||||||
|
func (e *ExpectedQuery) WithArgs(args ...driver.Value) *ExpectedQuery {
|
||||||
|
e.args = args
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// RowsWillBeClosed expects this query rows to be closed.
|
||||||
|
func (e *ExpectedQuery) RowsWillBeClosed() *ExpectedQuery {
|
||||||
|
e.rowsMustBeClosed = true
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// WillReturnError allows to set an error for expected database query
|
||||||
|
func (e *ExpectedQuery) WillReturnError(err error) *ExpectedQuery {
|
||||||
|
e.err = err
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// WillDelayFor allows to specify duration for which it will delay
|
||||||
|
// result. May be used together with Context
|
||||||
|
func (e *ExpectedQuery) WillDelayFor(duration time.Duration) *ExpectedQuery {
|
||||||
|
e.delay = duration
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns string representation
|
||||||
|
func (e *ExpectedQuery) String() string {
|
||||||
|
msg := "ExpectedQuery => expecting Query, QueryContext or QueryRow which:"
|
||||||
|
msg += "\n - matches sql: '" + e.expectSQL + "'"
|
||||||
|
|
||||||
|
if len(e.args) == 0 {
|
||||||
|
msg += "\n - is without arguments"
|
||||||
|
} else {
|
||||||
|
msg += "\n - is with arguments:\n"
|
||||||
|
for i, arg := range e.args {
|
||||||
|
msg += fmt.Sprintf(" %d - %+v\n", i, arg)
|
||||||
|
}
|
||||||
|
msg = strings.TrimSpace(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.rows != nil {
|
||||||
|
msg += fmt.Sprintf("\n - %s", e.rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.err != nil {
|
||||||
|
msg += fmt.Sprintf("\n - should return error: %s", e.err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpectedExec is used to manage *sql.DB.Exec, *sql.Tx.Exec or *sql.Stmt.Exec expectations.
|
||||||
|
// Returned by *Sqlmock.ExpectExec.
|
||||||
|
type ExpectedExec struct {
|
||||||
|
queryBasedExpectation
|
||||||
|
result driver.Result
|
||||||
|
delay time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithArgs will match given expected args to actual database exec operation arguments.
|
||||||
|
// if at least one argument does not match, it will return an error. For specific
|
||||||
|
// arguments an sqlmock.Argument interface can be used to match an argument.
|
||||||
|
func (e *ExpectedExec) WithArgs(args ...driver.Value) *ExpectedExec {
|
||||||
|
e.args = args
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// WillReturnError allows to set an error for expected database exec action
|
||||||
|
func (e *ExpectedExec) WillReturnError(err error) *ExpectedExec {
|
||||||
|
e.err = err
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// WillDelayFor allows to specify duration for which it will delay
|
||||||
|
// result. May be used together with Context
|
||||||
|
func (e *ExpectedExec) WillDelayFor(duration time.Duration) *ExpectedExec {
|
||||||
|
e.delay = duration
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns string representation
|
||||||
|
func (e *ExpectedExec) String() string {
|
||||||
|
msg := "ExpectedExec => expecting Exec or ExecContext which:"
|
||||||
|
msg += "\n - matches sql: '" + e.expectSQL + "'"
|
||||||
|
|
||||||
|
if len(e.args) == 0 {
|
||||||
|
msg += "\n - is without arguments"
|
||||||
|
} else {
|
||||||
|
msg += "\n - is with arguments:\n"
|
||||||
|
var margs []string
|
||||||
|
for i, arg := range e.args {
|
||||||
|
margs = append(margs, fmt.Sprintf(" %d - %+v", i, arg))
|
||||||
|
}
|
||||||
|
msg += strings.Join(margs, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.result != nil {
|
||||||
|
res, _ := e.result.(*result)
|
||||||
|
msg += "\n - should return Result having:"
|
||||||
|
msg += fmt.Sprintf("\n LastInsertId: %d", res.insertID)
|
||||||
|
msg += fmt.Sprintf("\n RowsAffected: %d", res.rowsAffected)
|
||||||
|
if res.err != nil {
|
||||||
|
msg += fmt.Sprintf("\n Error: %s", res.err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.err != nil {
|
||||||
|
msg += fmt.Sprintf("\n - should return error: %s", e.err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
// WillReturnResult arranges for an expected Exec() to return a particular
|
||||||
|
// result, there is sqlmock.NewResult(lastInsertID int64, affectedRows int64) method
|
||||||
|
// to build a corresponding result. Or if actions needs to be tested against errors
|
||||||
|
// sqlmock.NewErrorResult(err error) to return a given error.
|
||||||
|
func (e *ExpectedExec) WillReturnResult(result driver.Result) *ExpectedExec {
|
||||||
|
e.result = result
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpectedPrepare is used to manage *sql.DB.Prepare or *sql.Tx.Prepare expectations.
|
||||||
|
// Returned by *Sqlmock.ExpectPrepare.
|
||||||
|
type ExpectedPrepare struct {
|
||||||
|
commonExpectation
|
||||||
|
mock *sqlmock
|
||||||
|
expectSQL string
|
||||||
|
statement driver.Stmt
|
||||||
|
closeErr error
|
||||||
|
mustBeClosed bool
|
||||||
|
wasClosed bool
|
||||||
|
delay time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// WillReturnError allows to set an error for the expected *sql.DB.Prepare or *sql.Tx.Prepare action.
|
||||||
|
func (e *ExpectedPrepare) WillReturnError(err error) *ExpectedPrepare {
|
||||||
|
e.err = err
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// WillReturnCloseError allows to set an error for this prepared statement Close action
|
||||||
|
func (e *ExpectedPrepare) WillReturnCloseError(err error) *ExpectedPrepare {
|
||||||
|
e.closeErr = err
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// WillDelayFor allows to specify duration for which it will delay
|
||||||
|
// result. May be used together with Context
|
||||||
|
func (e *ExpectedPrepare) WillDelayFor(duration time.Duration) *ExpectedPrepare {
|
||||||
|
e.delay = duration
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// WillBeClosed expects this prepared statement to
|
||||||
|
// be closed.
|
||||||
|
func (e *ExpectedPrepare) WillBeClosed() *ExpectedPrepare {
|
||||||
|
e.mustBeClosed = true
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpectQuery allows to expect Query() or QueryRow() on this prepared statement.
|
||||||
|
// This method is convenient in order to prevent duplicating sql query string matching.
|
||||||
|
func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery {
|
||||||
|
eq := &ExpectedQuery{}
|
||||||
|
eq.expectSQL = e.expectSQL
|
||||||
|
eq.converter = e.mock.converter
|
||||||
|
e.mock.expected = append(e.mock.expected, eq)
|
||||||
|
return eq
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpectExec allows to expect Exec() on this prepared statement.
|
||||||
|
// This method is convenient in order to prevent duplicating sql query string matching.
|
||||||
|
func (e *ExpectedPrepare) ExpectExec() *ExpectedExec {
|
||||||
|
eq := &ExpectedExec{}
|
||||||
|
eq.expectSQL = e.expectSQL
|
||||||
|
eq.converter = e.mock.converter
|
||||||
|
e.mock.expected = append(e.mock.expected, eq)
|
||||||
|
return eq
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns string representation
|
||||||
|
func (e *ExpectedPrepare) String() string {
|
||||||
|
msg := "ExpectedPrepare => expecting Prepare statement which:"
|
||||||
|
msg += "\n - matches sql: '" + e.expectSQL + "'"
|
||||||
|
|
||||||
|
if e.err != nil {
|
||||||
|
msg += fmt.Sprintf("\n - should return error: %s", e.err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.closeErr != nil {
|
||||||
|
msg += fmt.Sprintf("\n - should return error on Close: %s", e.closeErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
// query based expectation
|
||||||
|
// adds a query matching logic
|
||||||
|
type queryBasedExpectation struct {
|
||||||
|
commonExpectation
|
||||||
|
expectSQL string
|
||||||
|
converter driver.ValueConverter
|
||||||
|
args []driver.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpectedPing is used to manage *sql.DB.Ping expectations.
|
||||||
|
// Returned by *Sqlmock.ExpectPing.
|
||||||
|
type ExpectedPing struct {
|
||||||
|
commonExpectation
|
||||||
|
delay time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// WillDelayFor allows to specify duration for which it will delay result. May
|
||||||
|
// be used together with Context.
|
||||||
|
func (e *ExpectedPing) WillDelayFor(duration time.Duration) *ExpectedPing {
|
||||||
|
e.delay = duration
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// WillReturnError allows to set an error for expected database ping
|
||||||
|
func (e *ExpectedPing) WillReturnError(err error) *ExpectedPing {
|
||||||
|
e.err = err
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns string representation
|
||||||
|
func (e *ExpectedPing) String() string {
|
||||||
|
msg := "ExpectedPing => expecting database Ping"
|
||||||
|
if e.err != nil {
|
||||||
|
msg += fmt.Sprintf(", which should return error: %s", e.err)
|
||||||
|
}
|
||||||
|
return msg
|
||||||
|
}
|
|
@ -0,0 +1,67 @@
|
||||||
|
// +build !go1.8
|
||||||
|
|
||||||
|
package sqlmock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WillReturnRows specifies the set of resulting rows that will be returned
|
||||||
|
// by the triggered query
|
||||||
|
func (e *ExpectedQuery) WillReturnRows(rows *Rows) *ExpectedQuery {
|
||||||
|
e.rows = &rowSets{sets: []*Rows{rows}, ex: e}
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
|
||||||
|
if nil == e.args {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(args) != len(e.args) {
|
||||||
|
return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args))
|
||||||
|
}
|
||||||
|
for k, v := range args {
|
||||||
|
// custom argument matcher
|
||||||
|
matcher, ok := e.args[k].(Argument)
|
||||||
|
if ok {
|
||||||
|
// @TODO: does it make sense to pass value instead of named value?
|
||||||
|
if !matcher.Match(v.Value) {
|
||||||
|
return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k])
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dval := e.args[k]
|
||||||
|
// convert to driver converter
|
||||||
|
darg, err := e.converter.ConvertValue(dval)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !driver.IsValue(darg) {
|
||||||
|
return fmt.Errorf("argument %d: non-subset type %T returned from Value", k, darg)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(darg, v.Value) {
|
||||||
|
return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *queryBasedExpectation) attemptArgMatch(args []namedValue) (err error) {
|
||||||
|
// catch panic
|
||||||
|
defer func() {
|
||||||
|
if e := recover(); e != nil {
|
||||||
|
_, ok := e.(error)
|
||||||
|
if !ok {
|
||||||
|
err = fmt.Errorf(e.(string))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = e.argsMatches(args)
|
||||||
|
return
|
||||||
|
}
|
|
@ -0,0 +1,85 @@
|
||||||
|
// +build go1.8
|
||||||
|
|
||||||
|
package sqlmock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WillReturnRows specifies the set of resulting rows that will be returned
|
||||||
|
// by the triggered query
|
||||||
|
func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery {
|
||||||
|
defs := 0
|
||||||
|
sets := make([]*Rows, len(rows))
|
||||||
|
for i, r := range rows {
|
||||||
|
sets[i] = r
|
||||||
|
if r.def != nil {
|
||||||
|
defs++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if defs > 0 && defs == len(sets) {
|
||||||
|
e.rows = &rowSetsWithDefinition{&rowSets{sets: sets, ex: e}}
|
||||||
|
} else {
|
||||||
|
e.rows = &rowSets{sets: sets, ex: e}
|
||||||
|
}
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *queryBasedExpectation) argsMatches(args []driver.NamedValue) error {
|
||||||
|
if nil == e.args {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(args) != len(e.args) {
|
||||||
|
return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args))
|
||||||
|
}
|
||||||
|
// @TODO should we assert either all args are named or ordinal?
|
||||||
|
for k, v := range args {
|
||||||
|
// custom argument matcher
|
||||||
|
matcher, ok := e.args[k].(Argument)
|
||||||
|
if ok {
|
||||||
|
if !matcher.Match(v.Value) {
|
||||||
|
return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k])
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dval := e.args[k]
|
||||||
|
if named, isNamed := dval.(sql.NamedArg); isNamed {
|
||||||
|
dval = named.Value
|
||||||
|
if v.Name != named.Name {
|
||||||
|
return fmt.Errorf("named argument %d: name: \"%s\" does not match expected: \"%s\"", k, v.Name, named.Name)
|
||||||
|
}
|
||||||
|
} else if k+1 != v.Ordinal {
|
||||||
|
return fmt.Errorf("argument %d: ordinal position: %d does not match expected: %d", k, k+1, v.Ordinal)
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert to driver converter
|
||||||
|
darg, err := e.converter.ConvertValue(dval)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(darg, v.Value) {
|
||||||
|
return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v.Value, v.Value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *queryBasedExpectation) attemptArgMatch(args []driver.NamedValue) (err error) {
|
||||||
|
// catch panic
|
||||||
|
defer func() {
|
||||||
|
if e := recover(); e != nil {
|
||||||
|
_, ok := e.(error)
|
||||||
|
if !ok {
|
||||||
|
err = fmt.Errorf(e.(string))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = e.argsMatches(args)
|
||||||
|
return
|
||||||
|
}
|
|
@ -0,0 +1,38 @@
|
||||||
|
package sqlmock
|
||||||
|
|
||||||
|
import "database/sql/driver"
|
||||||
|
|
||||||
|
// ValueConverterOption allows to create a sqlmock connection
|
||||||
|
// with a custom ValueConverter to support drivers with special data types.
|
||||||
|
func ValueConverterOption(converter driver.ValueConverter) func(*sqlmock) error {
|
||||||
|
return func(s *sqlmock) error {
|
||||||
|
s.converter = converter
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryMatcherOption allows to customize SQL query matcher
|
||||||
|
// and match SQL query strings in more sophisticated ways.
|
||||||
|
// The default QueryMatcher is QueryMatcherRegexp.
|
||||||
|
func QueryMatcherOption(queryMatcher QueryMatcher) func(*sqlmock) error {
|
||||||
|
return func(s *sqlmock) error {
|
||||||
|
s.queryMatcher = queryMatcher
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MonitorPingsOption determines whether calls to Ping on the driver should be
|
||||||
|
// observed and mocked.
|
||||||
|
//
|
||||||
|
// If true is passed, we will check these calls were expected. Expectations can
|
||||||
|
// be registered using the ExpectPing() method on the mock.
|
||||||
|
//
|
||||||
|
// If false is passed or this option is omitted, calls to Ping will not be
|
||||||
|
// considered when determining expectations and calls to ExpectPing will have
|
||||||
|
// no effect.
|
||||||
|
func MonitorPingsOption(monitorPings bool) func(*sqlmock) error {
|
||||||
|
return func(s *sqlmock) error {
|
||||||
|
s.monitorPings = monitorPings
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,68 @@
|
||||||
|
package sqlmock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var re = regexp.MustCompile("\\s+")
|
||||||
|
|
||||||
|
// strip out new lines and trim spaces
|
||||||
|
func stripQuery(q string) (s string) {
|
||||||
|
return strings.TrimSpace(re.ReplaceAllString(q, " "))
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryMatcher is an SQL query string matcher interface,
|
||||||
|
// which can be used to customize validation of SQL query strings.
|
||||||
|
// As an example, external library could be used to build
|
||||||
|
// and validate SQL ast, columns selected.
|
||||||
|
//
|
||||||
|
// sqlmock can be customized to implement a different QueryMatcher
|
||||||
|
// configured through an option when sqlmock.New or sqlmock.NewWithDSN
|
||||||
|
// is called, default QueryMatcher is QueryMatcherRegexp.
|
||||||
|
type QueryMatcher interface {
|
||||||
|
|
||||||
|
// Match expected SQL query string without whitespace to
|
||||||
|
// actual SQL.
|
||||||
|
Match(expectedSQL, actualSQL string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryMatcherFunc type is an adapter to allow the use of
|
||||||
|
// ordinary functions as QueryMatcher. If f is a function
|
||||||
|
// with the appropriate signature, QueryMatcherFunc(f) is a
|
||||||
|
// QueryMatcher that calls f.
|
||||||
|
type QueryMatcherFunc func(expectedSQL, actualSQL string) error
|
||||||
|
|
||||||
|
// Match implements the QueryMatcher
|
||||||
|
func (f QueryMatcherFunc) Match(expectedSQL, actualSQL string) error {
|
||||||
|
return f(expectedSQL, actualSQL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryMatcherRegexp is the default SQL query matcher
|
||||||
|
// used by sqlmock. It parses expectedSQL to a regular
|
||||||
|
// expression and attempts to match actualSQL.
|
||||||
|
var QueryMatcherRegexp QueryMatcher = QueryMatcherFunc(func(expectedSQL, actualSQL string) error {
|
||||||
|
expect := stripQuery(expectedSQL)
|
||||||
|
actual := stripQuery(actualSQL)
|
||||||
|
re, err := regexp.Compile(expect)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !re.MatchString(actual) {
|
||||||
|
return fmt.Errorf(`could not match actual sql: "%s" with expected regexp "%s"`, actual, re.String())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// QueryMatcherEqual is the SQL query matcher
|
||||||
|
// which simply tries a case sensitive match of
|
||||||
|
// expected and actual SQL strings without whitespace.
|
||||||
|
var QueryMatcherEqual QueryMatcher = QueryMatcherFunc(func(expectedSQL, actualSQL string) error {
|
||||||
|
expect := stripQuery(expectedSQL)
|
||||||
|
actual := stripQuery(actualSQL)
|
||||||
|
if actual != expect {
|
||||||
|
return fmt.Errorf(`actual sql: "%s" does not equal to expected "%s"`, actual, expect)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
|
@ -0,0 +1,39 @@
|
||||||
|
package sqlmock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Result satisfies sql driver Result, which
|
||||||
|
// holds last insert id and rows affected
|
||||||
|
// by Exec queries
|
||||||
|
type result struct {
|
||||||
|
insertID int64
|
||||||
|
rowsAffected int64
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewResult creates a new sql driver Result
|
||||||
|
// for Exec based query mocks.
|
||||||
|
func NewResult(lastInsertID int64, rowsAffected int64) driver.Result {
|
||||||
|
return &result{
|
||||||
|
insertID: lastInsertID,
|
||||||
|
rowsAffected: rowsAffected,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewErrorResult creates a new sql driver Result
|
||||||
|
// which returns an error given for both interface methods
|
||||||
|
func NewErrorResult(err error) driver.Result {
|
||||||
|
return &result{
|
||||||
|
err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *result) LastInsertId() (int64, error) {
|
||||||
|
return r.insertID, r.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *result) RowsAffected() (int64, error) {
|
||||||
|
return r.rowsAffected, r.err
|
||||||
|
}
|
|
@ -0,0 +1,212 @@
|
||||||
|
package sqlmock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/csv"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const invalidate = "☠☠☠ MEMORY OVERWRITTEN ☠☠☠ "
|
||||||
|
|
||||||
|
// CSVColumnParser is a function which converts trimmed csv
|
||||||
|
// column string to a []byte representation. Currently
|
||||||
|
// transforms NULL to nil
|
||||||
|
var CSVColumnParser = func(s string) []byte {
|
||||||
|
switch {
|
||||||
|
case strings.ToLower(s) == "null":
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return []byte(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
type rowSets struct {
|
||||||
|
sets []*Rows
|
||||||
|
pos int
|
||||||
|
ex *ExpectedQuery
|
||||||
|
raw [][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs *rowSets) Columns() []string {
|
||||||
|
return rs.sets[rs.pos].cols
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs *rowSets) Close() error {
|
||||||
|
rs.invalidateRaw()
|
||||||
|
rs.ex.rowsWereClosed = true
|
||||||
|
return rs.sets[rs.pos].closeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// advances to next row
|
||||||
|
func (rs *rowSets) Next(dest []driver.Value) error {
|
||||||
|
r := rs.sets[rs.pos]
|
||||||
|
r.pos++
|
||||||
|
rs.invalidateRaw()
|
||||||
|
if r.pos > len(r.rows) {
|
||||||
|
return io.EOF // per interface spec
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, col := range r.rows[r.pos-1] {
|
||||||
|
if b, ok := rawBytes(col); ok {
|
||||||
|
rs.raw = append(rs.raw, b)
|
||||||
|
dest[i] = b
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
dest[i] = col
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.nextErr[r.pos-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// transforms to debuggable printable string
|
||||||
|
func (rs *rowSets) String() string {
|
||||||
|
if rs.empty() {
|
||||||
|
return "with empty rows"
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := "should return rows:\n"
|
||||||
|
if len(rs.sets) == 1 {
|
||||||
|
for n, row := range rs.sets[0].rows {
|
||||||
|
msg += fmt.Sprintf(" row %d - %+v\n", n, row)
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(msg)
|
||||||
|
}
|
||||||
|
for i, set := range rs.sets {
|
||||||
|
msg += fmt.Sprintf(" result set: %d\n", i)
|
||||||
|
for n, row := range set.rows {
|
||||||
|
msg += fmt.Sprintf(" row %d - %+v\n", n, row)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs *rowSets) empty() bool {
|
||||||
|
for _, set := range rs.sets {
|
||||||
|
if len(set.rows) > 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func rawBytes(col driver.Value) (_ []byte, ok bool) {
|
||||||
|
val, ok := col.([]byte)
|
||||||
|
if !ok || len(val) == 0 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
// Copy the bytes from the mocked row into a shared raw buffer, which we'll replace the content of later
|
||||||
|
// This allows scanning into sql.RawBytes to correctly become invalid on subsequent calls to Next(), Scan() or Close()
|
||||||
|
b := make([]byte, len(val))
|
||||||
|
copy(b, val)
|
||||||
|
return b, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bytes that could have been scanned as sql.RawBytes are only valid until the next call to Next, Scan or Close.
|
||||||
|
// If those occur, we must replace their content to simulate the shared memory to expose misuse of sql.RawBytes
|
||||||
|
func (rs *rowSets) invalidateRaw() {
|
||||||
|
// Replace the content of slices previously returned
|
||||||
|
b := []byte(invalidate)
|
||||||
|
for _, r := range rs.raw {
|
||||||
|
copy(r, bytes.Repeat(b, len(r)/len(b)+1))
|
||||||
|
}
|
||||||
|
// Start with new slices for the next scan
|
||||||
|
rs.raw = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rows is a mocked collection of rows to
|
||||||
|
// return for Query result
|
||||||
|
type Rows struct {
|
||||||
|
converter driver.ValueConverter
|
||||||
|
cols []string
|
||||||
|
def []*Column
|
||||||
|
rows [][]driver.Value
|
||||||
|
pos int
|
||||||
|
nextErr map[int]error
|
||||||
|
closeErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRows allows Rows to be created from a
|
||||||
|
// sql driver.Value slice or from the CSV string and
|
||||||
|
// to be used as sql driver.Rows.
|
||||||
|
// Use Sqlmock.NewRows instead if using a custom converter
|
||||||
|
func NewRows(columns []string) *Rows {
|
||||||
|
return &Rows{
|
||||||
|
cols: columns,
|
||||||
|
nextErr: make(map[int]error),
|
||||||
|
converter: driver.DefaultParameterConverter,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseError allows to set an error
|
||||||
|
// which will be returned by rows.Close
|
||||||
|
// function.
|
||||||
|
//
|
||||||
|
// The close error will be triggered only in cases
|
||||||
|
// when rows.Next() EOF was not yet reached, that is
|
||||||
|
// a default sql library behavior
|
||||||
|
func (r *Rows) CloseError(err error) *Rows {
|
||||||
|
r.closeErr = err
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// RowError allows to set an error
|
||||||
|
// which will be returned when a given
|
||||||
|
// row number is read
|
||||||
|
func (r *Rows) RowError(row int, err error) *Rows {
|
||||||
|
r.nextErr[row] = err
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRow composed from database driver.Value slice
|
||||||
|
// return the same instance to perform subsequent actions.
|
||||||
|
// Note that the number of values must match the number
|
||||||
|
// of columns
|
||||||
|
func (r *Rows) AddRow(values ...driver.Value) *Rows {
|
||||||
|
if len(values) != len(r.cols) {
|
||||||
|
panic("Expected number of values to match number of columns")
|
||||||
|
}
|
||||||
|
|
||||||
|
row := make([]driver.Value, len(r.cols))
|
||||||
|
for i, v := range values {
|
||||||
|
// Convert user-friendly values (such as int or driver.Valuer)
|
||||||
|
// to database/sql native value (driver.Value such as int64)
|
||||||
|
var err error
|
||||||
|
v, err = r.converter.ConvertValue(v)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf(
|
||||||
|
"row #%d, column #%d (%q) type %T: %s",
|
||||||
|
len(r.rows)+1, i, r.cols[i], values[i], err,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
row[i] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
r.rows = append(r.rows, row)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromCSVString build rows from csv string.
|
||||||
|
// return the same instance to perform subsequent actions.
|
||||||
|
// Note that the number of values must match the number
|
||||||
|
// of columns
|
||||||
|
func (r *Rows) FromCSVString(s string) *Rows {
|
||||||
|
res := strings.NewReader(strings.TrimSpace(s))
|
||||||
|
csvReader := csv.NewReader(res)
|
||||||
|
|
||||||
|
for {
|
||||||
|
res, err := csvReader.Read()
|
||||||
|
if err != nil || res == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
row := make([]driver.Value, len(r.cols))
|
||||||
|
for i, v := range res {
|
||||||
|
row[i] = CSVColumnParser(strings.TrimSpace(v))
|
||||||
|
}
|
||||||
|
r.rows = append(r.rows, row)
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
|
@ -0,0 +1,74 @@
|
||||||
|
// +build go1.8
|
||||||
|
|
||||||
|
package sqlmock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"io"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Implement the "RowsNextResultSet" interface
|
||||||
|
func (rs *rowSets) HasNextResultSet() bool {
|
||||||
|
return rs.pos+1 < len(rs.sets)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement the "RowsNextResultSet" interface
|
||||||
|
func (rs *rowSets) NextResultSet() error {
|
||||||
|
if !rs.HasNextResultSet() {
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
rs.pos++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// type for rows with columns definition created with sqlmock.NewRowsWithColumnDefinition
|
||||||
|
type rowSetsWithDefinition struct {
|
||||||
|
*rowSets
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement the "RowsColumnTypeDatabaseTypeName" interface
|
||||||
|
func (rs *rowSetsWithDefinition) ColumnTypeDatabaseTypeName(index int) string {
|
||||||
|
return rs.getDefinition(index).DbType()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement the "RowsColumnTypeLength" interface
|
||||||
|
func (rs *rowSetsWithDefinition) ColumnTypeLength(index int) (length int64, ok bool) {
|
||||||
|
return rs.getDefinition(index).Length()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement the "RowsColumnTypeNullable" interface
|
||||||
|
func (rs *rowSetsWithDefinition) ColumnTypeNullable(index int) (nullable, ok bool) {
|
||||||
|
return rs.getDefinition(index).IsNullable()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement the "RowsColumnTypePrecisionScale" interface
|
||||||
|
func (rs *rowSetsWithDefinition) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
|
||||||
|
return rs.getDefinition(index).PrecisionScale()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ColumnTypeScanType is defined from driver.RowsColumnTypeScanType
|
||||||
|
func (rs *rowSetsWithDefinition) ColumnTypeScanType(index int) reflect.Type {
|
||||||
|
return rs.getDefinition(index).ScanType()
|
||||||
|
}
|
||||||
|
|
||||||
|
// return column definition from current set metadata
|
||||||
|
func (rs *rowSetsWithDefinition) getDefinition(index int) *Column {
|
||||||
|
return rs.sets[rs.pos].def[index]
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRowsWithColumnDefinition return rows with columns metadata
|
||||||
|
func NewRowsWithColumnDefinition(columns ...*Column) *Rows {
|
||||||
|
cols := make([]string, len(columns))
|
||||||
|
for i, column := range columns {
|
||||||
|
cols[i] = column.Name()
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Rows{
|
||||||
|
cols: cols,
|
||||||
|
def: columns,
|
||||||
|
nextErr: make(map[int]error),
|
||||||
|
converter: driver.DefaultParameterConverter,
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,439 @@
|
||||||
|
/*
|
||||||
|
Package sqlmock is a mock library implementing sql driver. Which has one and only
|
||||||
|
purpose - to simulate any sql driver behavior in tests, without needing a real
|
||||||
|
database connection. It helps to maintain correct **TDD** workflow.
|
||||||
|
|
||||||
|
It does not require any modifications to your source code in order to test
|
||||||
|
and mock database operations. Supports concurrency and multiple database mocking.
|
||||||
|
|
||||||
|
The driver allows to mock any sql driver method behavior.
|
||||||
|
*/
|
||||||
|
package sqlmock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Sqlmock interface serves to create expectations
|
||||||
|
// for any kind of database action in order to mock
|
||||||
|
// and test real database behavior.
|
||||||
|
type SqlmockCommon interface {
|
||||||
|
// ExpectClose queues an expectation for this database
|
||||||
|
// action to be triggered. the *ExpectedClose allows
|
||||||
|
// to mock database response
|
||||||
|
ExpectClose() *ExpectedClose
|
||||||
|
|
||||||
|
// ExpectationsWereMet checks whether all queued expectations
|
||||||
|
// were met in order. If any of them was not met - an error is returned.
|
||||||
|
ExpectationsWereMet() error
|
||||||
|
|
||||||
|
// ExpectPrepare expects Prepare() to be called with expectedSQL query.
|
||||||
|
// the *ExpectedPrepare allows to mock database response.
|
||||||
|
// Note that you may expect Query() or Exec() on the *ExpectedPrepare
|
||||||
|
// statement to prevent repeating expectedSQL
|
||||||
|
ExpectPrepare(expectedSQL string) *ExpectedPrepare
|
||||||
|
|
||||||
|
// ExpectQuery expects Query() or QueryRow() to be called with expectedSQL query.
|
||||||
|
// the *ExpectedQuery allows to mock database response.
|
||||||
|
ExpectQuery(expectedSQL string) *ExpectedQuery
|
||||||
|
|
||||||
|
// ExpectExec expects Exec() to be called with expectedSQL query.
|
||||||
|
// the *ExpectedExec allows to mock database response
|
||||||
|
ExpectExec(expectedSQL string) *ExpectedExec
|
||||||
|
|
||||||
|
// ExpectBegin expects *sql.DB.Begin to be called.
|
||||||
|
// the *ExpectedBegin allows to mock database response
|
||||||
|
ExpectBegin() *ExpectedBegin
|
||||||
|
|
||||||
|
// ExpectCommit expects *sql.Tx.Commit to be called.
|
||||||
|
// the *ExpectedCommit allows to mock database response
|
||||||
|
ExpectCommit() *ExpectedCommit
|
||||||
|
|
||||||
|
// ExpectRollback expects *sql.Tx.Rollback to be called.
|
||||||
|
// the *ExpectedRollback allows to mock database response
|
||||||
|
ExpectRollback() *ExpectedRollback
|
||||||
|
|
||||||
|
// ExpectPing expected *sql.DB.Ping to be called.
|
||||||
|
// the *ExpectedPing allows to mock database response
|
||||||
|
//
|
||||||
|
// Ping support only exists in the SQL library in Go 1.8 and above.
|
||||||
|
// ExpectPing in Go <=1.7 will return an ExpectedPing but not register
|
||||||
|
// any expectations.
|
||||||
|
//
|
||||||
|
// You must enable pings using MonitorPingsOption for this to register
|
||||||
|
// any expectations.
|
||||||
|
ExpectPing() *ExpectedPing
|
||||||
|
|
||||||
|
// MatchExpectationsInOrder gives an option whether to match all
|
||||||
|
// expectations in the order they were set or not.
|
||||||
|
//
|
||||||
|
// By default it is set to - true. But if you use goroutines
|
||||||
|
// to parallelize your query executation, that option may
|
||||||
|
// be handy.
|
||||||
|
//
|
||||||
|
// This option may be turned on anytime during tests. As soon
|
||||||
|
// as it is switched to false, expectations will be matched
|
||||||
|
// in any order. Or otherwise if switched to true, any unmatched
|
||||||
|
// expectations will be expected in order
|
||||||
|
MatchExpectationsInOrder(bool)
|
||||||
|
|
||||||
|
// NewRows allows Rows to be created from a
|
||||||
|
// sql driver.Value slice or from the CSV string and
|
||||||
|
// to be used as sql driver.Rows.
|
||||||
|
NewRows(columns []string) *Rows
|
||||||
|
}
|
||||||
|
|
||||||
|
type sqlmock struct {
|
||||||
|
ordered bool
|
||||||
|
dsn string
|
||||||
|
opened int
|
||||||
|
drv *mockDriver
|
||||||
|
converter driver.ValueConverter
|
||||||
|
queryMatcher QueryMatcher
|
||||||
|
monitorPings bool
|
||||||
|
|
||||||
|
expected []expectation
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlmock) open(options []func(*sqlmock) error) (*sql.DB, Sqlmock, error) {
|
||||||
|
db, err := sql.Open("sqlmock", c.dsn)
|
||||||
|
if err != nil {
|
||||||
|
return db, c, err
|
||||||
|
}
|
||||||
|
for _, option := range options {
|
||||||
|
err := option(c)
|
||||||
|
if err != nil {
|
||||||
|
return db, c, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c.converter == nil {
|
||||||
|
c.converter = driver.DefaultParameterConverter
|
||||||
|
}
|
||||||
|
if c.queryMatcher == nil {
|
||||||
|
c.queryMatcher = QueryMatcherRegexp
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.monitorPings {
|
||||||
|
// We call Ping on the driver shortly to verify startup assertions by
|
||||||
|
// driving internal behaviour of the sql standard library. We don't
|
||||||
|
// want this call to ping to be monitored for expectation purposes so
|
||||||
|
// temporarily disable.
|
||||||
|
c.monitorPings = false
|
||||||
|
defer func() { c.monitorPings = true }()
|
||||||
|
}
|
||||||
|
return db, c, db.Ping()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlmock) ExpectClose() *ExpectedClose {
|
||||||
|
e := &ExpectedClose{}
|
||||||
|
c.expected = append(c.expected, e)
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlmock) MatchExpectationsInOrder(b bool) {
|
||||||
|
c.ordered = b
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close a mock database driver connection. It may or may not
|
||||||
|
// be called depending on the circumstances, but if it is called
|
||||||
|
// there must be an *ExpectedClose expectation satisfied.
|
||||||
|
// meets http://golang.org/pkg/database/sql/driver/#Conn interface
|
||||||
|
func (c *sqlmock) Close() error {
|
||||||
|
c.drv.Lock()
|
||||||
|
defer c.drv.Unlock()
|
||||||
|
|
||||||
|
c.opened--
|
||||||
|
if c.opened == 0 {
|
||||||
|
delete(c.drv.conns, c.dsn)
|
||||||
|
}
|
||||||
|
|
||||||
|
var expected *ExpectedClose
|
||||||
|
var fulfilled int
|
||||||
|
var ok bool
|
||||||
|
for _, next := range c.expected {
|
||||||
|
next.Lock()
|
||||||
|
if next.fulfilled() {
|
||||||
|
next.Unlock()
|
||||||
|
fulfilled++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected, ok = next.(*ExpectedClose); ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
next.Unlock()
|
||||||
|
if c.ordered {
|
||||||
|
return fmt.Errorf("call to database Close, was not expected, next expectation is: %s", next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected == nil {
|
||||||
|
msg := "call to database Close was not expected"
|
||||||
|
if fulfilled == len(c.expected) {
|
||||||
|
msg = "all expectations were already fulfilled, " + msg
|
||||||
|
}
|
||||||
|
return fmt.Errorf(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected.triggered = true
|
||||||
|
expected.Unlock()
|
||||||
|
return expected.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlmock) ExpectationsWereMet() error {
|
||||||
|
for _, e := range c.expected {
|
||||||
|
e.Lock()
|
||||||
|
fulfilled := e.fulfilled()
|
||||||
|
e.Unlock()
|
||||||
|
|
||||||
|
if !fulfilled {
|
||||||
|
return fmt.Errorf("there is a remaining expectation which was not matched: %s", e)
|
||||||
|
}
|
||||||
|
|
||||||
|
// for expected prepared statement check whether it was closed if expected
|
||||||
|
if prep, ok := e.(*ExpectedPrepare); ok {
|
||||||
|
if prep.mustBeClosed && !prep.wasClosed {
|
||||||
|
return fmt.Errorf("expected prepared statement to be closed, but it was not: %s", prep)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// must check whether all expected queried rows are closed
|
||||||
|
if query, ok := e.(*ExpectedQuery); ok {
|
||||||
|
if query.rowsMustBeClosed && !query.rowsWereClosed {
|
||||||
|
return fmt.Errorf("expected query rows to be closed, but it was not: %s", query)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface
|
||||||
|
func (c *sqlmock) Begin() (driver.Tx, error) {
|
||||||
|
ex, err := c.begin()
|
||||||
|
if ex != nil {
|
||||||
|
time.Sleep(ex.delay)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlmock) begin() (*ExpectedBegin, error) {
|
||||||
|
var expected *ExpectedBegin
|
||||||
|
var ok bool
|
||||||
|
var fulfilled int
|
||||||
|
for _, next := range c.expected {
|
||||||
|
next.Lock()
|
||||||
|
if next.fulfilled() {
|
||||||
|
next.Unlock()
|
||||||
|
fulfilled++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected, ok = next.(*ExpectedBegin); ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
next.Unlock()
|
||||||
|
if c.ordered {
|
||||||
|
return nil, fmt.Errorf("call to database transaction Begin, was not expected, next expectation is: %s", next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if expected == nil {
|
||||||
|
msg := "call to database transaction Begin was not expected"
|
||||||
|
if fulfilled == len(c.expected) {
|
||||||
|
msg = "all expectations were already fulfilled, " + msg
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected.triggered = true
|
||||||
|
expected.Unlock()
|
||||||
|
|
||||||
|
return expected, expected.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlmock) ExpectBegin() *ExpectedBegin {
|
||||||
|
e := &ExpectedBegin{}
|
||||||
|
c.expected = append(c.expected, e)
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlmock) ExpectExec(expectedSQL string) *ExpectedExec {
|
||||||
|
e := &ExpectedExec{}
|
||||||
|
e.expectSQL = expectedSQL
|
||||||
|
e.converter = c.converter
|
||||||
|
c.expected = append(c.expected, e)
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface
|
||||||
|
func (c *sqlmock) Prepare(query string) (driver.Stmt, error) {
|
||||||
|
ex, err := c.prepare(query)
|
||||||
|
if ex != nil {
|
||||||
|
time.Sleep(ex.delay)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &statement{c, ex, query}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) {
|
||||||
|
var expected *ExpectedPrepare
|
||||||
|
var fulfilled int
|
||||||
|
var ok bool
|
||||||
|
|
||||||
|
for _, next := range c.expected {
|
||||||
|
next.Lock()
|
||||||
|
if next.fulfilled() {
|
||||||
|
next.Unlock()
|
||||||
|
fulfilled++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.ordered {
|
||||||
|
if expected, ok = next.(*ExpectedPrepare); ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
next.Unlock()
|
||||||
|
return nil, fmt.Errorf("call to Prepare statement with query '%s', was not expected, next expectation is: %s", query, next)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pr, ok := next.(*ExpectedPrepare); ok {
|
||||||
|
if err := c.queryMatcher.Match(pr.expectSQL, query); err == nil {
|
||||||
|
expected = pr
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
next.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected == nil {
|
||||||
|
msg := "call to Prepare '%s' query was not expected"
|
||||||
|
if fulfilled == len(c.expected) {
|
||||||
|
msg = "all expectations were already fulfilled, " + msg
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf(msg, query)
|
||||||
|
}
|
||||||
|
defer expected.Unlock()
|
||||||
|
if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil {
|
||||||
|
return nil, fmt.Errorf("Prepare: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected.triggered = true
|
||||||
|
return expected, expected.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlmock) ExpectPrepare(expectedSQL string) *ExpectedPrepare {
|
||||||
|
e := &ExpectedPrepare{expectSQL: expectedSQL, mock: c}
|
||||||
|
c.expected = append(c.expected, e)
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlmock) ExpectQuery(expectedSQL string) *ExpectedQuery {
|
||||||
|
e := &ExpectedQuery{}
|
||||||
|
e.expectSQL = expectedSQL
|
||||||
|
e.converter = c.converter
|
||||||
|
c.expected = append(c.expected, e)
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlmock) ExpectCommit() *ExpectedCommit {
|
||||||
|
e := &ExpectedCommit{}
|
||||||
|
c.expected = append(c.expected, e)
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlmock) ExpectRollback() *ExpectedRollback {
|
||||||
|
e := &ExpectedRollback{}
|
||||||
|
c.expected = append(c.expected, e)
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commit meets http://golang.org/pkg/database/sql/driver/#Tx
|
||||||
|
func (c *sqlmock) Commit() error {
|
||||||
|
var expected *ExpectedCommit
|
||||||
|
var fulfilled int
|
||||||
|
var ok bool
|
||||||
|
for _, next := range c.expected {
|
||||||
|
next.Lock()
|
||||||
|
if next.fulfilled() {
|
||||||
|
next.Unlock()
|
||||||
|
fulfilled++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected, ok = next.(*ExpectedCommit); ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
next.Unlock()
|
||||||
|
if c.ordered {
|
||||||
|
return fmt.Errorf("call to Commit transaction, was not expected, next expectation is: %s", next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if expected == nil {
|
||||||
|
msg := "call to Commit transaction was not expected"
|
||||||
|
if fulfilled == len(c.expected) {
|
||||||
|
msg = "all expectations were already fulfilled, " + msg
|
||||||
|
}
|
||||||
|
return fmt.Errorf(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected.triggered = true
|
||||||
|
expected.Unlock()
|
||||||
|
return expected.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rollback meets http://golang.org/pkg/database/sql/driver/#Tx
|
||||||
|
func (c *sqlmock) Rollback() error {
|
||||||
|
var expected *ExpectedRollback
|
||||||
|
var fulfilled int
|
||||||
|
var ok bool
|
||||||
|
for _, next := range c.expected {
|
||||||
|
next.Lock()
|
||||||
|
if next.fulfilled() {
|
||||||
|
next.Unlock()
|
||||||
|
fulfilled++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected, ok = next.(*ExpectedRollback); ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
next.Unlock()
|
||||||
|
if c.ordered {
|
||||||
|
return fmt.Errorf("call to Rollback transaction, was not expected, next expectation is: %s", next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if expected == nil {
|
||||||
|
msg := "call to Rollback transaction was not expected"
|
||||||
|
if fulfilled == len(c.expected) {
|
||||||
|
msg = "all expectations were already fulfilled, " + msg
|
||||||
|
}
|
||||||
|
return fmt.Errorf(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected.triggered = true
|
||||||
|
expected.Unlock()
|
||||||
|
return expected.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRows allows Rows to be created from a
|
||||||
|
// sql driver.Value slice or from the CSV string and
|
||||||
|
// to be used as sql driver.Rows.
|
||||||
|
func (c *sqlmock) NewRows(columns []string) *Rows {
|
||||||
|
r := NewRows(columns)
|
||||||
|
r.converter = c.converter
|
||||||
|
return r
|
||||||
|
}
|
|
@ -0,0 +1,191 @@
|
||||||
|
// +build !go1.8
|
||||||
|
|
||||||
|
package sqlmock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Sqlmock interface for Go up to 1.7
|
||||||
|
type Sqlmock interface {
|
||||||
|
// Embed common methods
|
||||||
|
SqlmockCommon
|
||||||
|
}
|
||||||
|
|
||||||
|
type namedValue struct {
|
||||||
|
Name string
|
||||||
|
Ordinal int
|
||||||
|
Value driver.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlmock) ExpectPing() *ExpectedPing {
|
||||||
|
log.Println("ExpectPing has no effect on Go 1.7 or below")
|
||||||
|
return &ExpectedPing{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query meets http://golang.org/pkg/database/sql/driver/#Queryer
|
||||||
|
func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error) {
|
||||||
|
namedArgs := make([]namedValue, len(args))
|
||||||
|
for i, v := range args {
|
||||||
|
namedArgs[i] = namedValue{
|
||||||
|
Ordinal: i + 1,
|
||||||
|
Value: v,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ex, err := c.query(query, namedArgs)
|
||||||
|
if ex != nil {
|
||||||
|
time.Sleep(ex.delay)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ex.rows, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error) {
|
||||||
|
var expected *ExpectedQuery
|
||||||
|
var fulfilled int
|
||||||
|
var ok bool
|
||||||
|
for _, next := range c.expected {
|
||||||
|
next.Lock()
|
||||||
|
if next.fulfilled() {
|
||||||
|
next.Unlock()
|
||||||
|
fulfilled++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.ordered {
|
||||||
|
if expected, ok = next.(*ExpectedQuery); ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
next.Unlock()
|
||||||
|
return nil, fmt.Errorf("call to Query '%s' with args %+v, was not expected, next expectation is: %s", query, args, next)
|
||||||
|
}
|
||||||
|
if qr, ok := next.(*ExpectedQuery); ok {
|
||||||
|
if err := c.queryMatcher.Match(qr.expectSQL, query); err != nil {
|
||||||
|
next.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := qr.attemptArgMatch(args); err == nil {
|
||||||
|
expected = qr
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
next.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected == nil {
|
||||||
|
msg := "call to Query '%s' with args %+v was not expected"
|
||||||
|
if fulfilled == len(c.expected) {
|
||||||
|
msg = "all expectations were already fulfilled, " + msg
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf(msg, query, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer expected.Unlock()
|
||||||
|
|
||||||
|
if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil {
|
||||||
|
return nil, fmt.Errorf("Query: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := expected.argsMatches(args); err != nil {
|
||||||
|
return nil, fmt.Errorf("Query '%s', arguments do not match: %s", query, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected.triggered = true
|
||||||
|
if expected.err != nil {
|
||||||
|
return expected, expected.err // mocked to return error
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected.rows == nil {
|
||||||
|
return nil, fmt.Errorf("Query '%s' with args %+v, must return a database/sql/driver.Rows, but it was not set for expectation %T as %+v", query, args, expected, expected)
|
||||||
|
}
|
||||||
|
return expected, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec meets http://golang.org/pkg/database/sql/driver/#Execer
|
||||||
|
func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error) {
|
||||||
|
namedArgs := make([]namedValue, len(args))
|
||||||
|
for i, v := range args {
|
||||||
|
namedArgs[i] = namedValue{
|
||||||
|
Ordinal: i + 1,
|
||||||
|
Value: v,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ex, err := c.exec(query, namedArgs)
|
||||||
|
if ex != nil {
|
||||||
|
time.Sleep(ex.delay)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ex.result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) {
|
||||||
|
var expected *ExpectedExec
|
||||||
|
var fulfilled int
|
||||||
|
var ok bool
|
||||||
|
for _, next := range c.expected {
|
||||||
|
next.Lock()
|
||||||
|
if next.fulfilled() {
|
||||||
|
next.Unlock()
|
||||||
|
fulfilled++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.ordered {
|
||||||
|
if expected, ok = next.(*ExpectedExec); ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
next.Unlock()
|
||||||
|
return nil, fmt.Errorf("call to ExecQuery '%s' with args %+v, was not expected, next expectation is: %s", query, args, next)
|
||||||
|
}
|
||||||
|
if exec, ok := next.(*ExpectedExec); ok {
|
||||||
|
if err := c.queryMatcher.Match(exec.expectSQL, query); err != nil {
|
||||||
|
next.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := exec.attemptArgMatch(args); err == nil {
|
||||||
|
expected = exec
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
next.Unlock()
|
||||||
|
}
|
||||||
|
if expected == nil {
|
||||||
|
msg := "call to ExecQuery '%s' with args %+v was not expected"
|
||||||
|
if fulfilled == len(c.expected) {
|
||||||
|
msg = "all expectations were already fulfilled, " + msg
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf(msg, query, args)
|
||||||
|
}
|
||||||
|
defer expected.Unlock()
|
||||||
|
|
||||||
|
if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil {
|
||||||
|
return nil, fmt.Errorf("ExecQuery: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := expected.argsMatches(args); err != nil {
|
||||||
|
return nil, fmt.Errorf("ExecQuery '%s', arguments do not match: %s", query, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected.triggered = true
|
||||||
|
if expected.err != nil {
|
||||||
|
return expected, expected.err // mocked to return error
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected.result == nil {
|
||||||
|
return nil, fmt.Errorf("ExecQuery '%s' with args %+v, must return a database/sql/driver.Result, but it was not set for expectation %T as %+v", query, args, expected, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
return expected, nil
|
||||||
|
}
|
|
@ -0,0 +1,356 @@
|
||||||
|
// +build go1.8
|
||||||
|
|
||||||
|
package sqlmock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql/driver"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Sqlmock interface for Go 1.8+
|
||||||
|
type Sqlmock interface {
|
||||||
|
// Embed common methods
|
||||||
|
SqlmockCommon
|
||||||
|
|
||||||
|
// NewRowsWithColumnDefinition allows Rows to be created from a
|
||||||
|
// sql driver.Value slice with a definition of sql metadata
|
||||||
|
NewRowsWithColumnDefinition(columns ...*Column) *Rows
|
||||||
|
|
||||||
|
// New Column allows to create a Column
|
||||||
|
NewColumn(name string) *Column
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrCancelled defines an error value, which can be expected in case of
|
||||||
|
// such cancellation error.
|
||||||
|
var ErrCancelled = errors.New("canceling query due to user request")
|
||||||
|
|
||||||
|
// Implement the "QueryerContext" interface
|
||||||
|
func (c *sqlmock) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
|
||||||
|
ex, err := c.query(query, args)
|
||||||
|
if ex != nil {
|
||||||
|
select {
|
||||||
|
case <-time.After(ex.delay):
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ex.rows, nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ErrCancelled
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement the "ExecerContext" interface
|
||||||
|
func (c *sqlmock) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||||
|
ex, err := c.exec(query, args)
|
||||||
|
if ex != nil {
|
||||||
|
select {
|
||||||
|
case <-time.After(ex.delay):
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ex.result, nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ErrCancelled
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement the "ConnBeginTx" interface
|
||||||
|
func (c *sqlmock) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||||
|
ex, err := c.begin()
|
||||||
|
if ex != nil {
|
||||||
|
select {
|
||||||
|
case <-time.After(ex.delay):
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return c, nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ErrCancelled
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement the "ConnPrepareContext" interface
|
||||||
|
func (c *sqlmock) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
|
||||||
|
ex, err := c.prepare(query)
|
||||||
|
if ex != nil {
|
||||||
|
select {
|
||||||
|
case <-time.After(ex.delay):
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &statement{c, ex, query}, nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ErrCancelled
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement the "Pinger" interface - the explicit DB driver ping was only added to database/sql in Go 1.8
|
||||||
|
func (c *sqlmock) Ping(ctx context.Context) error {
|
||||||
|
if !c.monitorPings {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ex, err := c.ping()
|
||||||
|
if ex != nil {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ErrCancelled
|
||||||
|
case <-time.After(ex.delay):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlmock) ping() (*ExpectedPing, error) {
|
||||||
|
var expected *ExpectedPing
|
||||||
|
var fulfilled int
|
||||||
|
var ok bool
|
||||||
|
for _, next := range c.expected {
|
||||||
|
next.Lock()
|
||||||
|
if next.fulfilled() {
|
||||||
|
next.Unlock()
|
||||||
|
fulfilled++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected, ok = next.(*ExpectedPing); ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
next.Unlock()
|
||||||
|
if c.ordered {
|
||||||
|
return nil, fmt.Errorf("call to database Ping, was not expected, next expectation is: %s", next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected == nil {
|
||||||
|
msg := "call to database Ping was not expected"
|
||||||
|
if fulfilled == len(c.expected) {
|
||||||
|
msg = "all expectations were already fulfilled, " + msg
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected.triggered = true
|
||||||
|
expected.Unlock()
|
||||||
|
return expected, expected.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement the "StmtExecContext" interface
|
||||||
|
func (stmt *statement) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||||
|
return stmt.conn.ExecContext(ctx, stmt.query, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement the "StmtQueryContext" interface
|
||||||
|
func (stmt *statement) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||||
|
return stmt.conn.QueryContext(ctx, stmt.query, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlmock) ExpectPing() *ExpectedPing {
|
||||||
|
if !c.monitorPings {
|
||||||
|
log.Println("ExpectPing will have no effect as monitoring pings is disabled. Use MonitorPingsOption to enable.")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
e := &ExpectedPing{}
|
||||||
|
c.expected = append(c.expected, e)
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query meets http://golang.org/pkg/database/sql/driver/#Queryer
|
||||||
|
// Deprecated: Drivers should implement QueryerContext instead.
|
||||||
|
func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error) {
|
||||||
|
namedArgs := make([]driver.NamedValue, len(args))
|
||||||
|
for i, v := range args {
|
||||||
|
namedArgs[i] = driver.NamedValue{
|
||||||
|
Ordinal: i + 1,
|
||||||
|
Value: v,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ex, err := c.query(query, namedArgs)
|
||||||
|
if ex != nil {
|
||||||
|
time.Sleep(ex.delay)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ex.rows, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlmock) query(query string, args []driver.NamedValue) (*ExpectedQuery, error) {
|
||||||
|
var expected *ExpectedQuery
|
||||||
|
var fulfilled int
|
||||||
|
var ok bool
|
||||||
|
for _, next := range c.expected {
|
||||||
|
next.Lock()
|
||||||
|
if next.fulfilled() {
|
||||||
|
next.Unlock()
|
||||||
|
fulfilled++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.ordered {
|
||||||
|
if expected, ok = next.(*ExpectedQuery); ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
next.Unlock()
|
||||||
|
return nil, fmt.Errorf("call to Query '%s' with args %+v, was not expected, next expectation is: %s", query, args, next)
|
||||||
|
}
|
||||||
|
if qr, ok := next.(*ExpectedQuery); ok {
|
||||||
|
if err := c.queryMatcher.Match(qr.expectSQL, query); err != nil {
|
||||||
|
next.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := qr.attemptArgMatch(args); err == nil {
|
||||||
|
expected = qr
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
next.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected == nil {
|
||||||
|
msg := "call to Query '%s' with args %+v was not expected"
|
||||||
|
if fulfilled == len(c.expected) {
|
||||||
|
msg = "all expectations were already fulfilled, " + msg
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf(msg, query, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer expected.Unlock()
|
||||||
|
|
||||||
|
if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil {
|
||||||
|
return nil, fmt.Errorf("Query: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := expected.argsMatches(args); err != nil {
|
||||||
|
return nil, fmt.Errorf("Query '%s', arguments do not match: %s", query, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected.triggered = true
|
||||||
|
if expected.err != nil {
|
||||||
|
return expected, expected.err // mocked to return error
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected.rows == nil {
|
||||||
|
return nil, fmt.Errorf("Query '%s' with args %+v, must return a database/sql/driver.Rows, but it was not set for expectation %T as %+v", query, args, expected, expected)
|
||||||
|
}
|
||||||
|
return expected, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec meets http://golang.org/pkg/database/sql/driver/#Execer
|
||||||
|
// Deprecated: Drivers should implement ExecerContext instead.
|
||||||
|
func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error) {
|
||||||
|
namedArgs := make([]driver.NamedValue, len(args))
|
||||||
|
for i, v := range args {
|
||||||
|
namedArgs[i] = driver.NamedValue{
|
||||||
|
Ordinal: i + 1,
|
||||||
|
Value: v,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ex, err := c.exec(query, namedArgs)
|
||||||
|
if ex != nil {
|
||||||
|
time.Sleep(ex.delay)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ex.result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlmock) exec(query string, args []driver.NamedValue) (*ExpectedExec, error) {
|
||||||
|
var expected *ExpectedExec
|
||||||
|
var fulfilled int
|
||||||
|
var ok bool
|
||||||
|
for _, next := range c.expected {
|
||||||
|
next.Lock()
|
||||||
|
if next.fulfilled() {
|
||||||
|
next.Unlock()
|
||||||
|
fulfilled++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.ordered {
|
||||||
|
if expected, ok = next.(*ExpectedExec); ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
next.Unlock()
|
||||||
|
return nil, fmt.Errorf("call to ExecQuery '%s' with args %+v, was not expected, next expectation is: %s", query, args, next)
|
||||||
|
}
|
||||||
|
if exec, ok := next.(*ExpectedExec); ok {
|
||||||
|
if err := c.queryMatcher.Match(exec.expectSQL, query); err != nil {
|
||||||
|
next.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := exec.attemptArgMatch(args); err == nil {
|
||||||
|
expected = exec
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
next.Unlock()
|
||||||
|
}
|
||||||
|
if expected == nil {
|
||||||
|
msg := "call to ExecQuery '%s' with args %+v was not expected"
|
||||||
|
if fulfilled == len(c.expected) {
|
||||||
|
msg = "all expectations were already fulfilled, " + msg
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf(msg, query, args)
|
||||||
|
}
|
||||||
|
defer expected.Unlock()
|
||||||
|
|
||||||
|
if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil {
|
||||||
|
return nil, fmt.Errorf("ExecQuery: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := expected.argsMatches(args); err != nil {
|
||||||
|
return nil, fmt.Errorf("ExecQuery '%s', arguments do not match: %s", query, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected.triggered = true
|
||||||
|
if expected.err != nil {
|
||||||
|
return expected, expected.err // mocked to return error
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected.result == nil {
|
||||||
|
return nil, fmt.Errorf("ExecQuery '%s' with args %+v, must return a database/sql/driver.Result, but it was not set for expectation %T as %+v", query, args, expected, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
return expected, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// @TODO maybe add ExpectedBegin.WithOptions(driver.TxOptions)
|
||||||
|
|
||||||
|
// NewRowsWithColumnDefinition allows Rows to be created from a
|
||||||
|
// sql driver.Value slice with a definition of sql metadata
|
||||||
|
func (c *sqlmock) NewRowsWithColumnDefinition(columns ...*Column) *Rows {
|
||||||
|
r := NewRowsWithColumnDefinition(columns...)
|
||||||
|
r.converter = c.converter
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewColumn allows to create a Column that can be enhanced with metadata
|
||||||
|
// using OfType/Nullable/WithLength/WithPrecisionAndScale methods.
|
||||||
|
func (c *sqlmock) NewColumn(name string) *Column {
|
||||||
|
return NewColumn(name)
|
||||||
|
}
|
|
@ -0,0 +1,11 @@
|
||||||
|
// +build go1.8,!go1.9
|
||||||
|
|
||||||
|
package sqlmock
|
||||||
|
|
||||||
|
import "database/sql/driver"
|
||||||
|
|
||||||
|
// CheckNamedValue meets https://golang.org/pkg/database/sql/driver/#NamedValueChecker
|
||||||
|
func (c *sqlmock) CheckNamedValue(nv *driver.NamedValue) (err error) {
|
||||||
|
nv.Value, err = c.converter.ConvertValue(nv.Value)
|
||||||
|
return err
|
||||||
|
}
|
|
@ -0,0 +1,19 @@
|
||||||
|
// +build go1.9
|
||||||
|
|
||||||
|
package sqlmock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CheckNamedValue meets https://golang.org/pkg/database/sql/driver/#NamedValueChecker
|
||||||
|
func (c *sqlmock) CheckNamedValue(nv *driver.NamedValue) (err error) {
|
||||||
|
switch nv.Value.(type) {
|
||||||
|
case sql.Out:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
nv.Value, err = c.converter.ConvertValue(nv.Value)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,16 @@
|
||||||
|
package sqlmock
|
||||||
|
|
||||||
|
type statement struct {
|
||||||
|
conn *sqlmock
|
||||||
|
ex *ExpectedPrepare
|
||||||
|
query string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stmt *statement) Close() error {
|
||||||
|
stmt.ex.wasClosed = true
|
||||||
|
return stmt.ex.closeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stmt *statement) NumInput() int {
|
||||||
|
return -1
|
||||||
|
}
|
|
@ -0,0 +1,17 @@
|
||||||
|
// +build !go1.8
|
||||||
|
|
||||||
|
package sqlmock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Deprecated: Drivers should implement ExecerContext instead.
|
||||||
|
func (stmt *statement) Exec(args []driver.Value) (driver.Result, error) {
|
||||||
|
return stmt.conn.Exec(stmt.query, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Drivers should implement StmtQueryContext instead (or additionally).
|
||||||
|
func (stmt *statement) Query(args []driver.Value) (driver.Rows, error) {
|
||||||
|
return stmt.conn.Query(stmt.query, args)
|
||||||
|
}
|
|
@ -0,0 +1,26 @@
|
||||||
|
// +build go1.8
|
||||||
|
|
||||||
|
package sqlmock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql/driver"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Deprecated: Drivers should implement ExecerContext instead.
|
||||||
|
func (stmt *statement) Exec(args []driver.Value) (driver.Result, error) {
|
||||||
|
return stmt.conn.ExecContext(context.Background(), stmt.query, convertValueToNamedValue(args))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Drivers should implement StmtQueryContext instead (or additionally).
|
||||||
|
func (stmt *statement) Query(args []driver.Value) (driver.Rows, error) {
|
||||||
|
return stmt.conn.QueryContext(context.Background(), stmt.query, convertValueToNamedValue(args))
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertValueToNamedValue(args []driver.Value) []driver.NamedValue {
|
||||||
|
namedArgs := make([]driver.NamedValue, len(args))
|
||||||
|
for i, v := range args {
|
||||||
|
namedArgs[i] = driver.NamedValue{Ordinal: i + 1, Value: v}
|
||||||
|
}
|
||||||
|
return namedArgs
|
||||||
|
}
|
|
@ -0,0 +1,4 @@
|
||||||
|
coverage.txt
|
||||||
|
bin
|
||||||
|
card.png
|
||||||
|
dist
|
|
@ -0,0 +1,8 @@
|
||||||
|
linters:
|
||||||
|
enable:
|
||||||
|
- thelper
|
||||||
|
- gofumpt
|
||||||
|
- tparallel
|
||||||
|
- unconvert
|
||||||
|
- unparam
|
||||||
|
- wastedassign
|
|
@ -0,0 +1,3 @@
|
||||||
|
includes:
|
||||||
|
- from_url:
|
||||||
|
url: https://raw.githubusercontent.com/caarlos0/.goreleaserfiles/main/lib.yml
|
|
@ -0,0 +1,21 @@
|
||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright (c) 2015-2022 Carlos Alexandro Becker
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
|
@ -0,0 +1,37 @@
|
||||||
|
SOURCE_FILES?=./...
|
||||||
|
TEST_PATTERN?=.
|
||||||
|
|
||||||
|
export GO111MODULE := on
|
||||||
|
|
||||||
|
setup:
|
||||||
|
go mod tidy
|
||||||
|
.PHONY: setup
|
||||||
|
|
||||||
|
build:
|
||||||
|
go build
|
||||||
|
.PHONY: build
|
||||||
|
|
||||||
|
test:
|
||||||
|
go test -v -failfast -race -coverpkg=./... -covermode=atomic -coverprofile=coverage.txt $(SOURCE_FILES) -run $(TEST_PATTERN) -timeout=2m
|
||||||
|
.PHONY: test
|
||||||
|
|
||||||
|
cover: test
|
||||||
|
go tool cover -html=coverage.txt
|
||||||
|
.PHONY: cover
|
||||||
|
|
||||||
|
fmt:
|
||||||
|
gofumpt -w -l .
|
||||||
|
.PHONY: fmt
|
||||||
|
|
||||||
|
lint:
|
||||||
|
golangci-lint run ./...
|
||||||
|
.PHONY: lint
|
||||||
|
|
||||||
|
ci: build test
|
||||||
|
.PHONY: ci
|
||||||
|
|
||||||
|
card:
|
||||||
|
wget -O card.png -c "https://og.caarlos0.dev/**env**: parse envs to structs.png?theme=light&md=1&fontSize=100px&images=https://github.com/caarlos0.png"
|
||||||
|
.PHONY: card
|
||||||
|
|
||||||
|
.DEFAULT_GOAL := ci
|
|
@ -0,0 +1,551 @@
|
||||||
|
# env
|
||||||
|
|
||||||
|
[![Build Status](https://img.shields.io/github/actions/workflow/status/caarlos0/env/build.yml?branch=main&style=for-the-badge)](https://github.com/caarlos0/env/actions?workflow=build)
|
||||||
|
[![Coverage Status](https://img.shields.io/codecov/c/gh/caarlos0/env.svg?logo=codecov&style=for-the-badge)](https://codecov.io/gh/caarlos0/env)
|
||||||
|
[![](http://img.shields.io/badge/godoc-reference-5272B4.svg?style=for-the-badge)](https://pkg.go.dev/github.com/caarlos0/env/v7)
|
||||||
|
|
||||||
|
A simple and zero-dependencies library to parse environment variables into structs.
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
Get the module with:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
go get github.com/caarlos0/env/v7
|
||||||
|
```
|
||||||
|
|
||||||
|
The usage looks like this:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/caarlos0/env/v7"
|
||||||
|
)
|
||||||
|
|
||||||
|
type config struct {
|
||||||
|
Home string `env:"HOME"`
|
||||||
|
Port int `env:"PORT" envDefault:"3000"`
|
||||||
|
Password string `env:"PASSWORD,unset"`
|
||||||
|
IsProduction bool `env:"PRODUCTION"`
|
||||||
|
Hosts []string `env:"HOSTS" envSeparator:":"`
|
||||||
|
Duration time.Duration `env:"DURATION"`
|
||||||
|
TempFolder string `env:"TEMP_FOLDER" envDefault:"${HOME}/tmp" envExpand:"true"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
cfg := config{}
|
||||||
|
if err := env.Parse(&cfg); err != nil {
|
||||||
|
fmt.Printf("%+v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("%+v\n", cfg)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
You can run it like this:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
$ PRODUCTION=true HOSTS="host1:host2:host3" DURATION=1s go run main.go
|
||||||
|
{Home:/your/home Port:3000 IsProduction:true Hosts:[host1 host2 host3] Duration:1s}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Caveats
|
||||||
|
|
||||||
|
> **Warning**
|
||||||
|
>
|
||||||
|
> **This is important!**
|
||||||
|
|
||||||
|
- _Unexported fields_ are **ignored**
|
||||||
|
|
||||||
|
|
||||||
|
## Supported types and defaults
|
||||||
|
|
||||||
|
Out of the box all built-in types are supported, plus a few others that
|
||||||
|
are commonly used.
|
||||||
|
|
||||||
|
Complete list:
|
||||||
|
|
||||||
|
- `string`
|
||||||
|
- `bool`
|
||||||
|
- `int`
|
||||||
|
- `int8`
|
||||||
|
- `int16`
|
||||||
|
- `int32`
|
||||||
|
- `int64`
|
||||||
|
- `uint`
|
||||||
|
- `uint8`
|
||||||
|
- `uint16`
|
||||||
|
- `uint32`
|
||||||
|
- `uint64`
|
||||||
|
- `float32`
|
||||||
|
- `float64`
|
||||||
|
- `time.Duration`
|
||||||
|
- `encoding.TextUnmarshaler`
|
||||||
|
- `url.URL`
|
||||||
|
|
||||||
|
Pointers, slices and slices of pointers, and maps of those types are also
|
||||||
|
supported.
|
||||||
|
|
||||||
|
You can also use/define a [custom parser func](#custom-parser-funcs) for any
|
||||||
|
other type you want.
|
||||||
|
|
||||||
|
You can also use custom keys and values in your maps, as long as you provide a
|
||||||
|
parser function for them.
|
||||||
|
|
||||||
|
If you set the `envDefault` tag for something, this value will be used in the
|
||||||
|
case of absence of it in the environment.
|
||||||
|
|
||||||
|
By default, slice types will split the environment value on `,`; you can change
|
||||||
|
this behavior by setting the `envSeparator` tag.
|
||||||
|
|
||||||
|
If you set the `envExpand` tag, environment variables (either in `${var}` or
|
||||||
|
`$var` format) in the string will be replaced according with the actual value
|
||||||
|
of the variable.
|
||||||
|
|
||||||
|
## Custom Parser Funcs
|
||||||
|
|
||||||
|
If you have a type that is not supported out of the box by the lib, you are able
|
||||||
|
to use (or define) and pass custom parsers (and their associated `reflect.Type`)
|
||||||
|
to the `env.ParseWithFuncs()` function.
|
||||||
|
|
||||||
|
In addition to accepting a struct pointer (same as `Parse()`), this function
|
||||||
|
also accepts a `map[reflect.Type]env.ParserFunc`.
|
||||||
|
|
||||||
|
If you add a custom parser for, say `Foo`, it will also be used to parse
|
||||||
|
`*Foo` and `[]Foo` types.
|
||||||
|
|
||||||
|
Check the examples in the [go doc](http://pkg.go.dev/github.com/caarlos0/env/v7)
|
||||||
|
for more info.
|
||||||
|
|
||||||
|
### A note about `TextUnmarshaler` and `time.Time`
|
||||||
|
|
||||||
|
Env supports by default anything that implements the `TextUnmarshaler` interface.
|
||||||
|
That includes things like `time.Time` for example.
|
||||||
|
The upside is that depending on the format you need, you don't need to change anything.
|
||||||
|
The downside is that if you do need time in another format, you'll need to create your own type.
|
||||||
|
|
||||||
|
Its fairly straightforward:
|
||||||
|
|
||||||
|
```go
|
||||||
|
type MyTime time.Time
|
||||||
|
|
||||||
|
func (t *MyTime) UnmarshalText(text []byte) error {
|
||||||
|
tt, err := time.Parse("2006-01-02", string(text))
|
||||||
|
*t = MyTime(tt)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
SomeTime MyTime `env:"SOME_TIME"`
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
And then you can parse `Config` with `env.Parse`.
|
||||||
|
|
||||||
|
## Required fields
|
||||||
|
|
||||||
|
The `env` tag option `required` (e.g., `env:"tagKey,required"`) can be added to ensure that some environment variable is set.
|
||||||
|
In the example above, an error is returned if the `config` struct is changed to:
|
||||||
|
|
||||||
|
```go
|
||||||
|
type config struct {
|
||||||
|
SecretKey string `env:"SECRET_KEY,required"`
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Not Empty fields
|
||||||
|
|
||||||
|
While `required` demands the environment variable to be set, it doesn't check its value.
|
||||||
|
If you want to make sure the environment is set and not empty, you need to use the `notEmpty` tag option instead (`env:"SOME_ENV,notEmpty"`).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
type config struct {
|
||||||
|
SecretKey string `env:"SECRET_KEY,notEmpty"`
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Unset environment variable after reading it
|
||||||
|
|
||||||
|
The `env` tag option `unset` (e.g., `env:"tagKey,unset"`) can be added
|
||||||
|
to ensure that some environment variable is unset after reading it.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
type config struct {
|
||||||
|
SecretKey string `env:"SECRET_KEY,unset"`
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## From file
|
||||||
|
|
||||||
|
The `env` tag option `file` (e.g., `env:"tagKey,file"`) can be added
|
||||||
|
to in order to indicate that the value of the variable shall be loaded from a file. The path of that file is given
|
||||||
|
by the environment variable associated with it
|
||||||
|
Example below
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
"github.com/caarlos0/env/v7"
|
||||||
|
)
|
||||||
|
|
||||||
|
type config struct {
|
||||||
|
Secret string `env:"SECRET,file"`
|
||||||
|
Password string `env:"PASSWORD,file" envDefault:"/tmp/password"`
|
||||||
|
Certificate string `env:"CERTIFICATE,file" envDefault:"${CERTIFICATE_FILE}" envExpand:"true"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
cfg := config{}
|
||||||
|
if err := env.Parse(&cfg); err != nil {
|
||||||
|
fmt.Printf("%+v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("%+v\n", cfg)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
```sh
|
||||||
|
$ echo qwerty > /tmp/secret
|
||||||
|
$ echo dvorak > /tmp/password
|
||||||
|
$ echo coleman > /tmp/certificate
|
||||||
|
|
||||||
|
$ SECRET=/tmp/secret \
|
||||||
|
CERTIFICATE_FILE=/tmp/certificate \
|
||||||
|
go run main.go
|
||||||
|
{Secret:qwerty Password:dvorak Certificate:coleman}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Options
|
||||||
|
|
||||||
|
### Use field names as environment variables by default
|
||||||
|
|
||||||
|
If you don't want to set the `env` tag on every field, you can use the
|
||||||
|
`UseFieldNameByDefault` option.
|
||||||
|
|
||||||
|
It will use the field name as environment variable name.
|
||||||
|
|
||||||
|
Here's an example:
|
||||||
|
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/caarlos0/env/v7"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Username string // will use $USERNAME
|
||||||
|
Password string // will use $PASSWORD
|
||||||
|
UserFullName string // will use $USER_FULL_NAME
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
cfg := &Config{}
|
||||||
|
opts := &env.Options{UseFieldNameByDefault: true}
|
||||||
|
|
||||||
|
// Load env vars.
|
||||||
|
if err := env.Parse(cfg, opts); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print the loaded data.
|
||||||
|
fmt.Printf("%+v\n", cfg)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Environment
|
||||||
|
|
||||||
|
By setting the `Options.Environment` map you can tell `Parse` to add those `keys` and `values`
|
||||||
|
as env vars before parsing is done. These envs are stored in the map and never actually set by `os.Setenv`.
|
||||||
|
This option effectively makes `env` ignore the OS environment variables: only the ones provided in the option are used.
|
||||||
|
|
||||||
|
This can make your testing scenarios a bit more clean and easy to handle.
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/caarlos0/env/v7"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Password string `env:"PASSWORD"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
cfg := &Config{}
|
||||||
|
opts := &env.Options{Environment: map[string]string{
|
||||||
|
"PASSWORD": "MY_PASSWORD",
|
||||||
|
}}
|
||||||
|
|
||||||
|
// Load env vars.
|
||||||
|
if err := env.Parse(cfg, opts); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print the loaded data.
|
||||||
|
fmt.Printf("%+v\n", cfg)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Changing default tag name
|
||||||
|
|
||||||
|
You can change what tag name to use for setting the env vars by setting the `Options.TagName`
|
||||||
|
variable.
|
||||||
|
|
||||||
|
For example
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/caarlos0/env/v7"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Password string `json:"PASSWORD"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
cfg := &Config{}
|
||||||
|
opts := &env.Options{TagName: "json"}
|
||||||
|
|
||||||
|
// Load env vars.
|
||||||
|
if err := env.Parse(cfg, opts); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print the loaded data.
|
||||||
|
fmt.Printf("%+v\n", cfg)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Prefixes
|
||||||
|
|
||||||
|
You can prefix sub-structs env tags, as well as a whole `env.Parse` call.
|
||||||
|
|
||||||
|
Here's an example flexing it a bit:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/caarlos0/env/v7"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Home string `env:"HOME"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ComplexConfig struct {
|
||||||
|
Foo Config `envPrefix:"FOO_"`
|
||||||
|
Clean Config
|
||||||
|
Bar Config `envPrefix:"BAR_"`
|
||||||
|
Blah string `env:"BLAH"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
cfg := ComplexConfig{}
|
||||||
|
if err := Parse(&cfg, Options{
|
||||||
|
Prefix: "T_",
|
||||||
|
Environment: map[string]string{
|
||||||
|
"T_FOO_HOME": "/foo",
|
||||||
|
"T_BAR_HOME": "/bar",
|
||||||
|
"T_BLAH": "blahhh",
|
||||||
|
"T_HOME": "/clean",
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load env vars.
|
||||||
|
if err := env.Parse(cfg, opts); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print the loaded data.
|
||||||
|
fmt.Printf("%+v\n", cfg)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### On set hooks
|
||||||
|
|
||||||
|
You might want to listen to value sets and, for example, log something or do some other kind of logic.
|
||||||
|
You can do this by passing a `OnSet` option:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/caarlos0/env/v7"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Username string `env:"USERNAME" envDefault:"admin"`
|
||||||
|
Password string `env:"PASSWORD"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
cfg := &Config{}
|
||||||
|
opts := &env.Options{
|
||||||
|
OnSet: func(tag string, value interface{}, isDefault bool) {
|
||||||
|
fmt.Printf("Set %s to %v (default? %v)\n", tag, value, isDefault)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load env vars.
|
||||||
|
if err := env.Parse(cfg, opts); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print the loaded data.
|
||||||
|
fmt.Printf("%+v\n", cfg)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Making all fields to required
|
||||||
|
|
||||||
|
You can make all fields that don't have a default value be required by setting the `RequiredIfNoDef: true` in the `Options`.
|
||||||
|
|
||||||
|
For example
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/caarlos0/env/v7"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Username string `env:"USERNAME" envDefault:"admin"`
|
||||||
|
Password string `env:"PASSWORD"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
cfg := &Config{}
|
||||||
|
opts := &env.Options{RequiredIfNoDef: true}
|
||||||
|
|
||||||
|
// Load env vars.
|
||||||
|
if err := env.Parse(cfg, opts); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print the loaded data.
|
||||||
|
fmt.Printf("%+v\n", cfg)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Defaults from code
|
||||||
|
|
||||||
|
You may define default value also in code, by initialising the config data before it's filled by `env.Parse`.
|
||||||
|
Default values defined as struct tags will overwrite existing values during Parse.
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/caarlos0/env/v7"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Username string `env:"USERNAME" envDefault:"admin"`
|
||||||
|
Password string `env:"PASSWORD"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
var cfg = Config{
|
||||||
|
Username: "test",
|
||||||
|
Password: "123456",
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := env.Parse(&cfg); err != nil {
|
||||||
|
fmt.Println("failed:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("%+v", cfg) // {Username:admin Password:123456}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Error handling
|
||||||
|
|
||||||
|
You can handle the errors the library throws like so:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/caarlos0/env/v7"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Username string `env:"USERNAME" envDefault:"admin"`
|
||||||
|
Password string `env:"PASSWORD"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
var cfg Config
|
||||||
|
err := env.Parse(&cfg)
|
||||||
|
if e, ok := err.(*env.AggregateError); ok {
|
||||||
|
for _, er := range e.Errors {
|
||||||
|
switch v := er.(type) {
|
||||||
|
case env.ParseError:
|
||||||
|
// handle it
|
||||||
|
case env.NotStructPtrError:
|
||||||
|
// handle it
|
||||||
|
case env.NoParserError:
|
||||||
|
// handle it
|
||||||
|
case env.NoSupportedTagOptionError:
|
||||||
|
// handle it
|
||||||
|
default:
|
||||||
|
fmt.Printf("Unknown error type %v", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("%+v", cfg) // {Username:admin Password:123456}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Info**
|
||||||
|
>
|
||||||
|
> If you want to check if an specific error is in the chain, you can also use
|
||||||
|
> `errors.Is()`.
|
||||||
|
|
||||||
|
## Stargazers over time
|
||||||
|
|
||||||
|
[![Stargazers over time](https://starchart.cc/caarlos0/env.svg)](https://starchart.cc/caarlos0/env)
|
|
@ -0,0 +1,540 @@
|
||||||
|
package env
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
"unicode"
|
||||||
|
)
|
||||||
|
|
||||||
|
// nolint: gochecknoglobals
|
||||||
|
var (
|
||||||
|
defaultBuiltInParsers = map[reflect.Kind]ParserFunc{
|
||||||
|
reflect.Bool: func(v string) (interface{}, error) {
|
||||||
|
return strconv.ParseBool(v)
|
||||||
|
},
|
||||||
|
reflect.String: func(v string) (interface{}, error) {
|
||||||
|
return v, nil
|
||||||
|
},
|
||||||
|
reflect.Int: func(v string) (interface{}, error) {
|
||||||
|
i, err := strconv.ParseInt(v, 10, 32)
|
||||||
|
return int(i), err
|
||||||
|
},
|
||||||
|
reflect.Int16: func(v string) (interface{}, error) {
|
||||||
|
i, err := strconv.ParseInt(v, 10, 16)
|
||||||
|
return int16(i), err
|
||||||
|
},
|
||||||
|
reflect.Int32: func(v string) (interface{}, error) {
|
||||||
|
i, err := strconv.ParseInt(v, 10, 32)
|
||||||
|
return int32(i), err
|
||||||
|
},
|
||||||
|
reflect.Int64: func(v string) (interface{}, error) {
|
||||||
|
return strconv.ParseInt(v, 10, 64)
|
||||||
|
},
|
||||||
|
reflect.Int8: func(v string) (interface{}, error) {
|
||||||
|
i, err := strconv.ParseInt(v, 10, 8)
|
||||||
|
return int8(i), err
|
||||||
|
},
|
||||||
|
reflect.Uint: func(v string) (interface{}, error) {
|
||||||
|
i, err := strconv.ParseUint(v, 10, 32)
|
||||||
|
return uint(i), err
|
||||||
|
},
|
||||||
|
reflect.Uint16: func(v string) (interface{}, error) {
|
||||||
|
i, err := strconv.ParseUint(v, 10, 16)
|
||||||
|
return uint16(i), err
|
||||||
|
},
|
||||||
|
reflect.Uint32: func(v string) (interface{}, error) {
|
||||||
|
i, err := strconv.ParseUint(v, 10, 32)
|
||||||
|
return uint32(i), err
|
||||||
|
},
|
||||||
|
reflect.Uint64: func(v string) (interface{}, error) {
|
||||||
|
i, err := strconv.ParseUint(v, 10, 64)
|
||||||
|
return i, err
|
||||||
|
},
|
||||||
|
reflect.Uint8: func(v string) (interface{}, error) {
|
||||||
|
i, err := strconv.ParseUint(v, 10, 8)
|
||||||
|
return uint8(i), err
|
||||||
|
},
|
||||||
|
reflect.Float64: func(v string) (interface{}, error) {
|
||||||
|
return strconv.ParseFloat(v, 64)
|
||||||
|
},
|
||||||
|
reflect.Float32: func(v string) (interface{}, error) {
|
||||||
|
f, err := strconv.ParseFloat(v, 32)
|
||||||
|
return float32(f), err
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func defaultTypeParsers() map[reflect.Type]ParserFunc {
|
||||||
|
return map[reflect.Type]ParserFunc{
|
||||||
|
reflect.TypeOf(url.URL{}): func(v string) (interface{}, error) {
|
||||||
|
u, err := url.Parse(v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, newParseValueError("unable to parse URL", err)
|
||||||
|
}
|
||||||
|
return *u, nil
|
||||||
|
},
|
||||||
|
reflect.TypeOf(time.Nanosecond): func(v string) (interface{}, error) {
|
||||||
|
s, err := time.ParseDuration(v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, newParseValueError("unable to parse duration", err)
|
||||||
|
}
|
||||||
|
return s, err
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParserFunc defines the signature of a function that can be used within `CustomParsers`.
|
||||||
|
type ParserFunc func(v string) (interface{}, error)
|
||||||
|
|
||||||
|
// OnSetFn is a hook that can be run when a value is set.
|
||||||
|
type OnSetFn func(tag string, value interface{}, isDefault bool)
|
||||||
|
|
||||||
|
// Options for the parser.
|
||||||
|
type Options struct {
|
||||||
|
// Environment keys and values that will be accessible for the service.
|
||||||
|
Environment map[string]string
|
||||||
|
|
||||||
|
// TagName specifies another tagname to use rather than the default env.
|
||||||
|
TagName string
|
||||||
|
|
||||||
|
// RequiredIfNoDef automatically sets all env as required if they do not
|
||||||
|
// declare 'envDefault'.
|
||||||
|
RequiredIfNoDef bool
|
||||||
|
|
||||||
|
// OnSet allows to run a function when a value is set.
|
||||||
|
OnSet OnSetFn
|
||||||
|
|
||||||
|
// Prefix define a prefix for each key.
|
||||||
|
Prefix string
|
||||||
|
|
||||||
|
// UseFieldNameByDefault defines whether or not env should use the field
|
||||||
|
// name by default if the `env` key is missing.
|
||||||
|
UseFieldNameByDefault bool
|
||||||
|
|
||||||
|
// Sets to true if we have already configured once.
|
||||||
|
configured bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// configure will do the basic configurations and defaults.
|
||||||
|
func configure(opts []Options) []Options {
|
||||||
|
// If we have already configured the first item
|
||||||
|
// of options will have been configured set to true.
|
||||||
|
if len(opts) > 0 && opts[0].configured {
|
||||||
|
return opts
|
||||||
|
}
|
||||||
|
|
||||||
|
// Created options with defaults.
|
||||||
|
opt := Options{
|
||||||
|
TagName: "env",
|
||||||
|
Environment: toMap(os.Environ()),
|
||||||
|
configured: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loop over all opts structs and set
|
||||||
|
// to opt if value is not default/empty.
|
||||||
|
for _, item := range opts {
|
||||||
|
if item.Environment != nil {
|
||||||
|
opt.Environment = item.Environment
|
||||||
|
}
|
||||||
|
if item.TagName != "" {
|
||||||
|
opt.TagName = item.TagName
|
||||||
|
}
|
||||||
|
if item.OnSet != nil {
|
||||||
|
opt.OnSet = item.OnSet
|
||||||
|
}
|
||||||
|
if item.Prefix != "" {
|
||||||
|
opt.Prefix = item.Prefix
|
||||||
|
}
|
||||||
|
opt.UseFieldNameByDefault = item.UseFieldNameByDefault
|
||||||
|
opt.RequiredIfNoDef = item.RequiredIfNoDef
|
||||||
|
}
|
||||||
|
|
||||||
|
return []Options{opt}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getOnSetFn(opts []Options) OnSetFn {
|
||||||
|
return opts[0].OnSet
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTagName returns the tag name.
|
||||||
|
func getTagName(opts []Options) string {
|
||||||
|
return opts[0].TagName
|
||||||
|
}
|
||||||
|
|
||||||
|
// getEnvironment returns the environment map.
|
||||||
|
func getEnvironment(opts []Options) map[string]string {
|
||||||
|
return opts[0].Environment
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse parses a struct containing `env` tags and loads its values from
|
||||||
|
// environment variables.
|
||||||
|
func Parse(v interface{}, opts ...Options) error {
|
||||||
|
return ParseWithFuncs(v, map[reflect.Type]ParserFunc{}, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseWithFuncs is the same as `Parse` except it also allows the user to pass
|
||||||
|
// in custom parsers.
|
||||||
|
func ParseWithFuncs(v interface{}, funcMap map[reflect.Type]ParserFunc, opts ...Options) error {
|
||||||
|
opts = configure(opts)
|
||||||
|
|
||||||
|
ptrRef := reflect.ValueOf(v)
|
||||||
|
if ptrRef.Kind() != reflect.Ptr {
|
||||||
|
return newAggregateError(NotStructPtrError{})
|
||||||
|
}
|
||||||
|
ref := ptrRef.Elem()
|
||||||
|
if ref.Kind() != reflect.Struct {
|
||||||
|
return newAggregateError(NotStructPtrError{})
|
||||||
|
}
|
||||||
|
parsers := defaultTypeParsers()
|
||||||
|
for k, v := range funcMap {
|
||||||
|
parsers[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
return doParse(ref, parsers, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func doParse(ref reflect.Value, funcMap map[reflect.Type]ParserFunc, opts []Options) error {
|
||||||
|
refType := ref.Type()
|
||||||
|
|
||||||
|
var agrErr AggregateError
|
||||||
|
|
||||||
|
for i := 0; i < refType.NumField(); i++ {
|
||||||
|
refField := ref.Field(i)
|
||||||
|
refTypeField := refType.Field(i)
|
||||||
|
|
||||||
|
if err := doParseField(refField, refTypeField, funcMap, opts); err != nil {
|
||||||
|
if val, ok := err.(AggregateError); ok {
|
||||||
|
agrErr.Errors = append(agrErr.Errors, val.Errors...)
|
||||||
|
} else {
|
||||||
|
agrErr.Errors = append(agrErr.Errors, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(agrErr.Errors) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return agrErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func doParseField(refField reflect.Value, refTypeField reflect.StructField, funcMap map[reflect.Type]ParserFunc, opts []Options) error {
|
||||||
|
if !refField.CanSet() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if reflect.Ptr == refField.Kind() && !refField.IsNil() {
|
||||||
|
return ParseWithFuncs(refField.Interface(), funcMap, optsWithPrefix(refTypeField, opts)...)
|
||||||
|
}
|
||||||
|
if reflect.Struct == refField.Kind() && refField.CanAddr() && refField.Type().Name() == "" {
|
||||||
|
return ParseWithFuncs(refField.Addr().Interface(), funcMap, optsWithPrefix(refTypeField, opts)...)
|
||||||
|
}
|
||||||
|
value, err := get(refTypeField, opts)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if value != "" {
|
||||||
|
return set(refField, refTypeField, value, funcMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
if reflect.Struct == refField.Kind() {
|
||||||
|
return doParse(refField, funcMap, optsWithPrefix(refTypeField, opts))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const underscore rune = '_'
|
||||||
|
|
||||||
|
func toEnvName(input string) string {
|
||||||
|
var output []rune
|
||||||
|
for i, c := range input {
|
||||||
|
if i > 0 && output[i-1] != underscore && c != underscore && unicode.ToUpper(c) == c {
|
||||||
|
output = append(output, underscore)
|
||||||
|
}
|
||||||
|
output = append(output, unicode.ToUpper(c))
|
||||||
|
}
|
||||||
|
return string(output)
|
||||||
|
}
|
||||||
|
|
||||||
|
func get(field reflect.StructField, opts []Options) (val string, err error) {
|
||||||
|
var exists bool
|
||||||
|
var isDefault bool
|
||||||
|
var loadFile bool
|
||||||
|
var unset bool
|
||||||
|
var notEmpty bool
|
||||||
|
|
||||||
|
required := opts[0].RequiredIfNoDef
|
||||||
|
prefix := opts[0].Prefix
|
||||||
|
ownKey, tags := parseKeyForOption(field.Tag.Get(getTagName(opts)))
|
||||||
|
if ownKey == "" && opts[0].UseFieldNameByDefault {
|
||||||
|
ownKey = toEnvName(field.Name)
|
||||||
|
}
|
||||||
|
key := prefix + ownKey
|
||||||
|
for _, tag := range tags {
|
||||||
|
switch tag {
|
||||||
|
case "":
|
||||||
|
continue
|
||||||
|
case "file":
|
||||||
|
loadFile = true
|
||||||
|
case "required":
|
||||||
|
required = true
|
||||||
|
case "unset":
|
||||||
|
unset = true
|
||||||
|
case "notEmpty":
|
||||||
|
notEmpty = true
|
||||||
|
default:
|
||||||
|
return "", newNoSupportedTagOptionError(tag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
expand := strings.EqualFold(field.Tag.Get("envExpand"), "true")
|
||||||
|
defaultValue, defExists := field.Tag.Lookup("envDefault")
|
||||||
|
val, exists, isDefault = getOr(key, defaultValue, defExists, getEnvironment(opts))
|
||||||
|
|
||||||
|
if expand {
|
||||||
|
val = os.ExpandEnv(val)
|
||||||
|
}
|
||||||
|
|
||||||
|
if unset {
|
||||||
|
defer os.Unsetenv(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
if required && !exists && len(ownKey) > 0 {
|
||||||
|
return "", newEnvVarIsNotSet(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
if notEmpty && val == "" {
|
||||||
|
return "", newEmptyEnvVarError(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
if loadFile && val != "" {
|
||||||
|
filename := val
|
||||||
|
val, err = getFromFile(filename)
|
||||||
|
if err != nil {
|
||||||
|
return "", newLoadFileContentError(filename, key, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if onSetFn := getOnSetFn(opts); onSetFn != nil {
|
||||||
|
onSetFn(key, val, isDefault)
|
||||||
|
}
|
||||||
|
return val, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// split the env tag's key into the expected key and desired option, if any.
|
||||||
|
func parseKeyForOption(key string) (string, []string) {
|
||||||
|
opts := strings.Split(key, ",")
|
||||||
|
return opts[0], opts[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func getFromFile(filename string) (value string, err error) {
|
||||||
|
b, err := os.ReadFile(filename)
|
||||||
|
return string(b), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func getOr(key, defaultValue string, defExists bool, envs map[string]string) (string, bool, bool) {
|
||||||
|
value, exists := envs[key]
|
||||||
|
switch {
|
||||||
|
case (!exists || key == "") && defExists:
|
||||||
|
return defaultValue, true, true
|
||||||
|
case exists && value == "" && defExists:
|
||||||
|
return defaultValue, true, true
|
||||||
|
case !exists:
|
||||||
|
return "", false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return value, true, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func set(field reflect.Value, sf reflect.StructField, value string, funcMap map[reflect.Type]ParserFunc) error {
|
||||||
|
if tm := asTextUnmarshaler(field); tm != nil {
|
||||||
|
if err := tm.UnmarshalText([]byte(value)); err != nil {
|
||||||
|
return newParseError(sf, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
typee := sf.Type
|
||||||
|
fieldee := field
|
||||||
|
if typee.Kind() == reflect.Ptr {
|
||||||
|
typee = typee.Elem()
|
||||||
|
fieldee = field.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
parserFunc, ok := funcMap[typee]
|
||||||
|
if ok {
|
||||||
|
val, err := parserFunc(value)
|
||||||
|
if err != nil {
|
||||||
|
return newParseError(sf, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldee.Set(reflect.ValueOf(val))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parserFunc, ok = defaultBuiltInParsers[typee.Kind()]
|
||||||
|
if ok {
|
||||||
|
val, err := parserFunc(value)
|
||||||
|
if err != nil {
|
||||||
|
return newParseError(sf, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldee.Set(reflect.ValueOf(val).Convert(typee))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch field.Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
return handleSlice(field, value, sf, funcMap)
|
||||||
|
case reflect.Map:
|
||||||
|
return handleMap(field, value, sf, funcMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
return newNoParserError(sf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleSlice(field reflect.Value, value string, sf reflect.StructField, funcMap map[reflect.Type]ParserFunc) error {
|
||||||
|
separator := sf.Tag.Get("envSeparator")
|
||||||
|
if separator == "" {
|
||||||
|
separator = ","
|
||||||
|
}
|
||||||
|
parts := strings.Split(value, separator)
|
||||||
|
|
||||||
|
typee := sf.Type.Elem()
|
||||||
|
if typee.Kind() == reflect.Ptr {
|
||||||
|
typee = typee.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := reflect.New(typee).Interface().(encoding.TextUnmarshaler); ok {
|
||||||
|
return parseTextUnmarshalers(field, parts, sf)
|
||||||
|
}
|
||||||
|
|
||||||
|
parserFunc, ok := funcMap[typee]
|
||||||
|
if !ok {
|
||||||
|
parserFunc, ok = defaultBuiltInParsers[typee.Kind()]
|
||||||
|
if !ok {
|
||||||
|
return newNoParserError(sf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result := reflect.MakeSlice(sf.Type, 0, len(parts))
|
||||||
|
for _, part := range parts {
|
||||||
|
r, err := parserFunc(part)
|
||||||
|
if err != nil {
|
||||||
|
return newParseError(sf, err)
|
||||||
|
}
|
||||||
|
v := reflect.ValueOf(r).Convert(typee)
|
||||||
|
if sf.Type.Elem().Kind() == reflect.Ptr {
|
||||||
|
v = reflect.New(typee)
|
||||||
|
v.Elem().Set(reflect.ValueOf(r).Convert(typee))
|
||||||
|
}
|
||||||
|
result = reflect.Append(result, v)
|
||||||
|
}
|
||||||
|
field.Set(result)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleMap(field reflect.Value, value string, sf reflect.StructField, funcMap map[reflect.Type]ParserFunc) error {
|
||||||
|
keyType := sf.Type.Key()
|
||||||
|
keyParserFunc, ok := funcMap[keyType]
|
||||||
|
if !ok {
|
||||||
|
keyParserFunc, ok = defaultBuiltInParsers[keyType.Kind()]
|
||||||
|
if !ok {
|
||||||
|
return newNoParserError(sf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
elemType := sf.Type.Elem()
|
||||||
|
elemParserFunc, ok := funcMap[elemType]
|
||||||
|
if !ok {
|
||||||
|
elemParserFunc, ok = defaultBuiltInParsers[elemType.Kind()]
|
||||||
|
if !ok {
|
||||||
|
return newNoParserError(sf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
separator := sf.Tag.Get("envSeparator")
|
||||||
|
if separator == "" {
|
||||||
|
separator = ","
|
||||||
|
}
|
||||||
|
|
||||||
|
result := reflect.MakeMap(sf.Type)
|
||||||
|
for _, part := range strings.Split(value, separator) {
|
||||||
|
pairs := strings.Split(part, ":")
|
||||||
|
if len(pairs) != 2 {
|
||||||
|
return newParseError(sf, fmt.Errorf(`%q should be in "key:value" format`, part))
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err := keyParserFunc(pairs[0])
|
||||||
|
if err != nil {
|
||||||
|
return newParseError(sf, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
elem, err := elemParserFunc(pairs[1])
|
||||||
|
if err != nil {
|
||||||
|
return newParseError(sf, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result.SetMapIndex(reflect.ValueOf(key).Convert(keyType), reflect.ValueOf(elem).Convert(elemType))
|
||||||
|
}
|
||||||
|
|
||||||
|
field.Set(result)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func asTextUnmarshaler(field reflect.Value) encoding.TextUnmarshaler {
|
||||||
|
if reflect.Ptr == field.Kind() {
|
||||||
|
if field.IsNil() {
|
||||||
|
field.Set(reflect.New(field.Type().Elem()))
|
||||||
|
}
|
||||||
|
} else if field.CanAddr() {
|
||||||
|
field = field.Addr()
|
||||||
|
}
|
||||||
|
|
||||||
|
tm, ok := field.Interface().(encoding.TextUnmarshaler)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return tm
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseTextUnmarshalers(field reflect.Value, data []string, sf reflect.StructField) error {
|
||||||
|
s := len(data)
|
||||||
|
elemType := field.Type().Elem()
|
||||||
|
slice := reflect.MakeSlice(reflect.SliceOf(elemType), s, s)
|
||||||
|
for i, v := range data {
|
||||||
|
sv := slice.Index(i)
|
||||||
|
kind := sv.Kind()
|
||||||
|
if kind == reflect.Ptr {
|
||||||
|
sv = reflect.New(elemType.Elem())
|
||||||
|
} else {
|
||||||
|
sv = sv.Addr()
|
||||||
|
}
|
||||||
|
tm := sv.Interface().(encoding.TextUnmarshaler)
|
||||||
|
if err := tm.UnmarshalText([]byte(v)); err != nil {
|
||||||
|
return newParseError(sf, err)
|
||||||
|
}
|
||||||
|
if kind == reflect.Ptr {
|
||||||
|
slice.Index(i).Set(sv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
field.Set(slice)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func optsWithPrefix(field reflect.StructField, opts []Options) []Options {
|
||||||
|
subOpts := make([]Options, len(opts))
|
||||||
|
copy(subOpts, opts)
|
||||||
|
if prefix := field.Tag.Get("envPrefix"); prefix != "" {
|
||||||
|
subOpts[0].Prefix += prefix
|
||||||
|
}
|
||||||
|
return subOpts
|
||||||
|
}
|
|
@ -0,0 +1,15 @@
|
||||||
|
//go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
|
||||||
|
// +build darwin dragonfly freebsd linux netbsd openbsd solaris
|
||||||
|
|
||||||
|
package env
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
func toMap(env []string) map[string]string {
|
||||||
|
r := map[string]string{}
|
||||||
|
for _, e := range env {
|
||||||
|
p := strings.SplitN(e, "=", 2)
|
||||||
|
r[p[0]] = p[1]
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
package env
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
func toMap(env []string) map[string]string {
|
||||||
|
r := map[string]string{}
|
||||||
|
for _, e := range env {
|
||||||
|
p := strings.SplitN(e, "=", 2)
|
||||||
|
|
||||||
|
// On Windows, environment variables can start with '='. If so, Split at next character.
|
||||||
|
// See env_windows.go in the Go source: https://github.com/golang/go/blob/master/src/syscall/env_windows.go#L58
|
||||||
|
prefixEqualSign := false
|
||||||
|
if len(e) > 0 && e[0] == '=' {
|
||||||
|
e = e[1:]
|
||||||
|
prefixEqualSign = true
|
||||||
|
}
|
||||||
|
p = strings.SplitN(e, "=", 2)
|
||||||
|
if prefixEqualSign {
|
||||||
|
p[0] = "=" + p[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
r[p[0]] = p[1]
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
|
@ -0,0 +1,164 @@
|
||||||
|
package env
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// An aggregated error wrapper to combine gathered errors. This allows either to display all errors or convert them individually
|
||||||
|
// List of the available errors
|
||||||
|
// ParseError
|
||||||
|
// NotStructPtrError
|
||||||
|
// NoParserError
|
||||||
|
// NoSupportedTagOptionError
|
||||||
|
// EnvVarIsNotSetError
|
||||||
|
// EmptyEnvVarError
|
||||||
|
// LoadFileContentError
|
||||||
|
// ParseValueError
|
||||||
|
type AggregateError struct {
|
||||||
|
Errors []error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAggregateError(initErr error) error {
|
||||||
|
return AggregateError{
|
||||||
|
[]error{
|
||||||
|
initErr,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e AggregateError) Error() string {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
sb.WriteString("env:")
|
||||||
|
|
||||||
|
for _, err := range e.Errors {
|
||||||
|
sb.WriteString(fmt.Sprintf(" %v;", err.Error()))
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.TrimRight(sb.String(), ";")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Is conforms with errors.Is.
|
||||||
|
func (e AggregateError) Is(err error) bool {
|
||||||
|
for _, ie := range e.Errors {
|
||||||
|
if reflect.TypeOf(ie) == reflect.TypeOf(err) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// The error occurs when it's impossible to convert the value for given type.
|
||||||
|
type ParseError struct {
|
||||||
|
Name string
|
||||||
|
Type reflect.Type
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newParseError(sf reflect.StructField, err error) error {
|
||||||
|
return ParseError{sf.Name, sf.Type, err}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ParseError) Error() string {
|
||||||
|
return fmt.Sprintf(`parse error on field "%s" of type "%s": %v`, e.Name, e.Type, e.Err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The error occurs when pass something that is not a pointer to a Struct to Parse
|
||||||
|
type NotStructPtrError struct{}
|
||||||
|
|
||||||
|
func (e NotStructPtrError) Error() string {
|
||||||
|
return "expected a pointer to a Struct"
|
||||||
|
}
|
||||||
|
|
||||||
|
// This error occurs when there is no parser provided for given type
|
||||||
|
// Supported types and defaults: https://github.com/caarlos0/env#supported-types-and-defaults
|
||||||
|
// How to create a custom parser: https://github.com/caarlos0/env#custom-parser-funcs
|
||||||
|
type NoParserError struct {
|
||||||
|
Name string
|
||||||
|
Type reflect.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
func newNoParserError(sf reflect.StructField) error {
|
||||||
|
return NoParserError{sf.Name, sf.Type}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e NoParserError) Error() string {
|
||||||
|
return fmt.Sprintf(`no parser found for field "%s" of type "%s"`, e.Name, e.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
// This error occurs when the given tag is not supported
|
||||||
|
// In-built supported tags: "", "file", "required", "unset", "notEmpty", "envDefault", "envExpand", "envSeparator"
|
||||||
|
// How to create a custom tag: https://github.com/caarlos0/env#changing-default-tag-name
|
||||||
|
type NoSupportedTagOptionError struct {
|
||||||
|
Tag string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newNoSupportedTagOptionError(tag string) error {
|
||||||
|
return NoSupportedTagOptionError{tag}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e NoSupportedTagOptionError) Error() string {
|
||||||
|
return fmt.Sprintf("tag option %q not supported", e.Tag)
|
||||||
|
}
|
||||||
|
|
||||||
|
// This error occurs when the required variable is not set
|
||||||
|
// Read about required fields: https://github.com/caarlos0/env#required-fields
|
||||||
|
type EnvVarIsNotSetError struct {
|
||||||
|
Key string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newEnvVarIsNotSet(key string) error {
|
||||||
|
return EnvVarIsNotSetError{key}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e EnvVarIsNotSetError) Error() string {
|
||||||
|
return fmt.Sprintf(`required environment variable %q is not set`, e.Key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// This error occurs when the variable which must be not empty is existing but has an empty value
|
||||||
|
// Read about not empty fields: https://github.com/caarlos0/env#not-empty-fields
|
||||||
|
type EmptyEnvVarError struct {
|
||||||
|
Key string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newEmptyEnvVarError(key string) error {
|
||||||
|
return EmptyEnvVarError{key}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e EmptyEnvVarError) Error() string {
|
||||||
|
return fmt.Sprintf("environment variable %q should not be empty", e.Key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// This error occurs when it's impossible to load the value from the file
|
||||||
|
// Read about From file feature: https://github.com/caarlos0/env#from-file
|
||||||
|
type LoadFileContentError struct {
|
||||||
|
Filename string
|
||||||
|
Key string
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLoadFileContentError(filename, key string, err error) error {
|
||||||
|
return LoadFileContentError{filename, key, err}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e LoadFileContentError) Error() string {
|
||||||
|
return fmt.Sprintf(`could not load content of file "%s" from variable %s: %v`, e.Filename, e.Key, e.Err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// This error occurs when it's impossible to convert value using given parser
|
||||||
|
// Supported types and defaults: https://github.com/caarlos0/env#supported-types-and-defaults
|
||||||
|
// How to create a custom parser: https://github.com/caarlos0/env#custom-parser-funcs
|
||||||
|
type ParseValueError struct {
|
||||||
|
Msg string
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newParseValueError(message string, err error) error {
|
||||||
|
return ParseValueError{message, err}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ParseValueError) Error() string {
|
||||||
|
return fmt.Sprintf("%s: %v", e.Msg, e.Err)
|
||||||
|
}
|
|
@ -0,0 +1,21 @@
|
||||||
|
sudo: false
|
||||||
|
language: go
|
||||||
|
go_import_path: github.com/dustin/go-humanize
|
||||||
|
go:
|
||||||
|
- 1.13.x
|
||||||
|
- 1.14.x
|
||||||
|
- 1.15.x
|
||||||
|
- 1.16.x
|
||||||
|
- stable
|
||||||
|
- master
|
||||||
|
matrix:
|
||||||
|
allow_failures:
|
||||||
|
- go: master
|
||||||
|
fast_finish: true
|
||||||
|
install:
|
||||||
|
- # Do nothing. This is needed to prevent default install action "go get -t -v ./..." from happening here (we want it to happen inside script step).
|
||||||
|
script:
|
||||||
|
- diff -u <(echo -n) <(gofmt -d -s .)
|
||||||
|
- go vet .
|
||||||
|
- go install -v -race ./...
|
||||||
|
- go test -v -race ./...
|
|
@ -0,0 +1,21 @@
|
||||||
|
Copyright (c) 2005-2008 Dustin Sallings <dustin@spy.net>
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in
|
||||||
|
all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
|
||||||
|
<http://www.opensource.org/licenses/mit-license.php>
|
|
@ -0,0 +1,124 @@
|
||||||
|
# Humane Units [![Build Status](https://travis-ci.org/dustin/go-humanize.svg?branch=master)](https://travis-ci.org/dustin/go-humanize) [![GoDoc](https://godoc.org/github.com/dustin/go-humanize?status.svg)](https://godoc.org/github.com/dustin/go-humanize)
|
||||||
|
|
||||||
|
Just a few functions for helping humanize times and sizes.
|
||||||
|
|
||||||
|
`go get` it as `github.com/dustin/go-humanize`, import it as
|
||||||
|
`"github.com/dustin/go-humanize"`, use it as `humanize`.
|
||||||
|
|
||||||
|
See [godoc](https://pkg.go.dev/github.com/dustin/go-humanize) for
|
||||||
|
complete documentation.
|
||||||
|
|
||||||
|
## Sizes
|
||||||
|
|
||||||
|
This lets you take numbers like `82854982` and convert them to useful
|
||||||
|
strings like, `83 MB` or `79 MiB` (whichever you prefer).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
fmt.Printf("That file is %s.", humanize.Bytes(82854982)) // That file is 83 MB.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Times
|
||||||
|
|
||||||
|
This lets you take a `time.Time` and spit it out in relative terms.
|
||||||
|
For example, `12 seconds ago` or `3 days from now`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
fmt.Printf("This was touched %s.", humanize.Time(someTimeInstance)) // This was touched 7 hours ago.
|
||||||
|
```
|
||||||
|
|
||||||
|
Thanks to Kyle Lemons for the time implementation from an IRC
|
||||||
|
conversation one day. It's pretty neat.
|
||||||
|
|
||||||
|
## Ordinals
|
||||||
|
|
||||||
|
From a [mailing list discussion][odisc] where a user wanted to be able
|
||||||
|
to label ordinals.
|
||||||
|
|
||||||
|
0 -> 0th
|
||||||
|
1 -> 1st
|
||||||
|
2 -> 2nd
|
||||||
|
3 -> 3rd
|
||||||
|
4 -> 4th
|
||||||
|
[...]
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
fmt.Printf("You're my %s best friend.", humanize.Ordinal(193)) // You are my 193rd best friend.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Commas
|
||||||
|
|
||||||
|
Want to shove commas into numbers? Be my guest.
|
||||||
|
|
||||||
|
0 -> 0
|
||||||
|
100 -> 100
|
||||||
|
1000 -> 1,000
|
||||||
|
1000000000 -> 1,000,000,000
|
||||||
|
-100000 -> -100,000
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
fmt.Printf("You owe $%s.\n", humanize.Comma(6582491)) // You owe $6,582,491.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Ftoa
|
||||||
|
|
||||||
|
Nicer float64 formatter that removes trailing zeros.
|
||||||
|
|
||||||
|
```go
|
||||||
|
fmt.Printf("%f", 2.24) // 2.240000
|
||||||
|
fmt.Printf("%s", humanize.Ftoa(2.24)) // 2.24
|
||||||
|
fmt.Printf("%f", 2.0) // 2.000000
|
||||||
|
fmt.Printf("%s", humanize.Ftoa(2.0)) // 2
|
||||||
|
```
|
||||||
|
|
||||||
|
## SI notation
|
||||||
|
|
||||||
|
Format numbers with [SI notation][sinotation].
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
humanize.SI(0.00000000223, "M") // 2.23 nM
|
||||||
|
```
|
||||||
|
|
||||||
|
## English-specific functions
|
||||||
|
|
||||||
|
The following functions are in the `humanize/english` subpackage.
|
||||||
|
|
||||||
|
### Plurals
|
||||||
|
|
||||||
|
Simple English pluralization
|
||||||
|
|
||||||
|
```go
|
||||||
|
english.PluralWord(1, "object", "") // object
|
||||||
|
english.PluralWord(42, "object", "") // objects
|
||||||
|
english.PluralWord(2, "bus", "") // buses
|
||||||
|
english.PluralWord(99, "locus", "loci") // loci
|
||||||
|
|
||||||
|
english.Plural(1, "object", "") // 1 object
|
||||||
|
english.Plural(42, "object", "") // 42 objects
|
||||||
|
english.Plural(2, "bus", "") // 2 buses
|
||||||
|
english.Plural(99, "locus", "loci") // 99 loci
|
||||||
|
```
|
||||||
|
|
||||||
|
### Word series
|
||||||
|
|
||||||
|
Format comma-separated words lists with conjuctions:
|
||||||
|
|
||||||
|
```go
|
||||||
|
english.WordSeries([]string{"foo"}, "and") // foo
|
||||||
|
english.WordSeries([]string{"foo", "bar"}, "and") // foo and bar
|
||||||
|
english.WordSeries([]string{"foo", "bar", "baz"}, "and") // foo, bar and baz
|
||||||
|
|
||||||
|
english.OxfordWordSeries([]string{"foo", "bar", "baz"}, "and") // foo, bar, and baz
|
||||||
|
```
|
||||||
|
|
||||||
|
[odisc]: https://groups.google.com/d/topic/golang-nuts/l8NhI74jl-4/discussion
|
||||||
|
[sinotation]: http://en.wikipedia.org/wiki/Metric_prefix
|
|
@ -0,0 +1,31 @@
|
||||||
|
package humanize
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/big"
|
||||||
|
)
|
||||||
|
|
||||||
|
// order of magnitude (to a max order)
|
||||||
|
func oomm(n, b *big.Int, maxmag int) (float64, int) {
|
||||||
|
mag := 0
|
||||||
|
m := &big.Int{}
|
||||||
|
for n.Cmp(b) >= 0 {
|
||||||
|
n.DivMod(n, b, m)
|
||||||
|
mag++
|
||||||
|
if mag == maxmag && maxmag >= 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return float64(n.Int64()) + (float64(m.Int64()) / float64(b.Int64())), mag
|
||||||
|
}
|
||||||
|
|
||||||
|
// total order of magnitude
|
||||||
|
// (same as above, but with no upper limit)
|
||||||
|
func oom(n, b *big.Int) (float64, int) {
|
||||||
|
mag := 0
|
||||||
|
m := &big.Int{}
|
||||||
|
for n.Cmp(b) >= 0 {
|
||||||
|
n.DivMod(n, b, m)
|
||||||
|
mag++
|
||||||
|
}
|
||||||
|
return float64(n.Int64()) + (float64(m.Int64()) / float64(b.Int64())), mag
|
||||||
|
}
|
|
@ -0,0 +1,189 @@
|
||||||
|
package humanize
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
bigIECExp = big.NewInt(1024)
|
||||||
|
|
||||||
|
// BigByte is one byte in bit.Ints
|
||||||
|
BigByte = big.NewInt(1)
|
||||||
|
// BigKiByte is 1,024 bytes in bit.Ints
|
||||||
|
BigKiByte = (&big.Int{}).Mul(BigByte, bigIECExp)
|
||||||
|
// BigMiByte is 1,024 k bytes in bit.Ints
|
||||||
|
BigMiByte = (&big.Int{}).Mul(BigKiByte, bigIECExp)
|
||||||
|
// BigGiByte is 1,024 m bytes in bit.Ints
|
||||||
|
BigGiByte = (&big.Int{}).Mul(BigMiByte, bigIECExp)
|
||||||
|
// BigTiByte is 1,024 g bytes in bit.Ints
|
||||||
|
BigTiByte = (&big.Int{}).Mul(BigGiByte, bigIECExp)
|
||||||
|
// BigPiByte is 1,024 t bytes in bit.Ints
|
||||||
|
BigPiByte = (&big.Int{}).Mul(BigTiByte, bigIECExp)
|
||||||
|
// BigEiByte is 1,024 p bytes in bit.Ints
|
||||||
|
BigEiByte = (&big.Int{}).Mul(BigPiByte, bigIECExp)
|
||||||
|
// BigZiByte is 1,024 e bytes in bit.Ints
|
||||||
|
BigZiByte = (&big.Int{}).Mul(BigEiByte, bigIECExp)
|
||||||
|
// BigYiByte is 1,024 z bytes in bit.Ints
|
||||||
|
BigYiByte = (&big.Int{}).Mul(BigZiByte, bigIECExp)
|
||||||
|
// BigRiByte is 1,024 y bytes in bit.Ints
|
||||||
|
BigRiByte = (&big.Int{}).Mul(BigYiByte, bigIECExp)
|
||||||
|
// BigQiByte is 1,024 r bytes in bit.Ints
|
||||||
|
BigQiByte = (&big.Int{}).Mul(BigRiByte, bigIECExp)
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
bigSIExp = big.NewInt(1000)
|
||||||
|
|
||||||
|
// BigSIByte is one SI byte in big.Ints
|
||||||
|
BigSIByte = big.NewInt(1)
|
||||||
|
// BigKByte is 1,000 SI bytes in big.Ints
|
||||||
|
BigKByte = (&big.Int{}).Mul(BigSIByte, bigSIExp)
|
||||||
|
// BigMByte is 1,000 SI k bytes in big.Ints
|
||||||
|
BigMByte = (&big.Int{}).Mul(BigKByte, bigSIExp)
|
||||||
|
// BigGByte is 1,000 SI m bytes in big.Ints
|
||||||
|
BigGByte = (&big.Int{}).Mul(BigMByte, bigSIExp)
|
||||||
|
// BigTByte is 1,000 SI g bytes in big.Ints
|
||||||
|
BigTByte = (&big.Int{}).Mul(BigGByte, bigSIExp)
|
||||||
|
// BigPByte is 1,000 SI t bytes in big.Ints
|
||||||
|
BigPByte = (&big.Int{}).Mul(BigTByte, bigSIExp)
|
||||||
|
// BigEByte is 1,000 SI p bytes in big.Ints
|
||||||
|
BigEByte = (&big.Int{}).Mul(BigPByte, bigSIExp)
|
||||||
|
// BigZByte is 1,000 SI e bytes in big.Ints
|
||||||
|
BigZByte = (&big.Int{}).Mul(BigEByte, bigSIExp)
|
||||||
|
// BigYByte is 1,000 SI z bytes in big.Ints
|
||||||
|
BigYByte = (&big.Int{}).Mul(BigZByte, bigSIExp)
|
||||||
|
// BigRByte is 1,000 SI y bytes in big.Ints
|
||||||
|
BigRByte = (&big.Int{}).Mul(BigYByte, bigSIExp)
|
||||||
|
// BigQByte is 1,000 SI r bytes in big.Ints
|
||||||
|
BigQByte = (&big.Int{}).Mul(BigRByte, bigSIExp)
|
||||||
|
)
|
||||||
|
|
||||||
|
var bigBytesSizeTable = map[string]*big.Int{
|
||||||
|
"b": BigByte,
|
||||||
|
"kib": BigKiByte,
|
||||||
|
"kb": BigKByte,
|
||||||
|
"mib": BigMiByte,
|
||||||
|
"mb": BigMByte,
|
||||||
|
"gib": BigGiByte,
|
||||||
|
"gb": BigGByte,
|
||||||
|
"tib": BigTiByte,
|
||||||
|
"tb": BigTByte,
|
||||||
|
"pib": BigPiByte,
|
||||||
|
"pb": BigPByte,
|
||||||
|
"eib": BigEiByte,
|
||||||
|
"eb": BigEByte,
|
||||||
|
"zib": BigZiByte,
|
||||||
|
"zb": BigZByte,
|
||||||
|
"yib": BigYiByte,
|
||||||
|
"yb": BigYByte,
|
||||||
|
"rib": BigRiByte,
|
||||||
|
"rb": BigRByte,
|
||||||
|
"qib": BigQiByte,
|
||||||
|
"qb": BigQByte,
|
||||||
|
// Without suffix
|
||||||
|
"": BigByte,
|
||||||
|
"ki": BigKiByte,
|
||||||
|
"k": BigKByte,
|
||||||
|
"mi": BigMiByte,
|
||||||
|
"m": BigMByte,
|
||||||
|
"gi": BigGiByte,
|
||||||
|
"g": BigGByte,
|
||||||
|
"ti": BigTiByte,
|
||||||
|
"t": BigTByte,
|
||||||
|
"pi": BigPiByte,
|
||||||
|
"p": BigPByte,
|
||||||
|
"ei": BigEiByte,
|
||||||
|
"e": BigEByte,
|
||||||
|
"z": BigZByte,
|
||||||
|
"zi": BigZiByte,
|
||||||
|
"y": BigYByte,
|
||||||
|
"yi": BigYiByte,
|
||||||
|
"r": BigRByte,
|
||||||
|
"ri": BigRiByte,
|
||||||
|
"q": BigQByte,
|
||||||
|
"qi": BigQiByte,
|
||||||
|
}
|
||||||
|
|
||||||
|
var ten = big.NewInt(10)
|
||||||
|
|
||||||
|
func humanateBigBytes(s, base *big.Int, sizes []string) string {
|
||||||
|
if s.Cmp(ten) < 0 {
|
||||||
|
return fmt.Sprintf("%d B", s)
|
||||||
|
}
|
||||||
|
c := (&big.Int{}).Set(s)
|
||||||
|
val, mag := oomm(c, base, len(sizes)-1)
|
||||||
|
suffix := sizes[mag]
|
||||||
|
f := "%.0f %s"
|
||||||
|
if val < 10 {
|
||||||
|
f = "%.1f %s"
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf(f, val, suffix)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// BigBytes produces a human readable representation of an SI size.
|
||||||
|
//
|
||||||
|
// See also: ParseBigBytes.
|
||||||
|
//
|
||||||
|
// BigBytes(82854982) -> 83 MB
|
||||||
|
func BigBytes(s *big.Int) string {
|
||||||
|
sizes := []string{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB", "RB", "QB"}
|
||||||
|
return humanateBigBytes(s, bigSIExp, sizes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BigIBytes produces a human readable representation of an IEC size.
|
||||||
|
//
|
||||||
|
// See also: ParseBigBytes.
|
||||||
|
//
|
||||||
|
// BigIBytes(82854982) -> 79 MiB
|
||||||
|
func BigIBytes(s *big.Int) string {
|
||||||
|
sizes := []string{"B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB", "RiB", "QiB"}
|
||||||
|
return humanateBigBytes(s, bigIECExp, sizes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseBigBytes parses a string representation of bytes into the number
|
||||||
|
// of bytes it represents.
|
||||||
|
//
|
||||||
|
// See also: BigBytes, BigIBytes.
|
||||||
|
//
|
||||||
|
// ParseBigBytes("42 MB") -> 42000000, nil
|
||||||
|
// ParseBigBytes("42 mib") -> 44040192, nil
|
||||||
|
func ParseBigBytes(s string) (*big.Int, error) {
|
||||||
|
lastDigit := 0
|
||||||
|
hasComma := false
|
||||||
|
for _, r := range s {
|
||||||
|
if !(unicode.IsDigit(r) || r == '.' || r == ',') {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if r == ',' {
|
||||||
|
hasComma = true
|
||||||
|
}
|
||||||
|
lastDigit++
|
||||||
|
}
|
||||||
|
|
||||||
|
num := s[:lastDigit]
|
||||||
|
if hasComma {
|
||||||
|
num = strings.Replace(num, ",", "", -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
val := &big.Rat{}
|
||||||
|
_, err := fmt.Sscanf(num, "%f", val)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
extra := strings.ToLower(strings.TrimSpace(s[lastDigit:]))
|
||||||
|
if m, ok := bigBytesSizeTable[extra]; ok {
|
||||||
|
mv := (&big.Rat{}).SetInt(m)
|
||||||
|
val.Mul(val, mv)
|
||||||
|
rv := &big.Int{}
|
||||||
|
rv.Div(val.Num(), val.Denom())
|
||||||
|
return rv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unhandled size name: %v", extra)
|
||||||
|
}
|
|
@ -0,0 +1,143 @@
|
||||||
|
package humanize
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IEC Sizes.
|
||||||
|
// kibis of bits
|
||||||
|
const (
|
||||||
|
Byte = 1 << (iota * 10)
|
||||||
|
KiByte
|
||||||
|
MiByte
|
||||||
|
GiByte
|
||||||
|
TiByte
|
||||||
|
PiByte
|
||||||
|
EiByte
|
||||||
|
)
|
||||||
|
|
||||||
|
// SI Sizes.
|
||||||
|
const (
|
||||||
|
IByte = 1
|
||||||
|
KByte = IByte * 1000
|
||||||
|
MByte = KByte * 1000
|
||||||
|
GByte = MByte * 1000
|
||||||
|
TByte = GByte * 1000
|
||||||
|
PByte = TByte * 1000
|
||||||
|
EByte = PByte * 1000
|
||||||
|
)
|
||||||
|
|
||||||
|
var bytesSizeTable = map[string]uint64{
|
||||||
|
"b": Byte,
|
||||||
|
"kib": KiByte,
|
||||||
|
"kb": KByte,
|
||||||
|
"mib": MiByte,
|
||||||
|
"mb": MByte,
|
||||||
|
"gib": GiByte,
|
||||||
|
"gb": GByte,
|
||||||
|
"tib": TiByte,
|
||||||
|
"tb": TByte,
|
||||||
|
"pib": PiByte,
|
||||||
|
"pb": PByte,
|
||||||
|
"eib": EiByte,
|
||||||
|
"eb": EByte,
|
||||||
|
// Without suffix
|
||||||
|
"": Byte,
|
||||||
|
"ki": KiByte,
|
||||||
|
"k": KByte,
|
||||||
|
"mi": MiByte,
|
||||||
|
"m": MByte,
|
||||||
|
"gi": GiByte,
|
||||||
|
"g": GByte,
|
||||||
|
"ti": TiByte,
|
||||||
|
"t": TByte,
|
||||||
|
"pi": PiByte,
|
||||||
|
"p": PByte,
|
||||||
|
"ei": EiByte,
|
||||||
|
"e": EByte,
|
||||||
|
}
|
||||||
|
|
||||||
|
func logn(n, b float64) float64 {
|
||||||
|
return math.Log(n) / math.Log(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func humanateBytes(s uint64, base float64, sizes []string) string {
|
||||||
|
if s < 10 {
|
||||||
|
return fmt.Sprintf("%d B", s)
|
||||||
|
}
|
||||||
|
e := math.Floor(logn(float64(s), base))
|
||||||
|
suffix := sizes[int(e)]
|
||||||
|
val := math.Floor(float64(s)/math.Pow(base, e)*10+0.5) / 10
|
||||||
|
f := "%.0f %s"
|
||||||
|
if val < 10 {
|
||||||
|
f = "%.1f %s"
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf(f, val, suffix)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bytes produces a human readable representation of an SI size.
|
||||||
|
//
|
||||||
|
// See also: ParseBytes.
|
||||||
|
//
|
||||||
|
// Bytes(82854982) -> 83 MB
|
||||||
|
func Bytes(s uint64) string {
|
||||||
|
sizes := []string{"B", "kB", "MB", "GB", "TB", "PB", "EB"}
|
||||||
|
return humanateBytes(s, 1000, sizes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IBytes produces a human readable representation of an IEC size.
|
||||||
|
//
|
||||||
|
// See also: ParseBytes.
|
||||||
|
//
|
||||||
|
// IBytes(82854982) -> 79 MiB
|
||||||
|
func IBytes(s uint64) string {
|
||||||
|
sizes := []string{"B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB"}
|
||||||
|
return humanateBytes(s, 1024, sizes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseBytes parses a string representation of bytes into the number
|
||||||
|
// of bytes it represents.
|
||||||
|
//
|
||||||
|
// See Also: Bytes, IBytes.
|
||||||
|
//
|
||||||
|
// ParseBytes("42 MB") -> 42000000, nil
|
||||||
|
// ParseBytes("42 mib") -> 44040192, nil
|
||||||
|
func ParseBytes(s string) (uint64, error) {
|
||||||
|
lastDigit := 0
|
||||||
|
hasComma := false
|
||||||
|
for _, r := range s {
|
||||||
|
if !(unicode.IsDigit(r) || r == '.' || r == ',') {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if r == ',' {
|
||||||
|
hasComma = true
|
||||||
|
}
|
||||||
|
lastDigit++
|
||||||
|
}
|
||||||
|
|
||||||
|
num := s[:lastDigit]
|
||||||
|
if hasComma {
|
||||||
|
num = strings.Replace(num, ",", "", -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := strconv.ParseFloat(num, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
extra := strings.ToLower(strings.TrimSpace(s[lastDigit:]))
|
||||||
|
if m, ok := bytesSizeTable[extra]; ok {
|
||||||
|
f *= float64(m)
|
||||||
|
if f >= math.MaxUint64 {
|
||||||
|
return 0, fmt.Errorf("too large: %v", s)
|
||||||
|
}
|
||||||
|
return uint64(f), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, fmt.Errorf("unhandled size name: %v", extra)
|
||||||
|
}
|
|
@ -0,0 +1,116 @@
|
||||||
|
package humanize
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"math"
|
||||||
|
"math/big"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Comma produces a string form of the given number in base 10 with
|
||||||
|
// commas after every three orders of magnitude.
|
||||||
|
//
|
||||||
|
// e.g. Comma(834142) -> 834,142
|
||||||
|
func Comma(v int64) string {
|
||||||
|
sign := ""
|
||||||
|
|
||||||
|
// Min int64 can't be negated to a usable value, so it has to be special cased.
|
||||||
|
if v == math.MinInt64 {
|
||||||
|
return "-9,223,372,036,854,775,808"
|
||||||
|
}
|
||||||
|
|
||||||
|
if v < 0 {
|
||||||
|
sign = "-"
|
||||||
|
v = 0 - v
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := []string{"", "", "", "", "", "", ""}
|
||||||
|
j := len(parts) - 1
|
||||||
|
|
||||||
|
for v > 999 {
|
||||||
|
parts[j] = strconv.FormatInt(v%1000, 10)
|
||||||
|
switch len(parts[j]) {
|
||||||
|
case 2:
|
||||||
|
parts[j] = "0" + parts[j]
|
||||||
|
case 1:
|
||||||
|
parts[j] = "00" + parts[j]
|
||||||
|
}
|
||||||
|
v = v / 1000
|
||||||
|
j--
|
||||||
|
}
|
||||||
|
parts[j] = strconv.Itoa(int(v))
|
||||||
|
return sign + strings.Join(parts[j:], ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commaf produces a string form of the given number in base 10 with
|
||||||
|
// commas after every three orders of magnitude.
|
||||||
|
//
|
||||||
|
// e.g. Commaf(834142.32) -> 834,142.32
|
||||||
|
func Commaf(v float64) string {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
if v < 0 {
|
||||||
|
buf.Write([]byte{'-'})
|
||||||
|
v = 0 - v
|
||||||
|
}
|
||||||
|
|
||||||
|
comma := []byte{','}
|
||||||
|
|
||||||
|
parts := strings.Split(strconv.FormatFloat(v, 'f', -1, 64), ".")
|
||||||
|
pos := 0
|
||||||
|
if len(parts[0])%3 != 0 {
|
||||||
|
pos += len(parts[0]) % 3
|
||||||
|
buf.WriteString(parts[0][:pos])
|
||||||
|
buf.Write(comma)
|
||||||
|
}
|
||||||
|
for ; pos < len(parts[0]); pos += 3 {
|
||||||
|
buf.WriteString(parts[0][pos : pos+3])
|
||||||
|
buf.Write(comma)
|
||||||
|
}
|
||||||
|
buf.Truncate(buf.Len() - 1)
|
||||||
|
|
||||||
|
if len(parts) > 1 {
|
||||||
|
buf.Write([]byte{'.'})
|
||||||
|
buf.WriteString(parts[1])
|
||||||
|
}
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// CommafWithDigits works like the Commaf but limits the resulting
|
||||||
|
// string to the given number of decimal places.
|
||||||
|
//
|
||||||
|
// e.g. CommafWithDigits(834142.32, 1) -> 834,142.3
|
||||||
|
func CommafWithDigits(f float64, decimals int) string {
|
||||||
|
return stripTrailingDigits(Commaf(f), decimals)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BigComma produces a string form of the given big.Int in base 10
|
||||||
|
// with commas after every three orders of magnitude.
|
||||||
|
func BigComma(b *big.Int) string {
|
||||||
|
sign := ""
|
||||||
|
if b.Sign() < 0 {
|
||||||
|
sign = "-"
|
||||||
|
b.Abs(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
athousand := big.NewInt(1000)
|
||||||
|
c := (&big.Int{}).Set(b)
|
||||||
|
_, m := oom(c, athousand)
|
||||||
|
parts := make([]string, m+1)
|
||||||
|
j := len(parts) - 1
|
||||||
|
|
||||||
|
mod := &big.Int{}
|
||||||
|
for b.Cmp(athousand) >= 0 {
|
||||||
|
b.DivMod(b, athousand, mod)
|
||||||
|
parts[j] = strconv.FormatInt(mod.Int64(), 10)
|
||||||
|
switch len(parts[j]) {
|
||||||
|
case 2:
|
||||||
|
parts[j] = "0" + parts[j]
|
||||||
|
case 1:
|
||||||
|
parts[j] = "00" + parts[j]
|
||||||
|
}
|
||||||
|
j--
|
||||||
|
}
|
||||||
|
parts[j] = strconv.Itoa(int(b.Int64()))
|
||||||
|
return sign + strings.Join(parts[j:], ",")
|
||||||
|
}
|
|
@ -0,0 +1,41 @@
|
||||||
|
//go:build go1.6
|
||||||
|
// +build go1.6
|
||||||
|
|
||||||
|
package humanize
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"math/big"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BigCommaf produces a string form of the given big.Float in base 10
|
||||||
|
// with commas after every three orders of magnitude.
|
||||||
|
func BigCommaf(v *big.Float) string {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
if v.Sign() < 0 {
|
||||||
|
buf.Write([]byte{'-'})
|
||||||
|
v.Abs(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
comma := []byte{','}
|
||||||
|
|
||||||
|
parts := strings.Split(v.Text('f', -1), ".")
|
||||||
|
pos := 0
|
||||||
|
if len(parts[0])%3 != 0 {
|
||||||
|
pos += len(parts[0]) % 3
|
||||||
|
buf.WriteString(parts[0][:pos])
|
||||||
|
buf.Write(comma)
|
||||||
|
}
|
||||||
|
for ; pos < len(parts[0]); pos += 3 {
|
||||||
|
buf.WriteString(parts[0][pos : pos+3])
|
||||||
|
buf.Write(comma)
|
||||||
|
}
|
||||||
|
buf.Truncate(buf.Len() - 1)
|
||||||
|
|
||||||
|
if len(parts) > 1 {
|
||||||
|
buf.Write([]byte{'.'})
|
||||||
|
buf.WriteString(parts[1])
|
||||||
|
}
|
||||||
|
return buf.String()
|
||||||
|
}
|
|
@ -0,0 +1,49 @@
|
||||||
|
package humanize
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func stripTrailingZeros(s string) string {
|
||||||
|
if !strings.ContainsRune(s, '.') {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
offset := len(s) - 1
|
||||||
|
for offset > 0 {
|
||||||
|
if s[offset] == '.' {
|
||||||
|
offset--
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if s[offset] != '0' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
offset--
|
||||||
|
}
|
||||||
|
return s[:offset+1]
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripTrailingDigits(s string, digits int) string {
|
||||||
|
if i := strings.Index(s, "."); i >= 0 {
|
||||||
|
if digits <= 0 {
|
||||||
|
return s[:i]
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
if i+digits >= len(s) {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return s[:i+digits]
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ftoa converts a float to a string with no trailing zeros.
|
||||||
|
func Ftoa(num float64) string {
|
||||||
|
return stripTrailingZeros(strconv.FormatFloat(num, 'f', 6, 64))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FtoaWithDigits converts a float to a string but limits the resulting string
|
||||||
|
// to the given number of decimal places, and no trailing zeros.
|
||||||
|
func FtoaWithDigits(num float64, digits int) string {
|
||||||
|
return stripTrailingZeros(stripTrailingDigits(strconv.FormatFloat(num, 'f', 6, 64), digits))
|
||||||
|
}
|
|
@ -0,0 +1,8 @@
|
||||||
|
/*
|
||||||
|
Package humanize converts boring ugly numbers to human-friendly strings and back.
|
||||||
|
|
||||||
|
Durations can be turned into strings such as "3 days ago", numbers
|
||||||
|
representing sizes like 82854982 into useful strings like, "83 MB" or
|
||||||
|
"79 MiB" (whichever you prefer).
|
||||||
|
*/
|
||||||
|
package humanize
|
|
@ -0,0 +1,192 @@
|
||||||
|
package humanize
|
||||||
|
|
||||||
|
/*
|
||||||
|
Slightly adapted from the source to fit go-humanize.
|
||||||
|
|
||||||
|
Author: https://github.com/gorhill
|
||||||
|
Source: https://gist.github.com/gorhill/5285193
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
renderFloatPrecisionMultipliers = [...]float64{
|
||||||
|
1,
|
||||||
|
10,
|
||||||
|
100,
|
||||||
|
1000,
|
||||||
|
10000,
|
||||||
|
100000,
|
||||||
|
1000000,
|
||||||
|
10000000,
|
||||||
|
100000000,
|
||||||
|
1000000000,
|
||||||
|
}
|
||||||
|
|
||||||
|
renderFloatPrecisionRounders = [...]float64{
|
||||||
|
0.5,
|
||||||
|
0.05,
|
||||||
|
0.005,
|
||||||
|
0.0005,
|
||||||
|
0.00005,
|
||||||
|
0.000005,
|
||||||
|
0.0000005,
|
||||||
|
0.00000005,
|
||||||
|
0.000000005,
|
||||||
|
0.0000000005,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// FormatFloat produces a formatted number as string based on the following user-specified criteria:
|
||||||
|
// * thousands separator
|
||||||
|
// * decimal separator
|
||||||
|
// * decimal precision
|
||||||
|
//
|
||||||
|
// Usage: s := RenderFloat(format, n)
|
||||||
|
// The format parameter tells how to render the number n.
|
||||||
|
//
|
||||||
|
// See examples: http://play.golang.org/p/LXc1Ddm1lJ
|
||||||
|
//
|
||||||
|
// Examples of format strings, given n = 12345.6789:
|
||||||
|
// "#,###.##" => "12,345.67"
|
||||||
|
// "#,###." => "12,345"
|
||||||
|
// "#,###" => "12345,678"
|
||||||
|
// "#\u202F###,##" => "12 345,68"
|
||||||
|
// "#.###,###### => 12.345,678900
|
||||||
|
// "" (aka default format) => 12,345.67
|
||||||
|
//
|
||||||
|
// The highest precision allowed is 9 digits after the decimal symbol.
|
||||||
|
// There is also a version for integer number, FormatInteger(),
|
||||||
|
// which is convenient for calls within template.
|
||||||
|
func FormatFloat(format string, n float64) string {
|
||||||
|
// Special cases:
|
||||||
|
// NaN = "NaN"
|
||||||
|
// +Inf = "+Infinity"
|
||||||
|
// -Inf = "-Infinity"
|
||||||
|
if math.IsNaN(n) {
|
||||||
|
return "NaN"
|
||||||
|
}
|
||||||
|
if n > math.MaxFloat64 {
|
||||||
|
return "Infinity"
|
||||||
|
}
|
||||||
|
if n < (0.0 - math.MaxFloat64) {
|
||||||
|
return "-Infinity"
|
||||||
|
}
|
||||||
|
|
||||||
|
// default format
|
||||||
|
precision := 2
|
||||||
|
decimalStr := "."
|
||||||
|
thousandStr := ","
|
||||||
|
positiveStr := ""
|
||||||
|
negativeStr := "-"
|
||||||
|
|
||||||
|
if len(format) > 0 {
|
||||||
|
format := []rune(format)
|
||||||
|
|
||||||
|
// If there is an explicit format directive,
|
||||||
|
// then default values are these:
|
||||||
|
precision = 9
|
||||||
|
thousandStr = ""
|
||||||
|
|
||||||
|
// collect indices of meaningful formatting directives
|
||||||
|
formatIndx := []int{}
|
||||||
|
for i, char := range format {
|
||||||
|
if char != '#' && char != '0' {
|
||||||
|
formatIndx = append(formatIndx, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(formatIndx) > 0 {
|
||||||
|
// Directive at index 0:
|
||||||
|
// Must be a '+'
|
||||||
|
// Raise an error if not the case
|
||||||
|
// index: 0123456789
|
||||||
|
// +0.000,000
|
||||||
|
// +000,000.0
|
||||||
|
// +0000.00
|
||||||
|
// +0000
|
||||||
|
if formatIndx[0] == 0 {
|
||||||
|
if format[formatIndx[0]] != '+' {
|
||||||
|
panic("RenderFloat(): invalid positive sign directive")
|
||||||
|
}
|
||||||
|
positiveStr = "+"
|
||||||
|
formatIndx = formatIndx[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Two directives:
|
||||||
|
// First is thousands separator
|
||||||
|
// Raise an error if not followed by 3-digit
|
||||||
|
// 0123456789
|
||||||
|
// 0.000,000
|
||||||
|
// 000,000.00
|
||||||
|
if len(formatIndx) == 2 {
|
||||||
|
if (formatIndx[1] - formatIndx[0]) != 4 {
|
||||||
|
panic("RenderFloat(): thousands separator directive must be followed by 3 digit-specifiers")
|
||||||
|
}
|
||||||
|
thousandStr = string(format[formatIndx[0]])
|
||||||
|
formatIndx = formatIndx[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// One directive:
|
||||||
|
// Directive is decimal separator
|
||||||
|
// The number of digit-specifier following the separator indicates wanted precision
|
||||||
|
// 0123456789
|
||||||
|
// 0.00
|
||||||
|
// 000,0000
|
||||||
|
if len(formatIndx) == 1 {
|
||||||
|
decimalStr = string(format[formatIndx[0]])
|
||||||
|
precision = len(format) - formatIndx[0] - 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generate sign part
|
||||||
|
var signStr string
|
||||||
|
if n >= 0.000000001 {
|
||||||
|
signStr = positiveStr
|
||||||
|
} else if n <= -0.000000001 {
|
||||||
|
signStr = negativeStr
|
||||||
|
n = -n
|
||||||
|
} else {
|
||||||
|
signStr = ""
|
||||||
|
n = 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
// split number into integer and fractional parts
|
||||||
|
intf, fracf := math.Modf(n + renderFloatPrecisionRounders[precision])
|
||||||
|
|
||||||
|
// generate integer part string
|
||||||
|
intStr := strconv.FormatInt(int64(intf), 10)
|
||||||
|
|
||||||
|
// add thousand separator if required
|
||||||
|
if len(thousandStr) > 0 {
|
||||||
|
for i := len(intStr); i > 3; {
|
||||||
|
i -= 3
|
||||||
|
intStr = intStr[:i] + thousandStr + intStr[i:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// no fractional part, we can leave now
|
||||||
|
if precision == 0 {
|
||||||
|
return signStr + intStr
|
||||||
|
}
|
||||||
|
|
||||||
|
// generate fractional part
|
||||||
|
fracStr := strconv.Itoa(int(fracf * renderFloatPrecisionMultipliers[precision]))
|
||||||
|
// may need padding
|
||||||
|
if len(fracStr) < precision {
|
||||||
|
fracStr = "000000000000000"[:precision-len(fracStr)] + fracStr
|
||||||
|
}
|
||||||
|
|
||||||
|
return signStr + intStr + decimalStr + fracStr
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatInteger produces a formatted number as string.
|
||||||
|
// See FormatFloat.
|
||||||
|
func FormatInteger(format string, n int) string {
|
||||||
|
return FormatFloat(format, float64(n))
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
package humanize
|
||||||
|
|
||||||
|
import "strconv"
|
||||||
|
|
||||||
|
// Ordinal gives you the input number in a rank/ordinal format.
|
||||||
|
//
|
||||||
|
// Ordinal(3) -> 3rd
|
||||||
|
func Ordinal(x int) string {
|
||||||
|
suffix := "th"
|
||||||
|
switch x % 10 {
|
||||||
|
case 1:
|
||||||
|
if x%100 != 11 {
|
||||||
|
suffix = "st"
|
||||||
|
}
|
||||||
|
case 2:
|
||||||
|
if x%100 != 12 {
|
||||||
|
suffix = "nd"
|
||||||
|
}
|
||||||
|
case 3:
|
||||||
|
if x%100 != 13 {
|
||||||
|
suffix = "rd"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strconv.Itoa(x) + suffix
|
||||||
|
}
|
|
@ -0,0 +1,127 @@
|
||||||
|
package humanize
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"math"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
var siPrefixTable = map[float64]string{
|
||||||
|
-30: "q", // quecto
|
||||||
|
-27: "r", // ronto
|
||||||
|
-24: "y", // yocto
|
||||||
|
-21: "z", // zepto
|
||||||
|
-18: "a", // atto
|
||||||
|
-15: "f", // femto
|
||||||
|
-12: "p", // pico
|
||||||
|
-9: "n", // nano
|
||||||
|
-6: "µ", // micro
|
||||||
|
-3: "m", // milli
|
||||||
|
0: "",
|
||||||
|
3: "k", // kilo
|
||||||
|
6: "M", // mega
|
||||||
|
9: "G", // giga
|
||||||
|
12: "T", // tera
|
||||||
|
15: "P", // peta
|
||||||
|
18: "E", // exa
|
||||||
|
21: "Z", // zetta
|
||||||
|
24: "Y", // yotta
|
||||||
|
27: "R", // ronna
|
||||||
|
30: "Q", // quetta
|
||||||
|
}
|
||||||
|
|
||||||
|
var revSIPrefixTable = revfmap(siPrefixTable)
|
||||||
|
|
||||||
|
// revfmap reverses the map and precomputes the power multiplier
|
||||||
|
func revfmap(in map[float64]string) map[string]float64 {
|
||||||
|
rv := map[string]float64{}
|
||||||
|
for k, v := range in {
|
||||||
|
rv[v] = math.Pow(10, k)
|
||||||
|
}
|
||||||
|
return rv
|
||||||
|
}
|
||||||
|
|
||||||
|
var riParseRegex *regexp.Regexp
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
ri := `^([\-0-9.]+)\s?([`
|
||||||
|
for _, v := range siPrefixTable {
|
||||||
|
ri += v
|
||||||
|
}
|
||||||
|
ri += `]?)(.*)`
|
||||||
|
|
||||||
|
riParseRegex = regexp.MustCompile(ri)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ComputeSI finds the most appropriate SI prefix for the given number
|
||||||
|
// and returns the prefix along with the value adjusted to be within
|
||||||
|
// that prefix.
|
||||||
|
//
|
||||||
|
// See also: SI, ParseSI.
|
||||||
|
//
|
||||||
|
// e.g. ComputeSI(2.2345e-12) -> (2.2345, "p")
|
||||||
|
func ComputeSI(input float64) (float64, string) {
|
||||||
|
if input == 0 {
|
||||||
|
return 0, ""
|
||||||
|
}
|
||||||
|
mag := math.Abs(input)
|
||||||
|
exponent := math.Floor(logn(mag, 10))
|
||||||
|
exponent = math.Floor(exponent/3) * 3
|
||||||
|
|
||||||
|
value := mag / math.Pow(10, exponent)
|
||||||
|
|
||||||
|
// Handle special case where value is exactly 1000.0
|
||||||
|
// Should return 1 M instead of 1000 k
|
||||||
|
if value == 1000.0 {
|
||||||
|
exponent += 3
|
||||||
|
value = mag / math.Pow(10, exponent)
|
||||||
|
}
|
||||||
|
|
||||||
|
value = math.Copysign(value, input)
|
||||||
|
|
||||||
|
prefix := siPrefixTable[exponent]
|
||||||
|
return value, prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
// SI returns a string with default formatting.
|
||||||
|
//
|
||||||
|
// SI uses Ftoa to format float value, removing trailing zeros.
|
||||||
|
//
|
||||||
|
// See also: ComputeSI, ParseSI.
|
||||||
|
//
|
||||||
|
// e.g. SI(1000000, "B") -> 1 MB
|
||||||
|
// e.g. SI(2.2345e-12, "F") -> 2.2345 pF
|
||||||
|
func SI(input float64, unit string) string {
|
||||||
|
value, prefix := ComputeSI(input)
|
||||||
|
return Ftoa(value) + " " + prefix + unit
|
||||||
|
}
|
||||||
|
|
||||||
|
// SIWithDigits works like SI but limits the resulting string to the
|
||||||
|
// given number of decimal places.
|
||||||
|
//
|
||||||
|
// e.g. SIWithDigits(1000000, 0, "B") -> 1 MB
|
||||||
|
// e.g. SIWithDigits(2.2345e-12, 2, "F") -> 2.23 pF
|
||||||
|
func SIWithDigits(input float64, decimals int, unit string) string {
|
||||||
|
value, prefix := ComputeSI(input)
|
||||||
|
return FtoaWithDigits(value, decimals) + " " + prefix + unit
|
||||||
|
}
|
||||||
|
|
||||||
|
var errInvalid = errors.New("invalid input")
|
||||||
|
|
||||||
|
// ParseSI parses an SI string back into the number and unit.
|
||||||
|
//
|
||||||
|
// See also: SI, ComputeSI.
|
||||||
|
//
|
||||||
|
// e.g. ParseSI("2.2345 pF") -> (2.2345e-12, "F", nil)
|
||||||
|
func ParseSI(input string) (float64, string, error) {
|
||||||
|
found := riParseRegex.FindStringSubmatch(input)
|
||||||
|
if len(found) != 4 {
|
||||||
|
return 0, "", errInvalid
|
||||||
|
}
|
||||||
|
mag := revSIPrefixTable[found[2]]
|
||||||
|
unit := found[3]
|
||||||
|
|
||||||
|
base, err := strconv.ParseFloat(found[1], 64)
|
||||||
|
return base * mag, unit, err
|
||||||
|
}
|
|
@ -0,0 +1,117 @@
|
||||||
|
package humanize
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"sort"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Seconds-based time units
|
||||||
|
const (
|
||||||
|
Day = 24 * time.Hour
|
||||||
|
Week = 7 * Day
|
||||||
|
Month = 30 * Day
|
||||||
|
Year = 12 * Month
|
||||||
|
LongTime = 37 * Year
|
||||||
|
)
|
||||||
|
|
||||||
|
// Time formats a time into a relative string.
|
||||||
|
//
|
||||||
|
// Time(someT) -> "3 weeks ago"
|
||||||
|
func Time(then time.Time) string {
|
||||||
|
return RelTime(then, time.Now(), "ago", "from now")
|
||||||
|
}
|
||||||
|
|
||||||
|
// A RelTimeMagnitude struct contains a relative time point at which
|
||||||
|
// the relative format of time will switch to a new format string. A
|
||||||
|
// slice of these in ascending order by their "D" field is passed to
|
||||||
|
// CustomRelTime to format durations.
|
||||||
|
//
|
||||||
|
// The Format field is a string that may contain a "%s" which will be
|
||||||
|
// replaced with the appropriate signed label (e.g. "ago" or "from
|
||||||
|
// now") and a "%d" that will be replaced by the quantity.
|
||||||
|
//
|
||||||
|
// The DivBy field is the amount of time the time difference must be
|
||||||
|
// divided by in order to display correctly.
|
||||||
|
//
|
||||||
|
// e.g. if D is 2*time.Minute and you want to display "%d minutes %s"
|
||||||
|
// DivBy should be time.Minute so whatever the duration is will be
|
||||||
|
// expressed in minutes.
|
||||||
|
type RelTimeMagnitude struct {
|
||||||
|
D time.Duration
|
||||||
|
Format string
|
||||||
|
DivBy time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultMagnitudes = []RelTimeMagnitude{
|
||||||
|
{time.Second, "now", time.Second},
|
||||||
|
{2 * time.Second, "1 second %s", 1},
|
||||||
|
{time.Minute, "%d seconds %s", time.Second},
|
||||||
|
{2 * time.Minute, "1 minute %s", 1},
|
||||||
|
{time.Hour, "%d minutes %s", time.Minute},
|
||||||
|
{2 * time.Hour, "1 hour %s", 1},
|
||||||
|
{Day, "%d hours %s", time.Hour},
|
||||||
|
{2 * Day, "1 day %s", 1},
|
||||||
|
{Week, "%d days %s", Day},
|
||||||
|
{2 * Week, "1 week %s", 1},
|
||||||
|
{Month, "%d weeks %s", Week},
|
||||||
|
{2 * Month, "1 month %s", 1},
|
||||||
|
{Year, "%d months %s", Month},
|
||||||
|
{18 * Month, "1 year %s", 1},
|
||||||
|
{2 * Year, "2 years %s", 1},
|
||||||
|
{LongTime, "%d years %s", Year},
|
||||||
|
{math.MaxInt64, "a long while %s", 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
// RelTime formats a time into a relative string.
|
||||||
|
//
|
||||||
|
// It takes two times and two labels. In addition to the generic time
|
||||||
|
// delta string (e.g. 5 minutes), the labels are used applied so that
|
||||||
|
// the label corresponding to the smaller time is applied.
|
||||||
|
//
|
||||||
|
// RelTime(timeInPast, timeInFuture, "earlier", "later") -> "3 weeks earlier"
|
||||||
|
func RelTime(a, b time.Time, albl, blbl string) string {
|
||||||
|
return CustomRelTime(a, b, albl, blbl, defaultMagnitudes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CustomRelTime formats a time into a relative string.
|
||||||
|
//
|
||||||
|
// It takes two times two labels and a table of relative time formats.
|
||||||
|
// In addition to the generic time delta string (e.g. 5 minutes), the
|
||||||
|
// labels are used applied so that the label corresponding to the
|
||||||
|
// smaller time is applied.
|
||||||
|
func CustomRelTime(a, b time.Time, albl, blbl string, magnitudes []RelTimeMagnitude) string {
|
||||||
|
lbl := albl
|
||||||
|
diff := b.Sub(a)
|
||||||
|
|
||||||
|
if a.After(b) {
|
||||||
|
lbl = blbl
|
||||||
|
diff = a.Sub(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
n := sort.Search(len(magnitudes), func(i int) bool {
|
||||||
|
return magnitudes[i].D > diff
|
||||||
|
})
|
||||||
|
|
||||||
|
if n >= len(magnitudes) {
|
||||||
|
n = len(magnitudes) - 1
|
||||||
|
}
|
||||||
|
mag := magnitudes[n]
|
||||||
|
args := []interface{}{}
|
||||||
|
escaped := false
|
||||||
|
for _, ch := range mag.Format {
|
||||||
|
if escaped {
|
||||||
|
switch ch {
|
||||||
|
case 's':
|
||||||
|
args = append(args, lbl)
|
||||||
|
case 'd':
|
||||||
|
args = append(args, diff/mag.DivBy)
|
||||||
|
}
|
||||||
|
escaped = false
|
||||||
|
} else {
|
||||||
|
escaped = ch == '%'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Sprintf(mag.Format, args...)
|
||||||
|
}
|
|
@ -0,0 +1 @@
|
||||||
|
.vscode/
|
|
@ -0,0 +1,82 @@
|
||||||
|
# Changelog
|
||||||
|
|
||||||
|
All notable changes to this project will be documented in this file.
|
||||||
|
|
||||||
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||||
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
|
## [0.6.0] - 2023-01-30
|
||||||
|
|
||||||
|
[0.6.0]: https://github.com/go-logfmt/logfmt/compare/v0.5.1...v0.6.0
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- NewDecoderSize by [@alexanderjophus]
|
||||||
|
|
||||||
|
## [0.5.1] - 2021-08-18
|
||||||
|
|
||||||
|
[0.5.1]: https://github.com/go-logfmt/logfmt/compare/v0.5.0...v0.5.1
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- Update the `go.mod` file for Go 1.17 as described in the [Go 1.17 release
|
||||||
|
notes](https://golang.org/doc/go1.17#go-command)
|
||||||
|
|
||||||
|
## [0.5.0] - 2020-01-03
|
||||||
|
|
||||||
|
[0.5.0]: https://github.com/go-logfmt/logfmt/compare/v0.4.0...v0.5.0
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- Remove the dependency on github.com/kr/logfmt by [@ChrisHines]
|
||||||
|
- Move fuzz code to github.com/go-logfmt/fuzzlogfmt by [@ChrisHines]
|
||||||
|
|
||||||
|
## [0.4.0] - 2018-11-21
|
||||||
|
|
||||||
|
[0.4.0]: https://github.com/go-logfmt/logfmt/compare/v0.3.0...v0.4.0
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- Go module support by [@ChrisHines]
|
||||||
|
- CHANGELOG by [@ChrisHines]
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- Drop invalid runes from keys instead of returning ErrInvalidKey by [@ChrisHines]
|
||||||
|
- On panic while printing, attempt to print panic value by [@bboreham]
|
||||||
|
|
||||||
|
## [0.3.0] - 2016-11-15
|
||||||
|
|
||||||
|
[0.3.0]: https://github.com/go-logfmt/logfmt/compare/v0.2.0...v0.3.0
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- Pool buffers for quoted strings and byte slices by [@nussjustin]
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- Fuzz fix, quote invalid UTF-8 values by [@judwhite]
|
||||||
|
|
||||||
|
## [0.2.0] - 2016-05-08
|
||||||
|
|
||||||
|
[0.2.0]: https://github.com/go-logfmt/logfmt/compare/v0.1.0...v0.2.0
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- Encoder.EncodeKeyvals by [@ChrisHines]
|
||||||
|
|
||||||
|
## [0.1.0] - 2016-03-28
|
||||||
|
|
||||||
|
[0.1.0]: https://github.com/go-logfmt/logfmt/commits/v0.1.0
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- Encoder by [@ChrisHines]
|
||||||
|
- Decoder by [@ChrisHines]
|
||||||
|
- MarshalKeyvals by [@ChrisHines]
|
||||||
|
|
||||||
|
[@ChrisHines]: https://github.com/ChrisHines
|
||||||
|
[@bboreham]: https://github.com/bboreham
|
||||||
|
[@judwhite]: https://github.com/judwhite
|
||||||
|
[@nussjustin]: https://github.com/nussjustin
|
||||||
|
[@alexanderjophus]: https://github.com/alexanderjophus
|
|
@ -0,0 +1,22 @@
|
||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright (c) 2015 go-logfmt
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
# logfmt
|
||||||
|
|
||||||
|
[![Go Reference](https://pkg.go.dev/badge/github.com/go-logfmt/logfmt.svg)](https://pkg.go.dev/github.com/go-logfmt/logfmt)
|
||||||
|
[![Go Report Card](https://goreportcard.com/badge/go-logfmt/logfmt)](https://goreportcard.com/report/go-logfmt/logfmt)
|
||||||
|
[![Github Actions](https://github.com/go-logfmt/logfmt/actions/workflows/test.yml/badge.svg)](https://github.com/go-logfmt/logfmt/actions/workflows/test.yml)
|
||||||
|
[![Coverage Status](https://coveralls.io/repos/github/go-logfmt/logfmt/badge.svg?branch=master)](https://coveralls.io/github/go-logfmt/logfmt?branch=main)
|
||||||
|
|
||||||
|
Package logfmt implements utilities to marshal and unmarshal data in the [logfmt
|
||||||
|
format][fmt]. It provides an API similar to [encoding/json][json] and
|
||||||
|
[encoding/xml][xml].
|
||||||
|
|
||||||
|
[fmt]: https://brandur.org/logfmt
|
||||||
|
[json]: https://pkg.go.dev/encoding/json
|
||||||
|
[xml]: https://pkg.go.dev/encoding/xml
|
||||||
|
|
||||||
|
The logfmt format was first documented by Brandur Leach in [this
|
||||||
|
article][origin]. The format has not been formally standardized. The most
|
||||||
|
authoritative public specification to date has been the documentation of a Go
|
||||||
|
Language [package][parser] written by Blake Mizerany and Keith Rarick.
|
||||||
|
|
||||||
|
[origin]: https://brandur.org/logfmt
|
||||||
|
[parser]: https://pkg.go.dev/github.com/kr/logfmt
|
||||||
|
|
||||||
|
## Goals
|
||||||
|
|
||||||
|
This project attempts to conform as closely as possible to the prior art, while
|
||||||
|
also removing ambiguity where necessary to provide well behaved encoder and
|
||||||
|
decoder implementations.
|
||||||
|
|
||||||
|
## Non-goals
|
||||||
|
|
||||||
|
This project does not attempt to formally standardize the logfmt format. In the
|
||||||
|
event that logfmt is standardized this project would take conforming to the
|
||||||
|
standard as a goal.
|
||||||
|
|
||||||
|
## Versioning
|
||||||
|
|
||||||
|
This project publishes releases according to the Go language guidelines for
|
||||||
|
[developing and publishing modules][pub].
|
||||||
|
|
||||||
|
[pub]: https://go.dev/doc/modules/developing
|
|
@ -0,0 +1,254 @@
|
||||||
|
package logfmt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"unicode/utf8"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A Decoder reads and decodes logfmt records from an input stream.
|
||||||
|
type Decoder struct {
|
||||||
|
pos int
|
||||||
|
key []byte
|
||||||
|
value []byte
|
||||||
|
lineNum int
|
||||||
|
s *bufio.Scanner
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDecoder returns a new decoder that reads from r.
|
||||||
|
//
|
||||||
|
// The decoder introduces its own buffering and may read data from r beyond
|
||||||
|
// the logfmt records requested.
|
||||||
|
func NewDecoder(r io.Reader) *Decoder {
|
||||||
|
dec := &Decoder{
|
||||||
|
s: bufio.NewScanner(r),
|
||||||
|
}
|
||||||
|
return dec
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDecoderSize returns a new decoder that reads from r.
|
||||||
|
//
|
||||||
|
// The decoder introduces its own buffering and may read data from r beyond
|
||||||
|
// the logfmt records requested.
|
||||||
|
// The size argument specifies the size of the initial buffer that the
|
||||||
|
// Decoder will use to read records from r.
|
||||||
|
// If a log line is longer than the size argument, the Decoder will return
|
||||||
|
// a bufio.ErrTooLong error.
|
||||||
|
func NewDecoderSize(r io.Reader, size int) *Decoder {
|
||||||
|
scanner := bufio.NewScanner(r)
|
||||||
|
scanner.Buffer(make([]byte, 0, size), size)
|
||||||
|
dec := &Decoder{
|
||||||
|
s: scanner,
|
||||||
|
}
|
||||||
|
return dec
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScanRecord advances the Decoder to the next record, which can then be
|
||||||
|
// parsed with the ScanKeyval method. It returns false when decoding stops,
|
||||||
|
// either by reaching the end of the input or an error. After ScanRecord
|
||||||
|
// returns false, the Err method will return any error that occurred during
|
||||||
|
// decoding, except that if it was io.EOF, Err will return nil.
|
||||||
|
func (dec *Decoder) ScanRecord() bool {
|
||||||
|
if dec.err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !dec.s.Scan() {
|
||||||
|
dec.err = dec.s.Err()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
dec.lineNum++
|
||||||
|
dec.pos = 0
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScanKeyval advances the Decoder to the next key/value pair of the current
|
||||||
|
// record, which can then be retrieved with the Key and Value methods. It
|
||||||
|
// returns false when decoding stops, either by reaching the end of the
|
||||||
|
// current record or an error.
|
||||||
|
func (dec *Decoder) ScanKeyval() bool {
|
||||||
|
dec.key, dec.value = nil, nil
|
||||||
|
if dec.err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
line := dec.s.Bytes()
|
||||||
|
|
||||||
|
// garbage
|
||||||
|
for p, c := range line[dec.pos:] {
|
||||||
|
if c > ' ' {
|
||||||
|
dec.pos += p
|
||||||
|
goto key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dec.pos = len(line)
|
||||||
|
return false
|
||||||
|
|
||||||
|
key:
|
||||||
|
const invalidKeyError = "invalid key"
|
||||||
|
|
||||||
|
start, multibyte := dec.pos, false
|
||||||
|
for p, c := range line[dec.pos:] {
|
||||||
|
switch {
|
||||||
|
case c == '=':
|
||||||
|
dec.pos += p
|
||||||
|
if dec.pos > start {
|
||||||
|
dec.key = line[start:dec.pos]
|
||||||
|
if multibyte && bytes.ContainsRune(dec.key, utf8.RuneError) {
|
||||||
|
dec.syntaxError(invalidKeyError)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if dec.key == nil {
|
||||||
|
dec.unexpectedByte(c)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
goto equal
|
||||||
|
case c == '"':
|
||||||
|
dec.pos += p
|
||||||
|
dec.unexpectedByte(c)
|
||||||
|
return false
|
||||||
|
case c <= ' ':
|
||||||
|
dec.pos += p
|
||||||
|
if dec.pos > start {
|
||||||
|
dec.key = line[start:dec.pos]
|
||||||
|
if multibyte && bytes.ContainsRune(dec.key, utf8.RuneError) {
|
||||||
|
dec.syntaxError(invalidKeyError)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
case c >= utf8.RuneSelf:
|
||||||
|
multibyte = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dec.pos = len(line)
|
||||||
|
if dec.pos > start {
|
||||||
|
dec.key = line[start:dec.pos]
|
||||||
|
if multibyte && bytes.ContainsRune(dec.key, utf8.RuneError) {
|
||||||
|
dec.syntaxError(invalidKeyError)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
|
||||||
|
equal:
|
||||||
|
dec.pos++
|
||||||
|
if dec.pos >= len(line) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
switch c := line[dec.pos]; {
|
||||||
|
case c <= ' ':
|
||||||
|
return true
|
||||||
|
case c == '"':
|
||||||
|
goto qvalue
|
||||||
|
}
|
||||||
|
|
||||||
|
// value
|
||||||
|
start = dec.pos
|
||||||
|
for p, c := range line[dec.pos:] {
|
||||||
|
switch {
|
||||||
|
case c == '=' || c == '"':
|
||||||
|
dec.pos += p
|
||||||
|
dec.unexpectedByte(c)
|
||||||
|
return false
|
||||||
|
case c <= ' ':
|
||||||
|
dec.pos += p
|
||||||
|
if dec.pos > start {
|
||||||
|
dec.value = line[start:dec.pos]
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dec.pos = len(line)
|
||||||
|
if dec.pos > start {
|
||||||
|
dec.value = line[start:dec.pos]
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
|
||||||
|
qvalue:
|
||||||
|
const (
|
||||||
|
untermQuote = "unterminated quoted value"
|
||||||
|
invalidQuote = "invalid quoted value"
|
||||||
|
)
|
||||||
|
|
||||||
|
hasEsc, esc := false, false
|
||||||
|
start = dec.pos
|
||||||
|
for p, c := range line[dec.pos+1:] {
|
||||||
|
switch {
|
||||||
|
case esc:
|
||||||
|
esc = false
|
||||||
|
case c == '\\':
|
||||||
|
hasEsc, esc = true, true
|
||||||
|
case c == '"':
|
||||||
|
dec.pos += p + 2
|
||||||
|
if hasEsc {
|
||||||
|
v, ok := unquoteBytes(line[start:dec.pos])
|
||||||
|
if !ok {
|
||||||
|
dec.syntaxError(invalidQuote)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
dec.value = v
|
||||||
|
} else {
|
||||||
|
start++
|
||||||
|
end := dec.pos - 1
|
||||||
|
if end > start {
|
||||||
|
dec.value = line[start:end]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dec.pos = len(line)
|
||||||
|
dec.syntaxError(untermQuote)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Key returns the most recent key found by a call to ScanKeyval. The returned
|
||||||
|
// slice may point to internal buffers and is only valid until the next call
|
||||||
|
// to ScanRecord. It does no allocation.
|
||||||
|
func (dec *Decoder) Key() []byte {
|
||||||
|
return dec.key
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value returns the most recent value found by a call to ScanKeyval. The
|
||||||
|
// returned slice may point to internal buffers and is only valid until the
|
||||||
|
// next call to ScanRecord. It does no allocation when the value has no
|
||||||
|
// escape sequences.
|
||||||
|
func (dec *Decoder) Value() []byte {
|
||||||
|
return dec.value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Err returns the first non-EOF error that was encountered by the Scanner.
|
||||||
|
func (dec *Decoder) Err() error {
|
||||||
|
return dec.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dec *Decoder) syntaxError(msg string) {
|
||||||
|
dec.err = &SyntaxError{
|
||||||
|
Msg: msg,
|
||||||
|
Line: dec.lineNum,
|
||||||
|
Pos: dec.pos + 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dec *Decoder) unexpectedByte(c byte) {
|
||||||
|
dec.err = &SyntaxError{
|
||||||
|
Msg: fmt.Sprintf("unexpected %q", c),
|
||||||
|
Line: dec.lineNum,
|
||||||
|
Pos: dec.pos + 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// A SyntaxError represents a syntax error in the logfmt input stream.
|
||||||
|
type SyntaxError struct {
|
||||||
|
Msg string
|
||||||
|
Line int
|
||||||
|
Pos int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *SyntaxError) Error() string {
|
||||||
|
return fmt.Sprintf("logfmt syntax error at pos %d on line %d: %s", e.Pos, e.Line, e.Msg)
|
||||||
|
}
|
|
@ -0,0 +1,6 @@
|
||||||
|
// Package logfmt implements utilities to marshal and unmarshal data in the
|
||||||
|
// logfmt format. The logfmt format records key/value pairs in a way that
|
||||||
|
// balances readability for humans and simplicity of computer parsing. It is
|
||||||
|
// most commonly used as a more human friendly alternative to JSON for
|
||||||
|
// structured logging.
|
||||||
|
package logfmt
|
|
@ -0,0 +1,322 @@
|
||||||
|
package logfmt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"unicode/utf8"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MarshalKeyvals returns the logfmt encoding of keyvals, a variadic sequence
|
||||||
|
// of alternating keys and values.
|
||||||
|
func MarshalKeyvals(keyvals ...interface{}) ([]byte, error) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
if err := NewEncoder(buf).EncodeKeyvals(keyvals...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// An Encoder writes logfmt data to an output stream.
|
||||||
|
type Encoder struct {
|
||||||
|
w io.Writer
|
||||||
|
scratch bytes.Buffer
|
||||||
|
needSep bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEncoder returns a new encoder that writes to w.
|
||||||
|
func NewEncoder(w io.Writer) *Encoder {
|
||||||
|
return &Encoder{
|
||||||
|
w: w,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
space = []byte(" ")
|
||||||
|
equals = []byte("=")
|
||||||
|
newline = []byte("\n")
|
||||||
|
null = []byte("null")
|
||||||
|
)
|
||||||
|
|
||||||
|
// EncodeKeyval writes the logfmt encoding of key and value to the stream. A
|
||||||
|
// single space is written before the second and subsequent keys in a record.
|
||||||
|
// Nothing is written if a non-nil error is returned.
|
||||||
|
func (enc *Encoder) EncodeKeyval(key, value interface{}) error {
|
||||||
|
enc.scratch.Reset()
|
||||||
|
if enc.needSep {
|
||||||
|
if _, err := enc.scratch.Write(space); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := writeKey(&enc.scratch, key); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := enc.scratch.Write(equals); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := writeValue(&enc.scratch, value); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err := enc.w.Write(enc.scratch.Bytes())
|
||||||
|
enc.needSep = true
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeKeyvals writes the logfmt encoding of keyvals to the stream. Keyvals
|
||||||
|
// is a variadic sequence of alternating keys and values. Keys of unsupported
|
||||||
|
// type are skipped along with their corresponding value. Values of
|
||||||
|
// unsupported type or that cause a MarshalerError are replaced by their error
|
||||||
|
// but do not cause EncodeKeyvals to return an error. If a non-nil error is
|
||||||
|
// returned some key/value pairs may not have be written.
|
||||||
|
func (enc *Encoder) EncodeKeyvals(keyvals ...interface{}) error {
|
||||||
|
if len(keyvals) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(keyvals)%2 == 1 {
|
||||||
|
keyvals = append(keyvals, nil)
|
||||||
|
}
|
||||||
|
for i := 0; i < len(keyvals); i += 2 {
|
||||||
|
k, v := keyvals[i], keyvals[i+1]
|
||||||
|
err := enc.EncodeKeyval(k, v)
|
||||||
|
if err == ErrUnsupportedKeyType {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := err.(*MarshalerError); ok || err == ErrUnsupportedValueType {
|
||||||
|
v = err
|
||||||
|
err = enc.EncodeKeyval(k, v)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalerError represents an error encountered while marshaling a value.
|
||||||
|
type MarshalerError struct {
|
||||||
|
Type reflect.Type
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *MarshalerError) Error() string {
|
||||||
|
return "error marshaling value of type " + e.Type.String() + ": " + e.Err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrNilKey is returned by Marshal functions and Encoder methods if a key is
|
||||||
|
// a nil interface or pointer value.
|
||||||
|
var ErrNilKey = errors.New("nil key")
|
||||||
|
|
||||||
|
// ErrInvalidKey is returned by Marshal functions and Encoder methods if, after
|
||||||
|
// dropping invalid runes, a key is empty.
|
||||||
|
var ErrInvalidKey = errors.New("invalid key")
|
||||||
|
|
||||||
|
// ErrUnsupportedKeyType is returned by Encoder methods if a key has an
|
||||||
|
// unsupported type.
|
||||||
|
var ErrUnsupportedKeyType = errors.New("unsupported key type")
|
||||||
|
|
||||||
|
// ErrUnsupportedValueType is returned by Encoder methods if a value has an
|
||||||
|
// unsupported type.
|
||||||
|
var ErrUnsupportedValueType = errors.New("unsupported value type")
|
||||||
|
|
||||||
|
func writeKey(w io.Writer, key interface{}) error {
|
||||||
|
if key == nil {
|
||||||
|
return ErrNilKey
|
||||||
|
}
|
||||||
|
|
||||||
|
switch k := key.(type) {
|
||||||
|
case string:
|
||||||
|
return writeStringKey(w, k)
|
||||||
|
case []byte:
|
||||||
|
if k == nil {
|
||||||
|
return ErrNilKey
|
||||||
|
}
|
||||||
|
return writeBytesKey(w, k)
|
||||||
|
case encoding.TextMarshaler:
|
||||||
|
kb, err := safeMarshal(k)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if kb == nil {
|
||||||
|
return ErrNilKey
|
||||||
|
}
|
||||||
|
return writeBytesKey(w, kb)
|
||||||
|
case fmt.Stringer:
|
||||||
|
ks, ok := safeString(k)
|
||||||
|
if !ok {
|
||||||
|
return ErrNilKey
|
||||||
|
}
|
||||||
|
return writeStringKey(w, ks)
|
||||||
|
default:
|
||||||
|
rkey := reflect.ValueOf(key)
|
||||||
|
switch rkey.Kind() {
|
||||||
|
case reflect.Array, reflect.Chan, reflect.Func, reflect.Map, reflect.Slice, reflect.Struct:
|
||||||
|
return ErrUnsupportedKeyType
|
||||||
|
case reflect.Ptr:
|
||||||
|
if rkey.IsNil() {
|
||||||
|
return ErrNilKey
|
||||||
|
}
|
||||||
|
return writeKey(w, rkey.Elem().Interface())
|
||||||
|
}
|
||||||
|
return writeStringKey(w, fmt.Sprint(k))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// keyRuneFilter returns r for all valid key runes, and -1 for all invalid key
|
||||||
|
// runes. When used as the mapping function for strings.Map and bytes.Map
|
||||||
|
// functions it causes them to remove invalid key runes from strings or byte
|
||||||
|
// slices respectively.
|
||||||
|
func keyRuneFilter(r rune) rune {
|
||||||
|
if r <= ' ' || r == '=' || r == '"' || r == utf8.RuneError {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeStringKey(w io.Writer, key string) error {
|
||||||
|
k := strings.Map(keyRuneFilter, key)
|
||||||
|
if k == "" {
|
||||||
|
return ErrInvalidKey
|
||||||
|
}
|
||||||
|
_, err := io.WriteString(w, k)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeBytesKey(w io.Writer, key []byte) error {
|
||||||
|
k := bytes.Map(keyRuneFilter, key)
|
||||||
|
if len(k) == 0 {
|
||||||
|
return ErrInvalidKey
|
||||||
|
}
|
||||||
|
_, err := w.Write(k)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeValue(w io.Writer, value interface{}) error {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case nil:
|
||||||
|
return writeBytesValue(w, null)
|
||||||
|
case string:
|
||||||
|
return writeStringValue(w, v, true)
|
||||||
|
case []byte:
|
||||||
|
return writeBytesValue(w, v)
|
||||||
|
case encoding.TextMarshaler:
|
||||||
|
vb, err := safeMarshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if vb == nil {
|
||||||
|
vb = null
|
||||||
|
}
|
||||||
|
return writeBytesValue(w, vb)
|
||||||
|
case error:
|
||||||
|
se, ok := safeError(v)
|
||||||
|
return writeStringValue(w, se, ok)
|
||||||
|
case fmt.Stringer:
|
||||||
|
ss, ok := safeString(v)
|
||||||
|
return writeStringValue(w, ss, ok)
|
||||||
|
default:
|
||||||
|
rvalue := reflect.ValueOf(value)
|
||||||
|
switch rvalue.Kind() {
|
||||||
|
case reflect.Array, reflect.Chan, reflect.Func, reflect.Map, reflect.Slice, reflect.Struct:
|
||||||
|
return ErrUnsupportedValueType
|
||||||
|
case reflect.Ptr:
|
||||||
|
if rvalue.IsNil() {
|
||||||
|
return writeBytesValue(w, null)
|
||||||
|
}
|
||||||
|
return writeValue(w, rvalue.Elem().Interface())
|
||||||
|
}
|
||||||
|
return writeStringValue(w, fmt.Sprint(v), true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func needsQuotedValueRune(r rune) bool {
|
||||||
|
return r <= ' ' || r == '=' || r == '"' || r == utf8.RuneError
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeStringValue(w io.Writer, value string, ok bool) error {
|
||||||
|
var err error
|
||||||
|
if ok && value == "null" {
|
||||||
|
_, err = io.WriteString(w, `"null"`)
|
||||||
|
} else if strings.IndexFunc(value, needsQuotedValueRune) != -1 {
|
||||||
|
_, err = writeQuotedString(w, value)
|
||||||
|
} else {
|
||||||
|
_, err = io.WriteString(w, value)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeBytesValue(w io.Writer, value []byte) error {
|
||||||
|
var err error
|
||||||
|
if bytes.IndexFunc(value, needsQuotedValueRune) != -1 {
|
||||||
|
_, err = writeQuotedBytes(w, value)
|
||||||
|
} else {
|
||||||
|
_, err = w.Write(value)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// EndRecord writes a newline character to the stream and resets the encoder
|
||||||
|
// to the beginning of a new record.
|
||||||
|
func (enc *Encoder) EndRecord() error {
|
||||||
|
_, err := enc.w.Write(newline)
|
||||||
|
if err == nil {
|
||||||
|
enc.needSep = false
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset resets the encoder to the beginning of a new record.
|
||||||
|
func (enc *Encoder) Reset() {
|
||||||
|
enc.needSep = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func safeError(err error) (s string, ok bool) {
|
||||||
|
defer func() {
|
||||||
|
if panicVal := recover(); panicVal != nil {
|
||||||
|
if v := reflect.ValueOf(err); v.Kind() == reflect.Ptr && v.IsNil() {
|
||||||
|
s, ok = "null", false
|
||||||
|
} else {
|
||||||
|
s, ok = fmt.Sprintf("PANIC:%v", panicVal), false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
s, ok = err.Error(), true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func safeString(str fmt.Stringer) (s string, ok bool) {
|
||||||
|
defer func() {
|
||||||
|
if panicVal := recover(); panicVal != nil {
|
||||||
|
if v := reflect.ValueOf(str); v.Kind() == reflect.Ptr && v.IsNil() {
|
||||||
|
s, ok = "null", false
|
||||||
|
} else {
|
||||||
|
s, ok = fmt.Sprintf("PANIC:%v", panicVal), true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
s, ok = str.String(), true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func safeMarshal(tm encoding.TextMarshaler) (b []byte, err error) {
|
||||||
|
defer func() {
|
||||||
|
if panicVal := recover(); panicVal != nil {
|
||||||
|
if v := reflect.ValueOf(tm); v.Kind() == reflect.Ptr && v.IsNil() {
|
||||||
|
b, err = nil, nil
|
||||||
|
} else {
|
||||||
|
b, err = nil, fmt.Errorf("panic when marshalling: %s", panicVal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
b, err = tm.MarshalText()
|
||||||
|
if err != nil {
|
||||||
|
return nil, &MarshalerError{
|
||||||
|
Type: reflect.TypeOf(tm),
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
|
@ -0,0 +1,277 @@
|
||||||
|
package logfmt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"unicode"
|
||||||
|
"unicode/utf16"
|
||||||
|
"unicode/utf8"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Taken from Go's encoding/json and modified for use here.
|
||||||
|
|
||||||
|
// Copyright 2010 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
var hex = "0123456789abcdef"
|
||||||
|
|
||||||
|
var bufferPool = sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return &bytes.Buffer{}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func getBuffer() *bytes.Buffer {
|
||||||
|
return bufferPool.Get().(*bytes.Buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func poolBuffer(buf *bytes.Buffer) {
|
||||||
|
buf.Reset()
|
||||||
|
bufferPool.Put(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: keep in sync with writeQuotedBytes below.
|
||||||
|
func writeQuotedString(w io.Writer, s string) (int, error) {
|
||||||
|
buf := getBuffer()
|
||||||
|
buf.WriteByte('"')
|
||||||
|
start := 0
|
||||||
|
for i := 0; i < len(s); {
|
||||||
|
if b := s[i]; b < utf8.RuneSelf {
|
||||||
|
if 0x20 <= b && b != '\\' && b != '"' {
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if start < i {
|
||||||
|
buf.WriteString(s[start:i])
|
||||||
|
}
|
||||||
|
switch b {
|
||||||
|
case '\\', '"':
|
||||||
|
buf.WriteByte('\\')
|
||||||
|
buf.WriteByte(b)
|
||||||
|
case '\n':
|
||||||
|
buf.WriteByte('\\')
|
||||||
|
buf.WriteByte('n')
|
||||||
|
case '\r':
|
||||||
|
buf.WriteByte('\\')
|
||||||
|
buf.WriteByte('r')
|
||||||
|
case '\t':
|
||||||
|
buf.WriteByte('\\')
|
||||||
|
buf.WriteByte('t')
|
||||||
|
default:
|
||||||
|
// This encodes bytes < 0x20 except for \n, \r, and \t.
|
||||||
|
buf.WriteString(`\u00`)
|
||||||
|
buf.WriteByte(hex[b>>4])
|
||||||
|
buf.WriteByte(hex[b&0xF])
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
start = i
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
c, size := utf8.DecodeRuneInString(s[i:])
|
||||||
|
if c == utf8.RuneError {
|
||||||
|
if start < i {
|
||||||
|
buf.WriteString(s[start:i])
|
||||||
|
}
|
||||||
|
buf.WriteString(`\ufffd`)
|
||||||
|
i += size
|
||||||
|
start = i
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
i += size
|
||||||
|
}
|
||||||
|
if start < len(s) {
|
||||||
|
buf.WriteString(s[start:])
|
||||||
|
}
|
||||||
|
buf.WriteByte('"')
|
||||||
|
n, err := w.Write(buf.Bytes())
|
||||||
|
poolBuffer(buf)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: keep in sync with writeQuoteString above.
|
||||||
|
func writeQuotedBytes(w io.Writer, s []byte) (int, error) {
|
||||||
|
buf := getBuffer()
|
||||||
|
buf.WriteByte('"')
|
||||||
|
start := 0
|
||||||
|
for i := 0; i < len(s); {
|
||||||
|
if b := s[i]; b < utf8.RuneSelf {
|
||||||
|
if 0x20 <= b && b != '\\' && b != '"' {
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if start < i {
|
||||||
|
buf.Write(s[start:i])
|
||||||
|
}
|
||||||
|
switch b {
|
||||||
|
case '\\', '"':
|
||||||
|
buf.WriteByte('\\')
|
||||||
|
buf.WriteByte(b)
|
||||||
|
case '\n':
|
||||||
|
buf.WriteByte('\\')
|
||||||
|
buf.WriteByte('n')
|
||||||
|
case '\r':
|
||||||
|
buf.WriteByte('\\')
|
||||||
|
buf.WriteByte('r')
|
||||||
|
case '\t':
|
||||||
|
buf.WriteByte('\\')
|
||||||
|
buf.WriteByte('t')
|
||||||
|
default:
|
||||||
|
// This encodes bytes < 0x20 except for \n, \r, and \t.
|
||||||
|
buf.WriteString(`\u00`)
|
||||||
|
buf.WriteByte(hex[b>>4])
|
||||||
|
buf.WriteByte(hex[b&0xF])
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
start = i
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
c, size := utf8.DecodeRune(s[i:])
|
||||||
|
if c == utf8.RuneError {
|
||||||
|
if start < i {
|
||||||
|
buf.Write(s[start:i])
|
||||||
|
}
|
||||||
|
buf.WriteString(`\ufffd`)
|
||||||
|
i += size
|
||||||
|
start = i
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
i += size
|
||||||
|
}
|
||||||
|
if start < len(s) {
|
||||||
|
buf.Write(s[start:])
|
||||||
|
}
|
||||||
|
buf.WriteByte('"')
|
||||||
|
n, err := w.Write(buf.Bytes())
|
||||||
|
poolBuffer(buf)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// getu4 decodes \uXXXX from the beginning of s, returning the hex value,
|
||||||
|
// or it returns -1.
|
||||||
|
func getu4(s []byte) rune {
|
||||||
|
if len(s) < 6 || s[0] != '\\' || s[1] != 'u' {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
r, err := strconv.ParseUint(string(s[2:6]), 16, 64)
|
||||||
|
if err != nil {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return rune(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func unquoteBytes(s []byte) (t []byte, ok bool) {
|
||||||
|
if len(s) < 2 || s[0] != '"' || s[len(s)-1] != '"' {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s = s[1 : len(s)-1]
|
||||||
|
|
||||||
|
// Check for unusual characters. If there are none,
|
||||||
|
// then no unquoting is needed, so return a slice of the
|
||||||
|
// original bytes.
|
||||||
|
r := 0
|
||||||
|
for r < len(s) {
|
||||||
|
c := s[r]
|
||||||
|
if c == '\\' || c == '"' || c < ' ' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if c < utf8.RuneSelf {
|
||||||
|
r++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rr, size := utf8.DecodeRune(s[r:])
|
||||||
|
if rr == utf8.RuneError {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
r += size
|
||||||
|
}
|
||||||
|
if r == len(s) {
|
||||||
|
return s, true
|
||||||
|
}
|
||||||
|
|
||||||
|
b := make([]byte, len(s)+2*utf8.UTFMax)
|
||||||
|
w := copy(b, s[0:r])
|
||||||
|
for r < len(s) {
|
||||||
|
// Out of room? Can only happen if s is full of
|
||||||
|
// malformed UTF-8 and we're replacing each
|
||||||
|
// byte with RuneError.
|
||||||
|
if w >= len(b)-2*utf8.UTFMax {
|
||||||
|
nb := make([]byte, (len(b)+utf8.UTFMax)*2)
|
||||||
|
copy(nb, b[0:w])
|
||||||
|
b = nb
|
||||||
|
}
|
||||||
|
switch c := s[r]; {
|
||||||
|
case c == '\\':
|
||||||
|
r++
|
||||||
|
if r >= len(s) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch s[r] {
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
case '"', '\\', '/', '\'':
|
||||||
|
b[w] = s[r]
|
||||||
|
r++
|
||||||
|
w++
|
||||||
|
case 'b':
|
||||||
|
b[w] = '\b'
|
||||||
|
r++
|
||||||
|
w++
|
||||||
|
case 'f':
|
||||||
|
b[w] = '\f'
|
||||||
|
r++
|
||||||
|
w++
|
||||||
|
case 'n':
|
||||||
|
b[w] = '\n'
|
||||||
|
r++
|
||||||
|
w++
|
||||||
|
case 'r':
|
||||||
|
b[w] = '\r'
|
||||||
|
r++
|
||||||
|
w++
|
||||||
|
case 't':
|
||||||
|
b[w] = '\t'
|
||||||
|
r++
|
||||||
|
w++
|
||||||
|
case 'u':
|
||||||
|
r--
|
||||||
|
rr := getu4(s[r:])
|
||||||
|
if rr < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r += 6
|
||||||
|
if utf16.IsSurrogate(rr) {
|
||||||
|
rr1 := getu4(s[r:])
|
||||||
|
if dec := utf16.DecodeRune(rr, rr1); dec != unicode.ReplacementChar {
|
||||||
|
// A valid pair; consume.
|
||||||
|
r += 6
|
||||||
|
w += utf8.EncodeRune(b[w:], dec)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// Invalid surrogate; fall back to replacement rune.
|
||||||
|
rr = unicode.ReplacementChar
|
||||||
|
}
|
||||||
|
w += utf8.EncodeRune(b[w:], rr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Quote, control characters are invalid.
|
||||||
|
case c == '"', c < ' ':
|
||||||
|
return
|
||||||
|
|
||||||
|
// ASCII
|
||||||
|
case c < utf8.RuneSelf:
|
||||||
|
b[w] = c
|
||||||
|
r++
|
||||||
|
w++
|
||||||
|
|
||||||
|
// Coerce to well-formed UTF-8.
|
||||||
|
default:
|
||||||
|
rr, size := utf8.DecodeRune(s[r:])
|
||||||
|
r += size
|
||||||
|
w += utf8.EncodeRune(b[w:], rr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return b[0:w], true
|
||||||
|
}
|
|
@ -0,0 +1,9 @@
|
||||||
|
language: go
|
||||||
|
|
||||||
|
go:
|
||||||
|
- 1.4.3
|
||||||
|
- 1.5.3
|
||||||
|
- tip
|
||||||
|
|
||||||
|
script:
|
||||||
|
- go test -v ./...
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue