1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

use std::convert::TryInto;
use std::io::Read;
use std::path::Path;

use once_cell::sync::Lazy;
use thiserror::Error;

use crate::ir::IRModule;
use crate::python;
use crate::runtime::{map::Map, Function, Module as RtModule, NDArray, String};

#[derive(Error, Debug)]
pub enum Error {
    #[error("{0}")]
    IO(#[from] std::io::Error),
    #[error("{0}")]
    TVM(#[from] crate::errors::Error),
}

static TVM_BUILD: Lazy<Function> = Lazy::new(|| {
    python::import("tvm").unwrap();
    python::import("tvm.relay").unwrap();
    Function::get("tvm.relay.build").unwrap()
});

fn _compile_module(
    module: IRModule,
    target: String,
    target_host: String,
    params: Map<String, NDArray>,
    module_name: String,
) -> Result<RtModule, Error> {
    // The RAW API is Fn(IRModule, String, String, Map<String, NDArray>, String);
    let module = TVM_BUILD.invoke(vec![
        (&module).into(),
        (&target).into(),
        (&target_host).into(),
        (&params).into(),
        (&module_name).into(),
    ])?;
    let module: RtModule = module.try_into().unwrap();
    Ok(module)
}

#[derive(Debug)]
pub struct CompilerConfig {
    target: Option<String>,
    target_host: Option<String>,
    params: Map<String, NDArray>,
    module_name: Option<String>,
}

impl Default for CompilerConfig {
    fn default() -> Self {
        CompilerConfig {
            target: None,
            target_host: None,
            params: Map::empty(),
            module_name: None,
        }
    }
}

/// Compile a module from a configuration and IRModule.
///
/// # Arguments
///
/// * `config` - The configuration for the compiler.
/// * `module` - The IRModule to compile.
pub fn compile_module(config: CompilerConfig, module: IRModule) -> Result<RtModule, Error> {
    let target = config.target.unwrap_or("llvm".into());
    _compile_module(
        module,
        target,
        "llvm".into(),
        Map::<String, NDArray>::empty(),
        "default".into(),
    )
}

/// Compile an IRModule on disk and output a runtime module to disk.
///
/// # Arguments
/// * `config` - The configuration for the compiler.
/// * `ir_mod_path` - The path the serialized IRModule.
//
/// * `output_rt_mod_path` - The path to the output runtime module.
pub fn compile_from_disk<P1, P2>(
    config: CompilerConfig,
    ir_mod_path: P1,
    output_rt_mod_path: P2,
) -> Result<(), Error>
where
    P1: AsRef<Path>,
    P2: AsRef<Path>,
{
    let mut input_file = std::fs::File::open(ir_mod_path.as_ref())?;
    let mut input_module_text = std::string::String::new();
    input_file.read_to_string(&mut input_module_text)?;
    let input_module = IRModule::parse("name", input_module_text)?;
    let rt_module = compile_module(config, input_module)?;
    let output_path_str = output_rt_mod_path.as_ref().display().to_string();
    rt_module.export_library(output_path_str)?;
    Ok(())
}