#!/usr/bin/env python3

import fcntl
import errno
import posix
import time
import signal
import os
import pwd
import sys
import getopt
import traceback
import datetime
import mimetypes
try:
	from urllib.parse import urlparse
	from urllib.parse import urlunparse
except ImportError:
	from urlparse import urlparse
	from urlparse import urlunparse
import socket
import select
import subprocess

# Http server based on recipes 511453,511454 from code.activestate.com by Pierre Quentel"""
# Added support for indexes, access tests, proper handle of SystemExit exception, fixed couple of errors and vulnerbilities, getopt, lockfiles, daemonize etc. by Jakub Kruszona-Zawadzki

# the dictionary holding one client handler for each connected client
# key = client socket, value = instance of (a subclass of) ClientHandler
client_handlers = {}

def emptybuff():
	if sys.version<'3':
		return ''
	else:
		return bytes(0)

if sys.version<'3':
	buff_type = str
else:
	buff_type = bytes

# =======================================================================
# The server class. Creating an instance starts a server on the specified
# host and port
# =======================================================================
class Server(object):
	def __init__(self,host='localhost',port=80):
		if host=='any':
			host=''
		self.host,self.port = host,port
		self.socket = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
		self.socket.setblocking(0)
		self.socket.setsockopt(socket.SOL_SOCKET,socket.SO_REUSEADDR,1)
		self.socket.bind((host,port))
		self.socket.listen(50)

# =====================================================================
# Generic client handler. An instance of this class is created for each
# request sent by a client to the server
# =====================================================================
class ClientHandler(object):
	blocksize = 2048

	def __init__(self, server, client_socket, client_address):
		self.server = server
		self.client_address = client_address
		self.client_socket = client_socket
		self.client_socket.setblocking(0)
		self.host = socket.getfqdn(client_address[0])
		self.incoming = emptybuff() # receives incoming data
		self.outgoing = emptybuff()
		self.writable = False
		self.close_when_done = True
		self.response = []
 
	def handle_error(self):
		self.close()

	def handle_read(self):
		"""Reads the data received"""
		try:
			buff = self.client_socket.recv(1024)
			if not buff:  # the connection is closed
				self.close()
			else:
				# buffer the data in self.incoming
				self.incoming += buff #.write(buff)
				self.process_incoming()
		except socket.error:
			self.close()

	def process_incoming(self):
		"""Test if request is complete ; if so, build the response
		and set self.writable to True"""
		if not self.request_complete():
			return
		self.response = self.make_response()
		self.outgoing = emptybuff()
		self.writable = True

	def request_complete(self):
		"""Return True if the request is complete, False otherwise
		Override this method in subclasses"""
		return True

	def make_response(self):
		"""Return the list of strings or file objects whose content will
		be sent to the client
		Override this method in subclasses"""
		return ["xxx"]

	def handle_write(self):
		"""Send (a part of) the response on the socket
		Finish the request if the whole response has been sent
		self.response is a list of strings or file objects
		"""
		if len(self.outgoing)==0 and self.response:
			if isinstance(self.response[0],buff_type):
				self.outgoing = self.response.pop(0)
			else:
				self.outgoing = self.response[0].read(self.blocksize)
				if not self.outgoing:
					self.response[0].close()
					self.response.pop(0)
		if self.outgoing:
			try:
				sent = self.client_socket.send(self.outgoing)
			except socket.error:
				self.close()
				return
			if sent < len(self.outgoing):
				self.outgoing = self.outgoing[sent:]
			else:
				self.outgoing = emptybuff()
		if len(self.outgoing)==0 and not self.response:
			if self.close_when_done:
				self.close() # close socket
			else:
				# reset for next request
				self.writable = False
				self.incoming = emptybuff()

	def close(self):
		while self.response:
			if not isinstance(self.response[0],buff_type):
				self.response[0].close()
			self.response.pop(0)
		del client_handlers[self.client_socket]
		self.client_socket.close()

# ============================================================================
# Main loop, calling the select() function on the sockets to see if new 
# clients are trying to connect, if some clients have sent data and if those
# for which the response is complete are ready to receive it
# For each event, call the appropriate method of the server or of the instance
# of ClientHandler managing the dialog with the client : handle_read() or 
# handle_write()
# ============================================================================
def loop(server,handler,timeout=30):
	while True:
		k = list(client_handlers.keys())
		# w = sockets to which there is something to send
		# we must test if we can send data
		w = [ cl for cl in client_handlers if client_handlers[cl].writable ]
		# the heart of the program ! "r" will have the sockets that have sent
		# data, and the server socket if a new client has tried to connect
		r,w,e = select.select(k+[server.socket],w,k,timeout)
		for e_socket in e:
			client_handlers[e_socket].handle_error()
		for r_socket in r:
			if r_socket is server.socket:
				# server socket readable means a new connection request
				try:
					client_socket,client_address = server.socket.accept()
					client_handlers[client_socket] = handler(server,client_socket,client_address)
				except socket.error:
					pass
			else:
				# the client connected on r_socket has sent something
				client_handlers[r_socket].handle_read()
		w = set(w) & set(client_handlers.keys()) # remove deleted sockets
		for w_socket in w:
			client_handlers[w_socket].handle_write()


# =============================================================
# An implementation of the HTTP protocol, supporting persistent
# connections and CGI
# =============================================================

