#!/www/python/bin/python

"""gen_schema

Search a set of directories for Python modules and classes, and parse
the class docstrings to generate an object schema.  Lots of
project-specific information (eg. directories to search, classes to
exclude, ...) can be supplied via a project description file.
"""

# created 2001/08/08, Greg Ward (from the MEMS Exchange-specific gen_schema)

__revision__ = "$Id: gen_schema,v 1.12 2001/08/21 22:21:20 gward Exp $"

import sys, os, string, re
import getopt
import types
import traceback
from fnmatch import fnmatch
from cPickle import dump

from grouch.schema import ObjectSchema, ClassDefinition
from grouch.util import is_class_object, issubclass
from grouch.script_util import writenow, announce, error, warn


# -- Mid-level workers -------------------------------------------------
# (called by generate_schema())

def find_modules (dirs, base_dir=None, prefix=None, exclude=None):
    """find_modules(dirs : [string],
                    base_dir : string = None,
                    prefix : string = None,
                    exclude : [string] = None)
       -> [(modname : string, filename : string)]

    Searches a list of directories for Python module files (*.py).
    Returns a list of (modname, filename) tuples.  If 'prefix' is
    supplied, it is prepended (with a dot interpolated) to each module
    name.  If 'exclude' is supplied, it must be a list of
    fully-qualified module names; if a module is found in 'exclude'
    (after adding prefix), it will not be included in the returned list
    of modules.

    Example: dirs == ["foo"] and directory "foo" contains "bar.py" and
    "baz.py".  If no prefix or base_dir is supplied, returns
      [("foo.bar", "foo/bar.py"),
       ("foo.baz", "foo/baz.py")]

    If base_dir == "d" and prefix == "p", returns
      [("p.foo.bar", "d/foo/bar.py"),
       ("p.foo.baz", "d/foo/baz.py")]
    """
    modules = []                        # list of (modname, filename) tuples
    for dir in dirs:
        components = string.split(os.path.normpath(dir), os.sep)
        if os.pardir in components:     # disallow "../foo" after normpath()
            raise ValueError, \
                  ("invalid directory '%s': cannot contain '%s'" %
                   (dir, os.pardir))
        if components == [os.curdir]: # eg. dir == "." or ""
            components = []
        if prefix:
            components.insert(0, prefix)

        if base_dir:
            real_dir = os.path.normpath(os.path.join(base_dir, dir))
        else:
            real_dir = os.path.normpath(dir)

        for basename in os.listdir(real_dir):
            # Skip files that aren't Python modules
            if not fnmatch(basename, "*.py") or basename == "__init__.py":
                continue

            real_filename = os.path.join(real_dir, basename)
            bare_modname = os.path.splitext(basename)[0]
            modname = ".".join(components + [bare_modname])

            if modname not in exclude:
                modules.append((modname, real_filename))

    return modules


class ClassInfo:
    """
    Encapsulates all the information we need about a class in order to
    add it to an object schema.  find_classes() returns a list of
    ClassInfo instances to find_all_classes(), which incorporates that
    in a dictionary mapping module names to lists of ClassInfo
    instances.

    Instance attributes:
      bare_name : string
        the bare class name as seen in the "class" statement
      full_name :
        the fully-qualified class name (ie. including module name)
      base_classes : [string]
        the list of base class names from the class statement
      docstring : string
        the class docstring as a single (probably multi-line) string
    """

    def __init__ (self, bare_name, full_name, base_classes, docstring):
        self.bare_name = bare_name
        self.full_name = full_name
        self.base_classes = base_classes
        self.docstring = docstring

    def __str__ (self):
        return self.full_name

    def __repr__ (self):
        return "<%s at %08x: %s>" % (self.__class__.__name__, id(self), self)

    def expand_base_classes (self, schema):
        """
        Ensure that every base class listed in self.base_classes is
        either a class or an alias for a class in 'schema'.  Any base
        class names that are aliases are expanded to the full class
        name.  Raises ValueError if any base class names are bad.
        """
        for i in range(len(self.base_classes)):
            base_name = self.base_classes[i]
            if not schema.get_class_definition(base_name):
                alias = schema.get_alias(base_name)
                if not alias:
                    raise ValueError(
                        "%s: invalid base class %r (no such class or alias)"
                        % (self.full_name, base_name))
                elif not alias.is_plain_instance_type():
                    raise ValueError(
                        "%s: invalid base class %r (alias to non-class)"
                        % (self.full_name, base_name))
                else:
                    self.base_classes[i] = alias.klass_name


def find_all_classes (modules, schema, exclude_classes=None):
    """find_all_classes(modules : [(modname : string, filename : string)],
                        schema : ObjectSchema,
                        exclude_classes : [string] = None)
       -> { modname:string : [ClassInfo] }

    Find all classes in a list of modules.  Returns a dictionary mapping
    module name to list of ClassInfo objects.  For each class found,
    adds two things to schema: an empty ClassDefinition, and an alias
    mapping the bare class name to its full name (eg. for class
    foo.bar.FooBar, the alias maps "FooBar" to "foo.bar.FooBar").
    """

    # Map module name to list of class objects
    module_classes = {}
    num_classes = 0

    announce("looking for classes...\n")
    for (modname, filename) in modules:
        klasses = find_classes(filename, modname, exclude_classes)
        num_classes += len(klasses)

        ok_klasses = []
        for klass in klasses:
            bare_name = klass.bare_name
            full_name = klass.full_name
            alias = schema.get_alias(bare_name)

            # Uh-oh: there's an alias with the same name as this class,
            # which could be a problem.
            if alias:
                # The alias is actually for this class, so it's not a
                # problem after all.
                if (alias.is_instance_type() and
                    alias.get_class_name() == full_name):
                    ok_klasses.append(klass)

                # It's an alias for something else -- barf!
                else:
                    error("%s: class name conflict: "
                          "%s is already an alias for %s"
                          % (full_name, bare_name, alias))

            # No aliases here, mate -- add the class and an alias for
            # it to the schema.
            else:
                klass_def = schema.get_class_definition(full_name)
                if klass_def is None:
                    klass_def = ClassDefinition(full_name, schema)
                    schema.add_class(klass_def)
                schema.add_alias(bare_name, full_name)
                ok_klasses.append(klass)

            module_classes[modname] = ok_klasses

        announce("\nmodule %s:\n" % modname, threshold=2)
        for klass in klasses:
            announce("  %s\n" % klass, threshold=2)

    # This has to be done after scanning all modules, because it depends
    # on us having seen all classes in the application.
    for klasses in module_classes.values():
        for klass in klasses:
            try:
                klass.expand_base_classes(schema)
            except ValueError, err:
                # Should we exclude the class from the schema entirely
                # if this happens?  Right now we just refrain from giving
                # it any base classes if one of its base classes has
                # a problem.
                error(str(err))
            else:
                klass_def = schema.get_class_definition(klass.full_name)
                klass_def.set_bases(klass.base_classes)

    announce("found %d classes\n" % num_classes)
    return module_classes


def get_names (nodes):
    from compiler import ast
    names = []
    for node in nodes:
        if isinstance(node, ast.Name):
            names.append(node.name)
        elif isinstance(node, ast.Getattr):
            # Dotted name: "foo.bar.baz" in source becomes
            # Getattr(Getattr(Name(foo),bar),baz) in the AST.  Unwind
            # this stack to reconstruct the original string.
            cur_node = node
            stack = []
            while isinstance(cur_node, ast.Getattr):
                stack.append(cur_node.attrname)
                cur_node = cur_node.expr
            assert isinstance(cur_node, ast.Name), \
                   "expected Name at bottom of Getattr stack"
            stack.append(cur_node.name)
            stack.reverse()
            name = string.join(stack, ".")
            names.append(name)

    return names
         

def find_classes (filename, modname, exclude=None):
    """find_classes(filename : string,
                    modname : string,
                    exclude : [string] = None)
       -> [ClassInfo]

    Parses the Python source file 'filename' looking for classes.
    Assumes this source file contains the module 'modname'.  Returns a
    list of ClassInfo instances, one for each class in the file.

    Eg. if parsing a file foo.py which (according to 'modname') contains
    a module 'foo', this code:
      class Foo (bar.Bar):
        '''
        foo
        bar
        '''
    results in a ClassInfo instance like:
      bare_name = "Foo"
      full_name = "foo.Foo"
      base_classes = ["bar.Bar"]
      docstring = "\n    foo\n    bar\n    "
    
    No attempt is made to determine what "bar.Bar" really refers to --
    that would require interpreting the module, and in that case we
    might as well just import the damn thing.  Note that the
    ClassInfo.expand_base_classes() method, called by find_all_classes()
    just before returning, attempts to expand all base class names to
    their true names.
    """
    from compiler import parseFile, ast
    from parser import ParserError

    klasses = []

    try:
        module_ast = parseFile(filename)
    except ParserError:
        error("%s: unable to parse module (try importing it for more details)"
              % filename)

    for node in module_ast.node.nodes:
        if isinstance(node, ast.Class):
            if modname:
                fullname = modname + "." + node.name
            else:
                fullname = node.name

            if fullname in exclude:
                continue
            
            base_classes = get_names(node.bases)
            klasses.append(ClassInfo(node.name, fullname,
                                     base_classes, node.doc))

    return klasses


def parse_class_docstrings (modules, module_classes, schema):
    """parse_class_docstrings(modules : [(string, string)],
                              module_classes : {string : [ClassInfo]},
                              schema : ObjectSchema)
    """
    announce("parsing class docstrings...\n")
    for (modname, _) in modules:
        klasses = module_classes.get(modname, [])
        for klass in klasses:
            try:
                errors = parse_docstring(klass, schema)
            except ValueError, exc:
                error(str(exc))
                announce("failed to parse %s docstring\n" % klass, threshold=2)
            else:
                for e in errors:
                    warn(e)
                announce("parsed %s docstring\n" % klass, threshold=2)


# -- Parsing code ------------------------------------------------------
# (parse_class_docstrings() calls parse_docstring(), which uses
#  everything else in this section)

leading_ws_re = re.compile('^(\s*)')

def get_indent_level (s):
    """Return the number of spaces that 's' starts with."""
    m = leading_ws_re.match(s)
    return len(m.group(1))
    

def clean_docstring (doc, klass_name):
    lines = string.split(doc, "\n")
    assert lines, "string.split() returned empty list"
    if lines[0] and lines[0][0] == ' ':              # first line indented
        leading_indent = get_indent_level(lines[0])
        start_line = 0
    elif len(lines) > 1:                # look for first "real" line
        i = 1
        while i < len(lines) and not lines[i]: # skip over blank lines
            i += 1
        assert i < len(lines), "arg, screwy docstring"
        leading_indent = get_indent_level(lines[i])
        start_line = 1
    else:                               # single-line docstring
        leading_indent = 0
        start_line = 0

    for i in range(start_line, len(lines)):
        if not lines[i]:                # skip blanks
            continue
        if len(lines[i]) < leading_indent:
            raise ValueError, \
                  ("class %s: inconsistent indent in docstring "
                   "(line %d too short)" % (klass_name, i))
        if string.lstrip(lines[i][:leading_indent]) != "":
            raise ValueError, \
                  ("class %s: inconsistent indent in docstring "
                   "(line %d dedented relative to line %d)" %
                   (klass_name, i, start_line))

        lines[i] = re.sub(r'#.*', '', lines[i])
        lines[i] = string.rstrip(lines[i][leading_indent:])

    return lines

def find_attrs (lines, klass_name):
    i = 0
    while i < len(lines):
        line = lines[i]
        if line.startswith("Instance attributes:"):
            if line.endswith("none"):
                return None
            return i+1
        i += 1
    else:
        raise ValueError, \
              ("class %s: no \"Instance attributes:\" line in docstring" %
               klass_name)


