Commit c5df649f authored by zauberstuhl's avatar zauberstuhl

Implement new logger interface and tests

parent 6f5c4296
...@@ -37,7 +37,7 @@ func FetchEntityOrder(entityXML string) (order string) { ...@@ -37,7 +37,7 @@ func FetchEntityOrder(entityXML string) (order string) {
} }
} }
if len(order) <= 0 { if len(order) <= 0 {
warn("Entity order is empty") logger.Warn("Entity order is empty")
return return
} }
return order[:len(order)-1] // trim space return order[:len(order)-1] // trim space
......
...@@ -55,7 +55,7 @@ func push(host, endpoint, proto, contentType string, body io.Reader) error { ...@@ -55,7 +55,7 @@ func push(host, endpoint, proto, contentType string, body io.Reader) error {
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
if proto == PROTO_HTTPS { if proto == PROTO_HTTPS {
info("Retry with", PROTO_HTTP, "on", host, err) logger.Info("Retry with", PROTO_HTTP, "on", host, err)
return push(host, endpoint, PROTO_HTTP, contentType, body) return push(host, endpoint, PROTO_HTTP, contentType, body)
} }
return err return err
......
...@@ -21,9 +21,23 @@ import ( ...@@ -21,9 +21,23 @@ import (
"fmt" "fmt"
"runtime" "runtime"
"regexp" "regexp"
"os"
"log"
) )
func log(label string, msgs... interface{}) { const (
LOG_C_RED = "\033[31m"
LOG_C_YELLOW = "\033[33m"
LOG_C_RESET = "\033[0m"
)
var (
logger Log
defaultLogger Logger
defaultPrefix string
)
func init() {
pc := make([]uintptr, 10) // at least 1 entry needed pc := make([]uintptr, 10) // at least 1 entry needed
runtime.Callers(3, pc) runtime.Callers(3, pc)
f := runtime.FuncForPC(pc[0]) f := runtime.FuncForPC(pc[0])
...@@ -34,41 +48,30 @@ func log(label string, msgs... interface{}) { ...@@ -34,41 +48,30 @@ func log(label string, msgs... interface{}) {
file = result[0][1] file = result[0][1]
} }
fmt.Printf("%s:%d %s ", file, line, f.Name()) defaultPrefix = fmt.Sprintf("%s:%d %s ", file, line, f.Name())
defaultLogger = log.New(os.Stdout, defaultPrefix, log.Lshortfile)
}
type Logger interface {
Println(v... interface{})
}
for _, e := range msgs { type Log struct{
switch msg := e.(type) { Logger
case error: }
fmt.Printf("[%s] ", label)
fmt.Print(msg) func SetLogger(logger Logger) {
case []error: defaultLogger = logger
fmt.Println(" \\")
for _, err := range msg {
fmt.Printf("\t[%s] ", label)
fmt.Println(err)
}
case string:
fmt.Printf("[%s] ", label)
fmt.Print(msg)
case byte, []byte:
fmt.Printf("[%s] ", label)
fmt.Print("%s", msg)
default:
fmt.Printf("[%s] ", label)
fmt.Print(msg)
}
fmt.Println()
}
} }
func warn(msgs... interface{}) { func (l Log) Info(values... interface{}) {
log("W", msgs) defaultLogger.Println(values...)
} }
func info(msgs... interface{}) { func (l Log) Error(values... interface{}) {
log("I", msgs) l.Info(LOG_C_RED, values, LOG_C_RESET)
} }
func fatal(msgs... interface{}) { func (l Log) Warn(values... interface{}) {
log("F", msgs) l.Info(LOG_C_YELLOW, values, LOG_C_RESET)
} }
package federation
//
// GangGo Diaspora Federation Library
// Copyright (C) 2017 Lukas Matt <lukas@zauberstuhl.de>
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
//
import (
"testing"
"bytes"
"log"
"os"
"regexp"
)
type TestLogger struct {
Log string
}
func (t *TestLogger) Println(v... interface{}) {
t.Log = v[0].(string)
}
func TestSetLogger(t *testing.T) {
var buf bytes.Buffer
var expected = "Hello World"
SetLogger(log.New(&buf, defaultPrefix, log.Lshortfile))
logger.Info(expected)
matched, err := regexp.MatchString(expected + "\n", buf.String())
if err != nil || !matched {
t.Errorf("Expected to be %s, got %s (%v)", expected, buf.String(), err)
}
var testLogger TestLogger
SetLogger(&testLogger)
logger.Info(expected)
if expected != testLogger.Log {
t.Errorf("Expected to be %s, got %s", expected, testLogger.Log)
}
// reset otherwise it will break test output
SetLogger(log.New(os.Stdout, defaultPrefix, log.Lshortfile))
}
...@@ -43,9 +43,11 @@ type MagicEnvelopeMarshal struct { ...@@ -43,9 +43,11 @@ type MagicEnvelopeMarshal struct {
} }
func MagicEnvelope(privkey, handle string, plainXml []byte) (payload []byte, err error) { func MagicEnvelope(privkey, handle string, plainXml []byte) (payload []byte, err error) {
info("plain xml", string(plainXml)) logger.Info(
info("privkey length", len(privkey)) "MagicEnvelope with", string(plainXml),
info("handle", handle) "private key length", len(privkey),
"for", handle,
)
data := base64.URLEncoding.EncodeToString(plainXml) data := base64.URLEncoding.EncodeToString(plainXml)
keyId := base64.URLEncoding.EncodeToString([]byte(handle)) keyId := base64.URLEncoding.EncodeToString([]byte(handle))
...@@ -60,28 +62,29 @@ func MagicEnvelope(privkey, handle string, plainXml []byte) (payload []byte, err ...@@ -60,28 +62,29 @@ func MagicEnvelope(privkey, handle string, plainXml []byte) (payload []byte, err
err = xmlBody.Sign(privkey) err = xmlBody.Sign(privkey)
if err != nil { if err != nil {
warn(err) logger.Warn(err)
return return
} }
payload, err = xml.Marshal(xmlBody) payload, err = xml.Marshal(xmlBody)
if err != nil { if err != nil {
warn(err) logger.Warn(err)
return return
} }
info("payload", string(payload))
logger.Info("MagicEnvelope payload", string(payload))
return return
} }
func EncryptedMagicEnvelope(privkey, pubkey, handle string, serializedXml []byte) (payload []byte, err error) { func EncryptedMagicEnvelope(privkey, pubkey, handle string, serializedXml []byte) (payload []byte, err error) {
logger.Info("EncryptedMagicEnvelope with", string(serializedXml),
"private key length", len(privkey),
"and public key length", len(pubkey),
"for", handle,
)
var aesKeySet Aes var aesKeySet Aes
var aesWrapper AesWrapper var aesWrapper AesWrapper
info("serialized xml", string(serializedXml))
info("privkey length", len(privkey))
info("pubkey length", len(pubkey))
info("handle", handle)
data := base64.URLEncoding.EncodeToString(serializedXml) data := base64.URLEncoding.EncodeToString(serializedXml)
keyId := base64.URLEncoding.EncodeToString([]byte(handle)) keyId := base64.URLEncoding.EncodeToString([]byte(handle))
...@@ -96,62 +99,64 @@ func EncryptedMagicEnvelope(privkey, pubkey, handle string, serializedXml []byte ...@@ -96,62 +99,64 @@ func EncryptedMagicEnvelope(privkey, pubkey, handle string, serializedXml []byte
err = envelope.Sign(privkey) err = envelope.Sign(privkey)
if err != nil { if err != nil {
warn(err) logger.Warn(err)
return return
} }
// Generate a new AES key pair // Generate a new AES key pair
err = aesKeySet.Generate() err = aesKeySet.Generate()
if err != nil { if err != nil {
warn(err) logger.Warn(err)
return return
} }
// payload with aes encryption // payload with aes encryption
payload, err = xml.Marshal(envelope) payload, err = xml.Marshal(envelope)
if err != nil { if err != nil {
warn(err) logger.Warn(err)
return return
} }
info("payload, err = xml.Marshal(envelope) ", string(payload)) logger.Info(
"EncryptedMagicEnvelope payload with aes encryption",
string(payload),
)
err = aesKeySet.Encrypt(payload) err = aesKeySet.Encrypt(payload)
if err != nil { if err != nil {
warn(err) logger.Warn(err)
return return
} }
//aesWrapper.MagicEnvelope = base64.StdEncoding.EncodeToString([]byte(aesKeySet.Data))
aesWrapper.MagicEnvelope = aesKeySet.Data aesWrapper.MagicEnvelope = aesKeySet.Data
// aes with rsa encryption // aes with rsa encryption
aesKeySetXml, err := json.Marshal(aesKeySet) aesKeySetXml, err := json.Marshal(aesKeySet)
if err != nil { if err != nil {
warn(err) logger.Warn(err)
return return
} }
pubKey, err := ParseRSAPubKey([]byte(pubkey)) pubKey, err := ParseRSAPubKey([]byte(pubkey))
if err != nil { if err != nil {
warn(err) logger.Warn(err)
return return
} }
info("aesKeySetXml", string(aesKeySetXml)) logger.Info("AES key-set XML", string(aesKeySetXml))
aesKey, err := rsa.EncryptPKCS1v15(rand.Reader, pubKey, aesKeySetXml) aesKey, err := rsa.EncryptPKCS1v15(rand.Reader, pubKey, aesKeySetXml)
if err != nil { if err != nil {
warn(err) logger.Warn(err)
return return
} }
aesWrapper.AesKey = base64.StdEncoding.EncodeToString(aesKey) aesWrapper.AesKey = base64.StdEncoding.EncodeToString(aesKey)
payload, err = json.Marshal(aesWrapper) payload, err = json.Marshal(aesWrapper)
if err != nil { if err != nil {
warn(err) logger.Warn(err)
return return
} }
info("payload", string(payload)) logger.Info("EncryptedMagicEnvelope payload", string(payload))
return return
} }
...@@ -26,30 +26,30 @@ import ( ...@@ -26,30 +26,30 @@ import (
func ParseDecryptedRequest(entityXML []byte) (message Message, err error) { func ParseDecryptedRequest(entityXML []byte) (message Message, err error) {
err = xml.Unmarshal(entityXML, &message) err = xml.Unmarshal(entityXML, &message)
if err != nil { if err != nil {
fatal(err) logger.Error(err)
return return
} }
if !strings.EqualFold(message.Encoding, BASE64_URL) { if !strings.EqualFold(message.Encoding, BASE64_URL) {
fatal(err) logger.Error(err)
return return
} }
if !strings.EqualFold(message.Alg, RSA_SHA256) { if !strings.EqualFold(message.Alg, RSA_SHA256) {
fatal(err) logger.Error(err)
return return
} }
keyId, err := base64.StdEncoding.DecodeString(message.Sig.KeyId) keyId, err := base64.StdEncoding.DecodeString(message.Sig.KeyId)
if err != nil { if err != nil {
fatal(err) logger.Error(err)
return return
} }
message.Sig.KeyId = string(keyId) message.Sig.KeyId = string(keyId)
data, err := base64.URLEncoding.DecodeString(message.Data.Data) data, err := base64.URLEncoding.DecodeString(message.Data.Data)
if err != nil { if err != nil {
fatal(err) logger.Error(err)
return return
} }
...@@ -58,7 +58,7 @@ func ParseDecryptedRequest(entityXML []byte) (message Message, err error) { ...@@ -58,7 +58,7 @@ func ParseDecryptedRequest(entityXML []byte) (message Message, err error) {
} }
err = xml.Unmarshal(data, &entity) err = xml.Unmarshal(data, &entity)
if err != nil { if err != nil {
fatal(err) logger.Error(err)
return return
} }
message.Entity = entity message.Entity = entity
...@@ -68,7 +68,7 @@ func ParseDecryptedRequest(entityXML []byte) (message Message, err error) { ...@@ -68,7 +68,7 @@ func ParseDecryptedRequest(entityXML []byte) (message Message, err error) {
func ParseEncryptedRequest(wrapper AesWrapper, privkey []byte) (message Message, err error) { func ParseEncryptedRequest(wrapper AesWrapper, privkey []byte) (message Message, err error) {
entityXML, err := wrapper.Decrypt(privkey) entityXML, err := wrapper.Decrypt(privkey)
if err != nil { if err != nil {
fatal(err) logger.Error(err)
return return
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment