#include "config.h"

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <ctype.h>
#include <fcntl.h>
#include <glob.h>
#include <sys/uio.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/time.h>
#include <netinet/in.h>

#include "sftp.h"

int
tokenize(char *str, char **args, int maxargs) {
	int i;
	skipspaces(str);

	for (i=0; str && *str && i<maxargs-1; i++) {
		char *s = strchr(str, ' ');
		if (s != NULL)
			*s++ = '\0';
		args[i] = str;
		str = s;
		skipspaces(str);
	}
	args[i] = NULL;
	return i;
}

char *
findbasename(char *str) {
	char *base;
	base = strrchr(str, '/');
	if (base == NULL)
		base = str;
	else
		base++;
	return base;
}

int
find_match(char *str, map *action) {
	int i, guess = -1, badguess = 0;
	for (i=0; action[i].str; i++) {
		if (strcmp(action[i].str, str) == 0)
			return i;
		if (strncmp(action[i].str, str, strlen(str)) == 0) {
			if (guess < 0 && !badguess)
				guess = i;
			else {
				badguess = 1;
				guess = -1;
			}
		}
	}
	return guess;
}

int
do_action(char *str, map *action) {
	int index, nargs;
	char *args[MAXARGS];

	if (str[strlen(str)-1] == '\n')
		str[strlen(str)-1] = 0;
	skipspaces(str);
	if (*str == 0)
		return 0;
	if (*str == '!') {
		system(str+1);
		return 0;
        }
	nargs = tokenize(str, args, MAXARGS);
	index = find_match(args[0], action);
	if (index >= 0)
		return action[index].func(args, nargs);
	else
		printf ("%s: invalid command\n", args[0]);
	return -1;
}

int
readn(int fd, void *buf, size_t count) {
	int n, total = 0;

	while (total < count) {
		n = read(fd, (char*)buf+total, count-total);
		if (n == 0)
			return total;
		else
			total += n;
	}
	return total;
}


int
send_message(int sock, message m) {
	struct iovec v[4];
	u_int32_t nlen = htonl(m.len);

	v[0].iov_base = (void *)&m.channel;
	v[0].iov_len = 1;

	v[1].iov_base = (void *)&m.command;
	v[1].iov_len = 1;

	v[2].iov_base = (void *) &nlen;
	v[2].iov_len = 4;

	v[3].iov_base = (void *)m.data;
	v[3].iov_len = m.len;

	writev(sock, v, 4);
	return 0;

}

int
recv_message(int sock, message *m) {
	int ret;
	struct iovec v[3];

	v[0].iov_base = (void *) &m->channel;
	v[0].iov_len = 1;

	v[1].iov_base = (void *) &m->command;
	v[1].iov_len = 1;

	v[2].iov_base = (void *) &m->len;
	v[2].iov_len = 4;

	ret = readv(sock, v, 3);
	if (ret <= 0)
		return (-1);
	m->len = ntohl(m->len);

	if (m->len > 0) {
		m->data = malloc(m->len);
		ret = readn(sock, m->data, m->len);
		if (ret < m->len) {
			free(m->data);
			memset(m, 0, sizeof(message));
			return (-1);
		}
	}
	else
		m->data = NULL;
	return 0;
}

int
query_message(int sock, message *m) {
	int ret;
	struct timeval tv;
	fd_set fds;

	memset(&tv, 0, sizeof(tv));
	FD_ZERO(&fds);
	FD_SET(sock, &fds);
	ret = select (sock+1, &fds, NULL, NULL, &tv);
	if (ret == 0) {
		m->data = NULL;
		return -1;
	}
	else
		return recv_message(sock, m);
}

message
_message(u_int8_t command, void *data, u_int32_t len) {
	return _data_message(0, command, data, len);
}

message
_data_message(u_int8_t channel, u_int8_t command, void *data, u_int32_t len) {
	message m;
	m.channel = channel;
	m.command = command;
	m.len = len;
	m.data = data;
	return m;
}

static int
do_send1file(int sock, char *file) {
	struct stat statbuf;
	u_int32_t size;
	u_int16_t mode;
	char modestr[9];
	int n, fd;
	char buf[BUFSIZE];
	char *base;
	message m;
	char *str;
	u_int8_t c = 1;

	if (stat(file, &statbuf) || !S_ISREG(statbuf.st_mode)) {
		str = "File doesn't exist or is not a regular file";
		send_message(sock, _data_message(c, ERROR, str, strlen(str)+1));
		return -1;
	}
	fd = open(file, O_RDONLY);
	if (fd < 0) {
		str = "File could not be opened";
		send_message(sock, _data_message(c, ERROR, str, strlen(str)+1));
		return -1;
	}

	base = findbasename(file);
	mode = statbuf.st_mode & (S_IRWXU|S_IRWXG|S_IRWXO);
	modetostr(mode, modestr);
	size = htonl(statbuf.st_size);

	send_message(sock, _message(SENDFILE, &c, 1));
	send_message(sock, _data_message(c, FILENAME, base, strlen(base)+1));
	recv_message(sock, &m);
	if (!(m.command == FILEOK && m.len == 1 && *(char *)m.data != 0)) {
		free(m.data);
		return -1;
	}
	free(m.data);
	send_message(sock, _data_message(c, FILESIZE, &size, 4));
	send_message(sock, _data_message(c, FILEMODE, modestr, 9));

	while ((n = read(fd, buf, BUFSIZE)) > 0) {
		send_message(sock, _data_message(c, DATA, buf, n));
	}
	send_message(sock, _data_message(c, ENDDATA, NULL, 0));
	close (fd);
	return 0;
}

int
do_sendfile(int sock, char *file) {
	int ret, i;
	glob_t globbuf;

	if (*file == 0) {
		send_message(sock, _message(NOFILEMATCH, NULL, 0));
		return 0;
	}
	ret = glob(file, 0, NULL, &globbuf);
	if (ret != 0)
		return -1;
	for (i = 0; i < globbuf.gl_pathc; i++)
		do_send1file (sock, globbuf.gl_pathv[i]);
	globfree(&globbuf);
	return i;
}

int
do_recvfile(int sock, u_int8_t channel) {
	u_int32_t size;
	u_int16_t mode;
	char *file = NULL;
	int fd = -1, total = 0;
	u_int8_t status;
	message m;
	int namelen, dots = 0, newdots, i;

	while (1) {
		recv_message(sock, &m);
		switch (m.command) {
			case SENDFILE:
				channel = *(u_int8_t*)m.data;
			case FILENAME:
				file = m.data;
				fd = open(file, O_WRONLY|O_CREAT, 0644);
				fputs(file, stdout);
				fputs(": ", stdout);
				fflush(stdout);
				namelen = (strlen(file) + 2) % 80;
				status = (fd >= 0);
				send_message(sock,
					     _data_message(channel, FILEOK,
							   &status, 1));
				if (status == 0)
					return -1;
				break;
			case FILESIZE:
				if (fd < 0 || m.len != 4)
					break;
				size = *(u_int32_t *)m.data;
				size = ntohl(size);
				break;
			case FILEMODE:
				if (fd < 0 || m.len != 9)
					break;
				mode = strtomode((char *)m.data);
				fchmod(fd, mode);
				break;
			case DATA:
				if (fd < 0)
					break;
				write(fd, m.data, m.len);
				total += m.len;
				newdots = ((float)total/size) * (80 - namelen);
				for (i = 0; i < newdots - dots; i++)
					putchar('.');
				fflush(stdout);
				dots = newdots;
				break;
			case ENDDATA:
				close(fd);
				return size;
		}
		if (m.data)
			free(m.data);
	}
}

struct timeval
timediff(struct timeval tv1, struct timeval tv2) {
	struct timeval tv;
	tv.tv_sec = tv1.tv_sec - tv2.tv_sec;
	tv.tv_usec = tv1.tv_usec - tv2.tv_usec;
	if (tv.tv_usec < 0) {
		tv.tv_usec+=1000000;
		tv.tv_sec--;
	}
	return tv;
}

sftp_channel *
new_channel() {
	sftp_channel *channel = (sftp_channel *) malloc(sizeof(sftp_channel));
	memset (channel, 0, sizeof(sftp_channel));
	return channel;
}

u_int16_t
strtomode(char *str) {
	u_int16_t mode = 0;
	if (str[0] == 'r') mode |= S_IRUSR;
	if (str[1] == 'w') mode |= S_IWUSR;
	if (str[2] == 'x') mode |= S_IXUSR;
	if (str[3] == 'r') mode |= S_IRGRP;
	if (str[4] == 'w') mode |= S_IWGRP;
	if (str[5] == 'x') mode |= S_IXGRP;
	if (str[6] == 'r') mode |= S_IROTH;
	if (str[7] == 'w') mode |= S_IWOTH;
	if (str[8] == 'x') mode |= S_IXOTH;
	return mode;
}

void
modetostr(u_int16_t mode, char *str) {
	strncpy(str, "rwxrwxrwx", 9);
	if (!(mode & S_IRUSR)) str[0] = '-';
	if (!(mode & S_IWUSR)) str[1] = '-';
	if (!(mode & S_IXUSR)) str[2] = '-';
	if (!(mode & S_IRGRP)) str[3] = '-';
	if (!(mode & S_IWGRP)) str[4] = '-';
	if (!(mode & S_IXGRP)) str[5] = '-';
	if (!(mode & S_IROTH)) str[6] = '-';
	if (!(mode & S_IWOTH)) str[7] = '-';
	if (!(mode & S_IXOTH)) str[8] = '-';
}