_name_pat = r'[a-zA-Z_][a-zA-Z0-9_]*'
_dotted_name_pat = r'%s(?:\.%s)*' % (_name_pat, _name_pat)
_attr_line_re = re.compile(r'\s*(%s)\s*:\s*(.*)' % _name_pat)
_element_name_re = re.compile(r'(%s):(%s)' % (_name_pat, _dotted_name_pat))

def massage_typespec (typespec):
    # Now we have to massage the typespec so it can be parsed as
    # a ValueType.  Luckily, the docstring type specification
    # lanaguage is pretty similar to the ValueType type
    # specification language; the differences are:
    #   - container elements can have names as well as types,
    #     eg.  { key:keytype : value:valuetype }
    #     or   (val1:type1, val2:type2)
    #   - plain container types can include the name of the
    #     container, eg.
    #       dictionary {string : int}
    #   - typespecs can be trailed by default value description, eg.
    #       foo : int = 37
    #       bar : [string] = ["hello"]

    # Note that for named container elements, whitespace
    # matters!  {key:keytype : value:valuetype} is *not* the
    # same as {key : keytype : value : valuetype} -- the latter
    # is illegal.

    # Deal with the second exception first, since it's easiest.
    words = string.split(typespec, None, 1)
    if len(words) > 1:
        remainder = words[1]
        if words[0] in ('list', 'tuple', 'dictionary'):
            typespec = remainder
    else:
        remainder = None

    # Strip anything that looks like a default value, ie. " = ..."
    typespec = re.sub(r'\s*=.*', '', typespec)

    # Deal with the named-elements thing.  This is a kludge, but it
    # should work given the tight syntax constraints on named elements.
    if typespec[0] in "[({" or (remainder and remainder[0] in "[({"):
        typespec = _element_name_re.sub(r'\2', typespec)

    return typespec


def parse_docstring (klass, schema):
    """parse_docstring(klass : ClassInfo, schema : ObjectSchema)
       -> errors : [string]

    Parse the class docstring in 'klass' and use the docstring to update
    the class definition already in 'schema'.  Return a list of error
    messages resulting from parsing the docstring which should be presented
    to the user.  Raises ValueError if the docstring is a lost cause,
    ie. missing or completely unparseable.
    """
    
    klass_name = klass.full_name        # for error messages

    klass_def = schema.get_class_definition(klass_name)
    if klass_def.attrs or klass_def.all_attrs:
        raise ValueError, "class %s: already seen" % klass_name

    doc = klass.docstring
    if not doc:
        raise ValueError, "class %s: no docstring" % klass_name
    lines = clean_docstring(doc, klass_name)

    errors = []

    # Now 'lines' is a list of consistently-indented lines.  Find the
    # one that says "Instance attributes:", and start parsing attribute
    # type specifications from there.  Stop when we get back to
    # zero-indent level.
    start_line = find_attrs(lines, klass_name)
    if start_line is None:              # class is declared to have no
        return errors                   # instance attributes

    i = start_line
    while i < len(lines):
        line = lines[i]
        if not line:                    # skip blanks
            i += 1
            continue
        indent = get_indent_level(line)
        if indent == 0:                 # out of the attribute list
            break

        m = _attr_line_re.match(line)
        if m:
            (name, typespec) = m.group(1,2)
            typespec = massage_typespec(typespec)
            try:
                type = schema.parse_type(typespec)
            except ValueError, exc:
                errors.append("class %s, attribute %s: %s" %
                              (klass_name, name, exc))
            else:
                if type.is_any_type():
                    type.set_allow_any_instance(0)
                klass_def.add_attribute(name, type)
        else:
            errors.append("class %s: couldn't parse line %d of docstring: %s" %
                          (klass_name, i, `line`))

        # Read lines until back at the indent level of this line
        # (ie. skip the indented attribute description).
        i += 1
        while i < len(lines):
            if get_indent_level(lines[i]) <= indent:
                break
            i += 1

    # XXX ClassDefinition should expose len() of its attrs list
    if len(klass_def.attrs) == 0:
        errors.append("class %s: no attributes successfully parsed" %
                      klass_def.name)

    return errors