class HTTP(ClientHandler):
	# parameters to override if necessary
	root = os.getcwd()				# the directory to serve files from
	index_files = ['index.cgi','index.html']	# index files for directories
	logging = True					# print logging info for each request ?
	blocksize = 2 << 16				# size of blocks to read from files and send

	def __init__(self, server, client_socket, client_address):
		super(HTTP,self).__init__(server, client_socket, client_address)
		self.method = None
		self.protocol = None
		self.postbody = None
		self.requestline = None
		self.headers = None
		self.url = None
		self.file_name = None
		self.path = None
		self.rest = None
		self.mngt_method = None

	def request_complete(self):
		"""In the HTTP protocol, a request is complete if the "end of headers"
		sequence ('\r\n\r\n') has been received
		If the request is POST, stores the request body in a StringIO before
		returning True"""
		term = '\r\n\r\n'
		if sys.version>='3':
			term = term.encode('ascii')
		terminator = self.incoming.find(term)
		if terminator == -1:
			return False
		if sys.version>='3':
			lines = self.incoming[:terminator].decode('ascii').split('\r\n')
		else:
			lines = self.incoming[:terminator].split('\r\n')
		self.requestline = lines[0]
		try:
			self.method,self.url,self.protocol = lines[0].strip().split()
			if not self.protocol.startswith("HTTP/1") or ( self.protocol[7]!='0' and self.protocol[7]!='1') or len(self.protocol)!=8:
				self.method = None
				self.protocol = "HTTP/1.1"
				self.postbody = None
				return True
		except Exception:
			self.method = None
			self.protocol = "HTTP/1.1"
			self.postbody = None
			return True
		# put request headers in a dictionary
		self.headers = {}
		for line in lines[1:]:
			k,v = line.split(':',1)
			self.headers[k.lower().strip()] = v.strip()
		# persistent connection
		close_conn = self.headers.get("connection","")
		if (self.protocol == "HTTP/1.1" and close_conn.lower() == "keep-alive"):
			self.close_when_done = False
		# parse the url
		_,_,path,params,query,fragment = urlparse(self.url)
		self.path,self.rest = path,(params,query,fragment)

		if self.method == 'POST':
			# for POST requests, read the request body
			# its length must be specified in the content-length header
			content_length = int(self.headers.get('content-length',0))
			body = self.incoming[terminator+4:]
			# request is incomplete if not all message body received
			if len(body)<content_length:
				return False
			self.postbody = body
		else:
			self.postbody = None

		return True

	def make_response(self):
		"""Build the response : a list of strings or files"""
		try:
			if self.method is None: # bad request
				return self.err_resp(400,'Bad request : %s' %self.requestline)
			resp_headers, resp_file = '',None
			if not self.method in ['GET','POST','HEAD']:
				return self.err_resp(501,'Unsupported method (%s)' %self.method)
			else:
				file_name = self.file_name = self.translate_path()
				if not file_name.startswith(HTTP.root+os.path.sep) and not file_name==HTTP.root:
					return self.err_resp(403,'Forbidden')
				elif not os.path.exists(file_name):
					return self.err_resp(404,'File not found')
				elif self.managed():
					response = self.mngt_method()
				elif not os.access(file_name,os.R_OK):
					return self.err_resp(403,'Forbidden')
				else:
					fstatdata = os.stat(file_name)
					if (fstatdata.st_mode & 0xF000) == 0x4000:	# directory
						for index in self.index_files:
							if os.path.exists(file_name+'/'+index) and os.access(file_name+'/'+index,os.R_OK):
								return self.redirect_resp(index)
					if (fstatdata.st_mode & 0xF000) != 0x8000:
						return self.err_resp(403,'Forbidden')
					ext = os.path.splitext(file_name)[1]
					c_type = mimetypes.types_map.get(ext,'text/plain')
					resp_line = "%s 200 Ok\r\n" %self.protocol
					size = fstatdata.st_size
					resp_headers = "Content-Type: %s\r\n" %c_type
					resp_headers += "Content-Length: %s\r\n" %size
					resp_headers += '\r\n'
					if sys.version>='3':
						resp_line = resp_line.encode('ascii')
						resp_headers = resp_headers.encode('ascii')
					if self.method == "HEAD":
						resp_string = resp_line + resp_headers
					elif size > HTTP.blocksize:
						resp_string = resp_line + resp_headers
						resp_file = open(file_name,'rb')
					else:
						resp_string = resp_line + resp_headers + \
							open(file_name,'rb').read()
					response = [resp_string]
					if resp_file:
						response.append(resp_file)
			self.log(200)
			return response
		except Exception:
			return self.err_resp(500,'Internal Server Error')

	def translate_path(self):
		"""Translate URL path into a path in the file system"""
		return os.path.realpath(os.path.join(HTTP.root,*self.path.split('/')))

	def managed(self):
		"""Test if the request can be processed by a specific method
		If so, set self.mngt_method to the method used
		This implementation tests if the script is in a cgi directory"""
		if self.is_cgi():
			self.mngt_method = self.run_cgi
			return True
		return False

	def is_cgi(self):
		"""Test if url points to cgi script"""
		if self.path.endswith(".cgi"):
			return True
		return False

	def run_cgi(self):
		if not os.access(self.file_name,os.X_OK):
			return self.err_resp(403,'Forbidden')
		# set CGI environment variables
		e = self.make_cgi_env()
		self.close_when_done = True
		if self.method == "HEAD":
			try:
				proc = subprocess.Popen(self.file_name, env=e, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
				cgiout, cgierr = proc.communicate()
				response = cgiout + cgierr
				if sys.version>='3':
					response = response.decode('latin-1')
				if not ( response.startswith('Content-Type:') or response.startswith('Status:') ):
					response = "Content-Type: text/plain\r\n\r\n" + response
			except Exception:
				response = "Content-Type: text/plain\r\n\r\n" + traceback.format_exc()
			# for HEAD request, don't send message body even if the script
			# returns one (RFC 3875)
			head_lines = []
			for line in response.split('\n'):
				if not line:
					break
				head_lines.append(line)
			response = '\n'.join(head_lines)
			if sys.version>='3':
				response = response.encode('latin-1')
			resp_line = "%s 200 Ok\r\n" %self.protocol
			if sys.version>='3':
				resp_line = resp_line.encode('ascii')
			return [resp_line + response]
		else:
			try:
				if self.postbody != None:
					proc = subprocess.Popen(self.file_name, env=e, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE)
					cgiout, cgierr = proc.communicate(self.postbody)
					response = cgiout + cgierr
					if sys.version>='3':
						response = response.decode('latin-1')
					if not ( response.startswith('Content-Type:') or response.startswith('Status:') ):
						response = "Content-Type: text/plain\r\n\r\n" + response
					resp_line = "%s 200 Ok\r\n" %self.protocol
					if sys.version>='3':
						resp_line = resp_line.encode('ascii')
					return [resp_line + response]
				else:
					proc = subprocess.Popen(self.file_name, env=e, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
					firstline = proc.stdout.readline()
					if sys.version>='3':
						firstline = firstline.decode('latin-1')
					if not ( firstline.startswith('Content-Type:') or firstline.startswith('Status:') ):
						firstline = "Content-Type: text/plain\r\n\r\n" + firstline
					resp_line = "%s 200 Ok\r\n" %self.protocol
					if sys.version>='3':
						resp_line = resp_line.encode('ascii')
						firstline = firstline.encode('ascii')
					return [resp_line,firstline,proc.stdout,proc.stderr]
			except Exception:
				response = "Content-Type: text/plain\r\n\r\n" + traceback.format_exc()
				resp_line = "%s 200 Ok\r\n" %self.protocol
				if sys.version>='3':
					resp_line = resp_line.encode('ascii')
					response = response.encode('latin-1')
				return [resp_line + response]
		# sys.stdout = save_stdout # restore sys.stdout
		# close connection in case there is no content-length header
		# resp_line = "%s 200 Ok\r\n" %self.protocol
		# if sys.version>='3':
		#	resp_line = resp_line.encode('ascii')
		# return [resp_line + response]

	def make_cgi_env(self):
		"""Set CGI environment variables"""
		env = {}
		env['PATH'] = os.environ['PATH']
		env['SERVER_SOFTWARE'] = "AsyncServer"
		env['SERVER_NAME'] = "AsyncServer"
		env['GATEWAY_INTERFACE'] = 'CGI/1.1'
		env['DOCUMENT_ROOT'] = HTTP.root
		env['SERVER_PROTOCOL'] = "HTTP/1.1"
		env['SERVER_PORT'] = str(self.server.port)

		env['REQUEST_METHOD'] = self.method
		env['REQUEST_URI'] = self.url
		env['PATH_TRANSLATED'] = self.translate_path()
		env['SCRIPT_NAME'] = self.path
		env['PATH_INFO'] = urlunparse(("","","",self.rest[0],"",""))
		env['QUERY_STRING'] = self.rest[1]
		if not self.host == self.client_address[0]:
			env['REMOTE_HOST'] = self.host
		env['REMOTE_ADDR'] = self.client_address[0]
		env['CONTENT_LENGTH'] = str(self.headers.get('content-length',''))
		for k in ['USER_AGENT','COOKIE','ACCEPT','ACCEPT_CHARSET',
			'ACCEPT_ENCODING','ACCEPT_LANGUAGE','CONNECTION']:
			hdr = k.lower().replace("_","-")
			env['HTTP_%s' %k.upper()] = str(self.headers.get(hdr,''))
		return env

	def redirect_resp(self,redirurl):
		"""Return redirect message"""
		resp_line = "%s 301 Moved Permanently\r\nLocation: %s\r\n" % (self.protocol,redirurl)
		if sys.version>='3':
			resp_line = resp_line.encode('ascii')
		self.close_when_done = True
		self.log(301)
		return [resp_line]

	def err_resp(self,code,msg):
		"""Return an error message"""
		resp_line = "%s %s %s\r\n" %(self.protocol,code,msg)
		if sys.version>='3':
			resp_line = resp_line.encode('ascii')
		self.close_when_done = True
		self.log(code)
		return [resp_line]

	def log(self,code):
		"""Write a trace of the request on stderr"""
		if HTTP.logging:
			date_str = datetime.datetime.now().strftime('[%d/%b/%Y %H:%M:%S]')
			sys.stderr.write('%s - - %s "%s" %s\n' %(self.host,date_str,self.requestline,code))


def mylock(filename):
	try:
		fd = posix.open(filename,posix.O_RDWR|posix.O_CREAT,438) # 438 = 0o666
	except IOError:
		return -1
	try:
		fcntl.flock(fd,fcntl.LOCK_EX|fcntl.LOCK_NB)
	except IOError:
		ex = sys.exc_info()[1]
		if ex.errno != errno.EAGAIN:
			posix.close(fd)
			return -1
		try:
			pid = int(posix.read(fd,100).strip())
			posix.close(fd)
			return pid
		except ValueError:
			posix.close(fd)
			return -2
	posix.ftruncate(fd,0)
	if sys.version_info[0]<3:
		posix.write(fd,"%u" % posix.getpid())
	else:
		posix.write(fd,("%u" % posix.getpid()).encode('utf-8'))
	return 0

def wdlock(fname,runmode,timeout):
	killed = 0
	for i in range(timeout):
		l = mylock(fname)
		if l==0:
			if runmode==2:
				if killed:
					return 0
				else:
					print("can't find process to terminate")
					return -1
			if runmode==3:
				print("mfscgiserv is not running")
				return 0
			print("lockfile created and locked")
			return 1
		elif l<0:
			if l<-1:
				print("lockfile is damaged (can't obtain pid - kill prevoius instance manually)")
			else:
				print("lockfile error")
			return -1
		else:
			if runmode==3:
				print("mfscgiserv pid:%u" % l)
				return 0
			if runmode==1:
				print("can't start: lockfile is already locked by another process")
				return -1
			if killed!=l:
				print("sending SIGTERM to lock owner (pid:%u)" % l)
				posix.kill(l,signal.SIGTERM)
				killed = l
			if (i%10)==0 and i>0:
				print("about %u seconds passed and lock still exists" % i)
			time.sleep(1)
	print("about %u seconds passed and lockfile is still locked - giving up" % timeout)
	return -1

if __name__=="__main__":
	locktimeout = 60
	daemonize = 1
	verbose = 0
	host = 'any'
	port = 9425
	rootpath="/usr/share/mfscgi"
	datapath="/var/lib/mfs"

	opts,args = getopt.getopt(sys.argv[1:],"hH:P:R:t:fv")
	for opt,val in opts:
		if opt=='-h':
			print("usage: %s [-H bind_host] [-P bind_port] [-R rootpath] [-t locktimeout] [-f [-v]] [start|stop|restart|test]\n" % sys.argv[0])
			print("-H bind_host : local address to listen on (default: any)\n-P bind_port : port to listen on (default: 9425)\n-R rootpath : local path to use as HTTP document root (default: /usr/share/mfscgi)\n-t locktimeout : how long to wait for lockfile (default: 60s)\n-f : run in foreground\n-v : log requests on stderr")
			os._exit(0)
		elif opt=='-H':
			host = val
		elif opt=='-P':
			port = int(val)
		elif opt=='-R':
			rootpath = val
		elif opt=='t':
			locktimeout = int(val)
		elif opt=='-f':
			daemonize = 0
		elif opt=='-v':
			verbose = 1

	lockfname = datapath + os.path.sep + '.mfscgiserv.lock'

	try:
		mode = args[0]
		if mode=='start':
			mode = 1
		elif mode=='stop':
			mode = 2
		elif mode=='test':
			mode = 3
		else:
			mode = 0
	except Exception:
		mode = 0

	rootpath = os.path.realpath(rootpath)

	pipefd = posix.pipe()

	if (mode==1 or mode==0) and daemonize:
# daemonize
		try:
			pid = os.fork()
		except OSError:
			e = sys.exc_info()[1]
			raise Exception("fork error: %s [%d]" % (e.strerror, e.errno))
		if pid>0:
			posix.read(pipefd[0],1)
			os._exit(0)
		try:
			os.chdir("/")
		except OSError:
			pass
		os.setsid()
		try:
			pid = os.fork()
		except OSError:
			if sys.version_info[0]<3:
				posix.write(pipefd[1],'0')
			else:
				posix.write(pipefd[1],bytes(1))
			e = sys.exc_info()[1]
			raise Exception("fork error: %s [%d]" % (e.strerror, e.errno))
		if pid>0:
			os._exit(0)

	if wdlock(lockfname,mode,locktimeout)==1:

		print("starting simple cgi server (host: %s , port: %u , rootpath: %s)" % (host,port,rootpath))

		if daemonize:
			os.close(0)
			os.close(1)
			os.close(2)
			if os.open("/dev/null",os.O_RDWR)!=0:
				raise Exception("can't open /dev/null as 0 descriptor")
			os.dup2(0,1)
			os.dup2(0,2)

			if sys.version_info[0]<3:
				posix.write(pipefd[1],'0')
			else:
				posix.write(pipefd[1],bytes(1))

		posix.close(pipefd[0])
		posix.close(pipefd[1])

		server = Server(host, port)

# launch the server on the specified port
		if not daemonize:
			if host!='any':
				print("Asynchronous HTTP server running on %s:%s" % (host,port))
			else:
				print("Asynchronous HTTP server running on port %s" % port)
		if not daemonize and verbose:
			HTTP.logging = True
		else:
			HTTP.logging = False
		HTTP.root = rootpath
		try:
			os.seteuid(pwd.getpwnam("mfs").pw_uid)
		except Exception:
			try:
				os.seteuid(pwd.getpwnam("nobody").pw_uid)
			except Exception:
				pass
		signal.signal(signal.SIGCHLD, signal.SIG_IGN)
		loop(server,HTTP)

	else:
		if sys.version_info[0]<3:
			posix.write(pipefd[1],'0')
		else:
			posix.write(pipefd[1],bytes(1))
		os._exit(0)
