package database

import (
	log "github.com/cihub/seelog"

	"bytes"
	"crypto/rand"
	"errors"

	"golang.org/x/crypto/scrypt"
)

type user struct {
	ID       int
	Username string
	Password []byte
	Salt     []byte
	Role     string
}

func (db *pgDB) AddUser(name string, pass string) error {
	if !validUserName(name) {
		return errors.New("Invalid user name")
	}
	num, err := db.sql.Model(&user{}).Where("username = ?", name).Count()
	if err != nil {
		log.Error("Error on database checking user ", name, ": ", err)
		return errors.New("An error happen on the database")
	}
	if num != 0 {
		return errors.New("User name already exist")
	}

	hpass, salt, err := hashPass(pass)
	if err != nil {
		log.Error("Error hashing password: ", err)
		return errors.New("An error happen storing the password")
	}
	u := user{
		Username: name,
		Password: hpass,
		Salt:     salt,
		Role:     "",
	}
	return db.sql.Create(&u)
}

func (db *pgDB) GetRole(name string) (string, error) {
	var u user
	err := db.sql.Model(&u).Where("username = ?", name).Select()
	return u.Role, err
}

func (db *pgDB) ValidPassword(name string, pass string) bool {
	var u user
	err := db.sql.Model(&u).Where("username = ?", name).Select()
	if err != nil {
		return false
	}

	hash, err := calculateHash(pass, u.Salt)
	if err != nil {
		return false
	}
	return bytes.Compare(u.Password, hash) == 0
}

func (db *pgDB) SetPassword(name string, pass string) error {
	hash, salt, err := hashPass(pass)
	if err != nil {
		return err
	}
	_, err = db.sql.Model(user{}).
		Set("pass = ?, salt = ?", hash, salt).
		Where("username = ?", name).
		Update()
	return err
}

func validUserName(name string) bool {
	return name != ""
}

func hashPass(pass string) (hash []byte, salt []byte, err error) {
	salt, err = genSalt()
	if err != nil {
		return
	}
	hash, err = calculateHash(pass, salt)
	return
}

func genSalt() ([]byte, error) {
	const saltLen = 64

	b := make([]byte, saltLen)
	_, err := rand.Read(b)
	return b, err
}

func calculateHash(pass string, salt []byte) ([]byte, error) {
	const (
		N      = 16384
		r      = 8
		p      = 1
		keyLen = 32
	)

	bpass := []byte(pass)
	return scrypt.Key(bpass, salt, N, r, p, keyLen)
}