# parse_docstring ()


# -- High-level workers ------------------------------------------------
# (called from main())

def generate_schema (project, base_dir):
    schema = ObjectSchema()

    # Preparatory work -- stuff that needs to go in the schema, but
    # can't be discovered by searching for and parsing *.py files.
    for atomic_type in project.atomic_types:
        if type(atomic_type) is types.TupleType and len(atomic_type) == 2:
            schema.add_atomic_type(*atomic_type)
        else:
            schema.add_atomic_type(atomic_type)

    for name in project.forward_classes:
        cdef = ClassDefinition(name, schema)
        schema.add_class(cdef)

    for (name, value) in project.type_aliases:
        schema.add_alias(name, value)

    # The meat of the schema: class definitions for classes in whatever
    # *.py files we can find in project.dirs.
    if project.dirs:
        announce("searching for modules...")
        modules = find_modules(project.dirs,
                               base_dir=base_dir,
                               prefix=project.prefix,
                               exclude=project.exclude_modules)
        announce("found %d modules\n" % len(modules))
    else:
        modules = []

    project.add_extra_modules(modules, base_dir)

    module_classes = find_all_classes(modules, schema, project.exclude_classes)
    parse_class_docstrings(modules, module_classes, schema)

    if project.post_parse_hook:
        project.post_parse_hook(schema)

    # Finish up all class definitions -- ie. look at the inheritance tree
    # and gather up the list of all attributes that should be in instances
    # of a class, including those inherited from superclasses.
    for klass_name in schema.get_class_names():
        klass_def = schema.get_class_definition(klass_name)
        klass_def.finish_definition()

    return schema


def write_schema (schema, text_filename, pickle_filename):

    if text_filename:
        announce("writing object schema to %s..." % text_filename)
        schema_file = open(text_filename, "w")
        schema.write_aliases(schema_file)
        schema_file.write("\n\n")

        for klass_name in schema.get_class_names():
            klass_def = schema.get_class_definition(klass_name)
            klass_def.write(schema_file)
            schema_file.write("\n")
        schema_file.close()
        announce("\n")

    if pickle_filename:
        announce("pickling object schema to %s..." % pickle_filename)
        schema_file = open(pickle_filename, "w")
        dump(schema, schema_file, 1)
        schema_file.close()
        announce("\n")


# -- Main program ------------------------------------------------------

