#!/usr/bin/env python3
"""apis_generator.py - Generate a C++ interface that automates loading OpenCL.

Usage: apis_generator.py <headerPaths...>

The generated code looks roughly like this:
------------------------------------------------------------------------

// apis.h

CL_MACRO ( returnType, funcname, (fargs...), (callArgs...) )

"""

import os.path
import re
import sys

GENERATED_FILE_WARNING = """
/*
 * This file is generated by {}
 * Do not edit this file directly.
 */""".format(os.path.basename(__file__))

MACRO_GUARD = """
#ifndef CL_MACRO
#error You need to define CL_MACRO before including apis
#endif"""


def include_for_header(header):
  return '#include <CL/{}>'.format(header)


def extract_license_lines(lines):
  license_lines = []
  for line in lines:
    license_lines.append(line)
    if line.find('*/') != -1:
      return license_lines
  sys.exit("License text didn't terminate")


assert (extract_license_lines(['/* LICENSE */',
                               'something']) == ['/* LICENSE */'])
assert (extract_license_lines(['/* LICENSE', ' * TEXT */',
                               'something']) == ['/* LICENSE', ' * TEXT */'])
assert (extract_license_lines(['/* LICENSE', ' * TEXT', ' */', 'something'
                              ]) == ['/* LICENSE', ' * TEXT', ' */'])


def parse_arg_strs(str):
  paren_depth = 0
  current_arg = ''
  ret = []
  for c in str:
    if c == '(':
      paren_depth += 1
    elif c == ')':
      paren_depth -= 1
    if c == ',' and paren_depth == 0:
      ret.append(current_arg)
      current_arg = ''
    else:
      current_arg += c
  if current_arg != '':
    ret.append(current_arg)
  return ret


def process_type(raw):
  # strip redundant [] (where one is before the name)
  raw = re.sub(r'(\[[0-9]*\])\s*(\w+)\s*\[[0-9]*\]', r'\2\1', raw)
  # strip cases where the name comment hinted at the number of elements in an array
  raw = re.sub(r'\*\s*(\w+)\s*\[[0-9]+\]', r'*\1', raw)
  raw = ' '.join(raw.split())
  return raw


def parse_api(api_signature):
  m = None

  api_signature = re.sub('extern', '', api_signature)
  api_signature = re.sub('CL_\w+', '', api_signature)

  m = re.match(r'\s*(.*)\s+(\w+)\((.*)\)\s*;', api_signature)
  if m == None:
    print(api_signature)

  assert (m is not None)
  assert (len(m.groups()) == 3)
  arg_strs = None
  if re.match('\s*void\s*', m.group(3)):
    arg_strs = []
  else:
    arg_strs = parse_arg_strs(m.group(3))
  args = []
  for arg_str in arg_strs:
    nm = re.search(r'(\w+)\s*(\)|\[|$)', arg_str)
    assert (nm is not None)
    args.append({'type': process_type(arg_str), 'name': nm.group(1)})
  return {'return': m.group(1).strip(), 'name': m.group(2), 'args': args}


def extract_apis(lines):
  state = 'scanning'
  apis = []
  api_signature = ''
  for line in lines:
    if state == 'scanning':
      if line.find('CL_API_ENTRY') != -1 and line.find('typedef') == -1:
        api_signature = line
        if line.find(';') != -1:
          apis.append(
              parse_api(
                  api_signature.replace('/*', '').replace('*/', '').replace(
                      'CL_CALLBACK ', '')))
          api_signature = ''
        else:
          state = 'expectAPILine'
    elif state == 'expectAPILine':
      api_signature += ' ' + line
      if line.find(';') != -1:
        apis.append(
            parse_api(
                api_signature.replace('/*', '').replace('*/', '').replace(
                    'CL_CALLBACK ', '')))
        api_signature = ''
        state = 'scanning'
  return apis


def generate_apis(apis):
  print(GENERATED_FILE_WARNING)
  print()
  print(MACRO_GUARD)
  print()

  for api in apis:
    fargs = (arg['type'] for arg in api['args'])
    cargs = (arg['name'] for arg in api['args'])
    print('CL_MACRO( {}, {}, ({}), ({}) )\n'.format(api['return'], api['name'],
                                                    ', '.join(fargs),
                                                    ', '.join(cargs)))


def main():
  headers = sys.argv[1:]
  apis = []

  with open(headers[0]) as header:
    lines = [line.strip() for line in header.readlines()]
    license_lines = extract_license_lines(lines)

  for header_name in headers:
    with open(header_name) as header:
      lines = [line.strip() for line in header.readlines()]
      apis = apis + extract_apis(lines)

  generate_apis(apis)


if __name__ == '__main__':
  main()