This repository has been archived on 2025-03-01. You can view files and clone it, but cannot push or open issues or pull requests.
trantor/lib/database/users.go
2020-11-30 19:03:31 +00:00

202 lines
4.6 KiB
Go

package database
import (
"time"
log "github.com/cihub/seelog"
"bytes"
"crypto/rand"
"errors"
"regexp"
"golang.org/x/crypto/scrypt"
)
var alphaNumeric = regexp.MustCompile(`^[a-zA-Z0-9_\-\.]+$`).MatchString
type User struct {
ID int `pg:"type:serial"`
Username string `pg:"type:varchar(255),unique"`
Password []byte
Salt []byte
Role string `pg:"type:varchar(255)"`
LastLogin time.Time
}
func (db *pgDB) AddUser(name string, pass string) error {
if !validUserName(name) {
return errors.New("Invalid user name. Username needs to have at least 3 characters and can only be letters, numbers, '-', '_' and '.'.")
}
num, err := db.sql.Model(&User{}).Where("lower(username) = lower(?)", name).Count()
if err != nil {
log.Error("Error on database checking user ", name, ": ", err)
return errors.New("An error happen on the database")
}
if num != 0 {
return errors.New("User name already exist")
}
hpass, salt, err := hashPass(pass)
if err != nil {
log.Error("Error hashing password: ", err)
return errors.New("An error happen storing the password")
}
return db.AddRawUser(name, hpass, salt, "")
}
func (db *pgDB) SetAdminUser(name string, pass string) error {
if !validAdminUserName(name) {
return errors.New("Invalid admin user name. Username needs to have at least 3 characters and can only be letters, numbers, '-', '_' and '.'.")
}
num, err := db.sql.Model(&User{}).Where("lower(username) = lower(?)", name).Count()
if err != nil {
log.Error("Error on database checking user ", name, ": ", err)
return errors.New("An error happen on the database")
}
if num != 0 {
err := db.SetRole(name, "admin")
if err != nil {
log.Error("Error updating admin user ", name, ": ", err)
return errors.New("Error updating admin user")
}
return db.SetPassword(name, pass)
}
hpass, salt, err := hashPass(pass)
if err != nil {
log.Error("Error hashing password: ", err)
return errors.New("An error happen storing the password")
}
return db.AddRawUser(name, hpass, salt, "admin")
}
func (db *pgDB) AddRawUser(name string, hpass []byte, salt []byte, role string) error {
u := User{
Username: name,
Password: hpass,
Salt: salt,
Role: role,
}
_, err := db.sql.Model(&u).Insert()
return err
}
func (db *pgDB) GetRole(name string) (string, error) {
var u User
err := db.sql.Model(&u).Where("lower(username) = lower(?)", name).Select()
return u.Role, err
}
func (db *pgDB) SetRole(name, role string) error {
_, err := db.sql.Model(&User{}).
Set("role = ?", role).
Where("lower(username) = lower(?)", name).
Update()
return err
}
func (db *pgDB) ValidPassword(name string, pass string) bool {
var u User
err := db.sql.Model(&u).Where("lower(username) = lower(?)", name).Select()
if err != nil {
return false
}
hash, err := calculateHash(pass, u.Salt)
if err != nil {
return false
}
if bytes.Compare(u.Password, hash) != 0 {
return false
}
_, err = db.sql.Model(&User{}).
Set("last_login = CURRENT_TIMESTAMP").
Where("id = ?", u.ID).
Update()
if err != nil {
log.Error("Error updating last login for ", u.Username, ": ", err)
}
return true
}
func (db *pgDB) SetPassword(name string, pass string) error {
hash, salt, err := hashPass(pass)
if err != nil {
return err
}
_, err = db.sql.Model(&User{}).
Set("password = ?, salt = ?", hash, salt).
Where("lower(username) = lower(?)", name).
Update()
return err
}
func (db *pgDB) ListUsers() ([]User, error) {
var users []User
err := db.sql.Model(&users).Select()
return users, err
}
func (db *pgDB) getUser(name string) (User, error) {
var user User
err := db.sql.Model(&user).
Where("lower(username) = lower(?)", name).
Select()
return user, err
}
func validAdminUserName(name string) bool {
if len(name) < 3 {
return false
}
if !alphaNumeric(name) {
return false
}
return true
}
func validUserName(name string) bool {
if !validAdminUserName(name) {
return false
}
switch name {
case "", "admin", "webmaster", "postmaster", "info", "root", "news", "trantor", "librarian", "library", "imperial":
return false
default:
return true
}
}
func hashPass(pass string) (hash []byte, salt []byte, err error) {
salt, err = genSalt()
if err != nil {
return
}
hash, err = calculateHash(pass, salt)
return
}
func genSalt() ([]byte, error) {
const saltLen = 64
b := make([]byte, saltLen)
_, err := rand.Read(b)
return b, err
}
func calculateHash(pass string, salt []byte) ([]byte, error) {
const (
N = 16384
r = 8
p = 1
keyLen = 32
)
bpass := []byte(pass)
return scrypt.Key(bpass, salt, N, r, p, keyLen)
}