165 lines
3.5 KiB
Go
165 lines
3.5 KiB
Go
package database
|
|
|
|
import (
|
|
"time"
|
|
|
|
log "github.com/cihub/seelog"
|
|
|
|
"bytes"
|
|
"crypto/rand"
|
|
"errors"
|
|
"regexp"
|
|
|
|
"golang.org/x/crypto/scrypt"
|
|
)
|
|
|
|
var alphaNumeric = regexp.MustCompile(`^[a-zA-Z0-9_\-\.]+$`).MatchString
|
|
|
|
type User struct {
|
|
ID int `pg:"type:serial"`
|
|
Username string `pg:"type:varchar(255),unique"`
|
|
Password []byte
|
|
Salt []byte
|
|
Role string `pg:"type:varchar(255)"`
|
|
LastLogin time.Time
|
|
}
|
|
|
|
func (db *pgDB) AddUser(name string, pass string) error {
|
|
if !validUserName(name) {
|
|
return errors.New("Invalid user name. Username needs to have at least 3 characters and can only be letters, numbers, '-', '_' and '.'.")
|
|
}
|
|
num, err := db.sql.Model(&User{}).Where("lower(username) = lower(?)", name).Count()
|
|
if err != nil {
|
|
log.Error("Error on database checking user ", name, ": ", err)
|
|
return errors.New("An error happen on the database")
|
|
}
|
|
if num != 0 {
|
|
return errors.New("User name already exist")
|
|
}
|
|
|
|
hpass, salt, err := hashPass(pass)
|
|
if err != nil {
|
|
log.Error("Error hashing password: ", err)
|
|
return errors.New("An error happen storing the password")
|
|
}
|
|
return db.AddRawUser(name, hpass, salt, "")
|
|
}
|
|
|
|
func (db *pgDB) AddRawUser(name string, hpass []byte, salt []byte, role string) error {
|
|
u := User{
|
|
Username: name,
|
|
Password: hpass,
|
|
Salt: salt,
|
|
Role: role,
|
|
}
|
|
return db.sql.Insert(&u)
|
|
}
|
|
|
|
func (db *pgDB) GetRole(name string) (string, error) {
|
|
var u User
|
|
err := db.sql.Model(&u).Where("lower(username) = lower(?)", name).Select()
|
|
return u.Role, err
|
|
}
|
|
|
|
func (db *pgDB) SetRole(name, role string) error {
|
|
_, err := db.sql.Model(&User{}).
|
|
Set("role = ?", role).
|
|
Where("lower(username) = lower(?)", name).
|
|
Update()
|
|
return err
|
|
}
|
|
|
|
func (db *pgDB) ValidPassword(name string, pass string) bool {
|
|
var u User
|
|
err := db.sql.Model(&u).Where("lower(username) = lower(?)", name).Select()
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
hash, err := calculateHash(pass, u.Salt)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
if bytes.Compare(u.Password, hash) != 0 {
|
|
return false
|
|
}
|
|
|
|
_, err = db.sql.Model(&User{}).
|
|
Set("last_login = CURRENT_TIMESTAMP").
|
|
Where("id = ?", u.ID).
|
|
Update()
|
|
if err != nil {
|
|
log.Error("Error updating last login for ", u.Username, ": ", err)
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (db *pgDB) SetPassword(name string, pass string) error {
|
|
hash, salt, err := hashPass(pass)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = db.sql.Model(&User{}).
|
|
Set("password = ?, salt = ?", hash, salt).
|
|
Where("lower(username) = lower(?)", name).
|
|
Update()
|
|
return err
|
|
}
|
|
|
|
func (db *pgDB) ListUsers() ([]User, error) {
|
|
var users []User
|
|
err := db.sql.Model(&users).Select()
|
|
return users, err
|
|
}
|
|
|
|
func (db *pgDB) getUser(name string) (User, error) {
|
|
var user User
|
|
err := db.sql.Model(&user).
|
|
Where("lower(username) = lower(?)", name).
|
|
Select()
|
|
return user, err
|
|
}
|
|
|
|
func validUserName(name string) bool {
|
|
if len(name) < 3 {
|
|
return false
|
|
}
|
|
if !alphaNumeric(name) {
|
|
return false
|
|
}
|
|
switch name {
|
|
case "", "admin", "webmaster", "postmaster", "info", "root", "news", "trantor", "librarian", "library", "imperial":
|
|
return false
|
|
default:
|
|
return true
|
|
}
|
|
}
|
|
|
|
func hashPass(pass string) (hash []byte, salt []byte, err error) {
|
|
salt, err = genSalt()
|
|
if err != nil {
|
|
return
|
|
}
|
|
hash, err = calculateHash(pass, salt)
|
|
return
|
|
}
|
|
|
|
func genSalt() ([]byte, error) {
|
|
const saltLen = 64
|
|
|
|
b := make([]byte, saltLen)
|
|
_, err := rand.Read(b)
|
|
return b, err
|
|
}
|
|
|
|
func calculateHash(pass string, salt []byte) ([]byte, error) {
|
|
const (
|
|
N = 16384
|
|
r = 8
|
|
p = 1
|
|
keyLen = 32
|
|
)
|
|
|
|
bpass := []byte(pass)
|
|
return scrypt.Key(bpass, salt, N, r, p, keyLen)
|
|
}
|