class ProjectDescription:
    """
    Instance attributes:
      atomic_types : [(any, string) | any]
      forward_classes : [string]
      type_aliases : [(alias:string, alias_expansion:string)]
      prefix : string
      dirs : [string]
      extra_modules : [string | (string, string)]
      exclude_modules : [string]
      exclude_classes : [string]

      post_parse_hook : function
    """

    def __init__ (self):
        self.atomic_types = []
        self.forward_classes = []
        self.type_aliases = []
        self.prefix = None
        self.dirs = []
        self.extra_modules = []
        self.exclude_modules = []
        self.exclude_classes = []

        self.post_parse_hook = None

    def read (self, filename):
        data = {}
        if not os.path.isfile(filename):
            raise ValueError, "no such file: %s" % filename
        try:
            execfile(filename, data)
        except:
            (t, v, tb) = sys.exc_info()
            error("error in %s:" % filename)
            exc = string.join(traceback.format_exception_only(t, v), "")
            sys.stderr.write(exc)
            sys.exit(1)

        for (name, value) in data.items():
            if hasattr(self, name):
                setattr(self, name, value)

        self.check_types(filename)

    def check_types (self, filename):
        from grouch.context import TypecheckContext

        schema = ObjectSchema()
        schema.add_atomic_type(lambda: None) # add 'function' type
        cdef = ClassDefinition(self.__class__.__name__, schema)
        schema.add_class(cdef)
        klass = ClassInfo("ProjectDescription",
                          "ProjectDescription",
                          [], self.__class__.__doc__)
        errors = parse_docstring(klass, schema)
        assert not errors, "errors in my own docstring!"
        cdef.finish_definition()
        context = TypecheckContext(report_errors=0)
        schema.check_value(self, context)
        if context.num_errors() > 0:
            error("type errors in %s:" % filename)
            context.write_errors(sys.stderr)
            sys.exit(1)

    def add_extra_modules (self, modules, base_dir):
        """add_extra_modules(modules : [(string, string)])

        Adds the modules listed in self.extra_modules to modules.  A bit
        tricky because extra_modules might not include filenames, in
        which case we have to go out and find the files.
        """

        # extra_modules is a list of any of the following
        #   string
        #     fully-qualified module name -- we'll have to search
        #     sys.path to find the file
        #   (modname : string, filename : string)
        #     module name with filename; we have to make sure the
        #     file exists, and try it relative to base_dir if not

        # yow! this is hairy and undertested!

        for module in self.extra_modules:
            if type(module) is types.StringType:
                modname = module
                comps = module.split(".")
                tail = os.path.join(*comps) + ".py"
                for dir in sys.path:
                    filename = os.path.join(dir, tail)
                    if os.path.exists(filename):
                        break
                else:
                    raise ValueError(
                        "no such module %s: %s not found in sys.path"
                        % (module, tail))

            elif type(module) is types.TupleType and len(module) == 2:
                (modname, filename) = module
                if os.path.exists(filename):
                    pass                # found it, good
                elif base_dir:          # we have a second chance
                    filename2 = os.path.join(base_dir, filename)
                    if os.path.exists(filename2):
                        filename = filename2
                    else:
                        raise ValueError(
                            "no such module %s: neither %s nor %s exist"
                            % (modname, filename, filename2))
                else:
                    raise ValueError(
                        "no such module %s: %s does not exist"
                        % (modname, filename))
            else:
                raise TypeError("bad extra_modules")

            assert os.path.exists(filename)
            modules.append((modname, filename))
            
        # for module

    # add_extra_modules ()
                        

def main ():
    global VERBOSITY                    # global because announce(
    
    prog = os.path.basename(sys.argv[0])
    args = sys.argv[1:]

    usage = """\
usage: %s [options]
options:
  -v                        verbose: run noisily (repeat for more noise)
  -q                        run quietly
  -p FILE, --project=FILE   read project description info from FILE
  -o FILE, --output=FILE    output file: write pickled schema to FILE
                            [default: schema.pkl]
  -t FILE, --text=FILE      text output: write human-readable schema to FILE
                            [default: none]
  -d DIR, --base-dir=DIR    interpret DIRS in project description file
                            relative to DIR
""" % prog

    try:
        (opts, args) = getopt.getopt(args, "p:o:t:d:vq",
                                     ["project=",
                                      "output=",
                                      "text=",
                                      "base-dir=",])
    except getopt.error, msg:
        sys.exit(usage + str(msg))

    project = ProjectDescription()
    VERBOSITY = 1
    text_filename = None
    pickle_filename = "schema.pkl"
    base_dir = None
    prefix = None
    for (opt, val) in opts:
        if opt in ("-p", "project"):
            project.read(val)
        elif opt in ("-o", "--output"):
            pickle_filename = val
        elif opt in ("-t", "text"):
            text_filename = val
        elif opt in ("-d", "--base-dir"):
            base_dir = val
        elif opt == "-v":
            VERBOSITY += 1
        elif opt == "-q":
            VERBOSITY = 0

    if len(args) != 0:
        raise SystemExit, usage + "error: too many arguments"

    # Find modules in project.dirs, classes in those modules, and parse
    # the docstring of each class.  The result is an ObjectSchema, a
    # collection of atomic types, type aliases, and class definitions.
    schema = generate_schema(project, base_dir)

    # Write the schema out to either or both of a text and pickle file.
    write_schema(schema, text_filename, pickle_filename)

# main ()


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        sys.exit("interrupted")
