1

I am working on a static multidimensional array contraction framework, and I have encountered a problem which is somewhat difficult to explain but I will try my best. Suppose we have a N dimensional array class

template<typename T, int ... dims>
class Array {}

which could be instantiated as

Array<double> scalar;
Array<double,4> vector_of_4s;
Array<float,2,3> matrix_of_2_by_3;
// and so on

Now we have another class called Indices

template<int ... Idx>
struct Indices {}

I have a function contraction now whose signature should look like the following

template<T, int ... Dims, int ... Idx, 
typename std::enable_if<sizeof...(Dims)==sizeof...(Idx),bool>::type=0>
Array<T,apply_to_dims<Dims...,do_contract<Idx...>>> 
contraction(const Indices<Idx...> &idx, const Array<T,Dims...> &a)

I may not have gotten the syntax right here, but I essentially want the returned Array to have a dimension based on the entries of Indices. Let me provide examples of what a contraction can perform. Note that, in this context, contraction means removal of dimensions for which the parameters in index list is equal.

auto arr = contraction(Indices<0,0>, Array<double,3,3>) 
// arr is Array<double> as both indices contract 0==0

auto arr = contraction(Indices<0,1>, Array<double,3,3>) 
// arr is Array<double,3,3> as no contraction happens here, 0!=1

auto arr = contraction(Indices<0,1,0>, Array<double,3,4,3>) 
// arr is Array<double,4> as 1st and 3rd indices contract 0==0  

auto arr = contraction(Indices<0,1,0,7,7,2>, Array<double,3,4,3,5,5,6>) 
// arr is Array<double,4,6> as (1st and 3rd, 0==0) and (4th and 5th, 7==7) indices contract

auto arr = contraction(Indices<10,10,2,3>, Array<double,5,6,4,4>
// should not compile as contraction between 1st and 2nd arguments 
// requested but dimensions don't match 5!=6

// The parameters of Indices really do not matter as long as 
// we can identify contractions. They are typically expressed as enums, I,J,K...

So essentially, given Idx... and Dims... which should both be of equal size, check which values in Idx... are equal, get the positions at which they occur and remove the corresponding entries (positions) in Dims.... This is essentially a tensor contraction rule.

Rules of array contraction:

  1. The number of parameters of indices and the dimension/rank of the array should be the same, i.e. sizeof...(Idx)==sizeof...(Dims)
  2. There is one-to-one correspondence bewteen Idx and Dims i.e. if we have Indices<0,1,2> and Array<double,4,5,6>, 0 maps to 4, 1 maps to 5 and 2 maps to 6.
  3. If there are identical/equal values in Idx, that means a contraction, meaning the corresponding dimensions in Dims should vanish, for example, if we have Indices<0,0,3> and Array<double,4,4,6>, then 0==0 and the corresponding dimensions that these values map to which are 4 and 4 both need to vanish and the resulting array should be Array<double,6>
  4. If Idx has identical values, but the corresponding Dims don't match, then a compile time error should be triggered, for instance, Indices<0,0,3> and Array<double,4,5,6> is not possible as 4!=5, similarly Indices<0,1,0> would not be possible as 4!=6, which leads to
  5. No contraction is possible for arrays with different dimensions, for instance Array<double,4,5,6> cannot be contracted in any which way.
  6. Multiple pairs, triplets, quadruplets and so on, is allowed for Idx as long as the corresponding Dims also match, for instance Indices<0,0,0,0,1,1,4,3,3,7,7,7> would contract to an Array<double,6>, given the input array was Array<double,2,2,2,2,3,3,6,2,2,3,3,3>.

My knowledge of metaprogamming does not go that far to achieve this functionality, but I hope I have made the intent clear, for someone to guide me in the right direction.

romeric
  • 2,325
  • 3
  • 19
  • 35
  • 5
    I cannot figure out what your rule is for contraction. Given `Idx...` and `Dims...`, what should be the output dimensions? Can you provide a set of rules, rather than a set of examples? – md5i May 17 '16 at 13:05
  • Essentially, given `Idx...` and `Dims...` which should both be of equal size, check which values in `Idx...` are equal, get the positions at which they occur and remove the corresponding entries in `Dims...`. – romeric May 17 '16 at 13:08
  • @romeric - you can have only couples of equal values in `Idx` or even triplets, etc? In case of triplets, what's the rule? – max66 May 17 '16 at 14:31
  • Potentially, you can have as many equal values, for instance for triplets `contraction(Indices<0,0,1,1,2,2,3>, Array)` would give `Array` as (1st and 2nd, 0==0), (3rd and 4th, 1==1), (5th and 6th, 2==2) would all contract and vanish. – romeric May 17 '16 at 14:36
  • Can I have Indices<0,0,0,0,1,1,1,2,2,3> ? – Shangtong Zhang May 17 '16 at 14:38
  • Yes, that would collapse to `3`, as all other parameters have more than one value appearing in the index list. – romeric May 17 '16 at 14:39
  • @ romeric - and what about `Indices<0, 0, 0, 5>` with `Array` ? I mean, when you have a triplet in `Indices` and only two of the correspondant indices in `Array<>` are equal? – max66 May 17 '16 at 14:47
  • That should trigger a compile time error as the corresponding dimensions and the equal indices don't match. – romeric May 17 '16 at 14:49
  • @ romeric - do you need a solution in C+11 or a C++14 solution is good enough? – max66 May 17 '16 at 14:59
  • @max66 Ideally, C++11. – romeric May 17 '16 at 15:00

2 Answers2

3

A bunch of constexpr functions that do the actual checking:

// is ind[i] unique in ind?
template<size_t N>
constexpr bool is_uniq(const int (&ind)[N], size_t i, size_t cur = 0){
    return cur == N ? true : 
           (cur == i || ind[cur] != ind[i]) ? is_uniq(ind, i, cur + 1) : false;
}

// For every i where ind[i] == index, is dim[i] == dimension?
template<size_t N>
constexpr bool check_all_eq(int index, int dimension,
                            const int (&ind)[N], const int (&dim)[N], size_t cur = 0) {
    return cur == N ? true :
           (ind[cur] != index || dim[cur] == dimension) ? 
                check_all_eq(index, dimension, ind, dim, cur + 1) : false;
}

// if position i should be contracted away, return -1, otherwise return dim[i].
// triggers a compile-time error when used in a constant expression on mismatch.
template<size_t N>
constexpr int calc(size_t i, const int (&ind)[N], const int (&dim)[N]){
    return is_uniq(ind, i) ? dim[i] :
           check_all_eq(ind[i], dim[i], ind, dim) ? -1 : throw "dimension mismatch";
}

Now we need a way to get rid of the -1s:

template<class Ind, class... Inds>
struct concat { using type = Ind; };
template<int... I1, int... I2, class... Inds>
struct concat<Indices<I1...>, Indices<I2...>, Inds...>
    :  concat<Indices<I1..., I2...>, Inds...> {};

// filter out all instances of I from Is...,
// return the rest as an Indices    
template<int I, int... Is>
struct filter
    :  concat<typename std::conditional<Is == I, Indices<>, Indices<Is>>::type...> {};

Use them:

template<class Ind, class Arr, class Seq>
struct contraction_impl;

template<class T, int... Ind, int... Dim, size_t... Seq>
struct contraction_impl<Indices<Ind...>, Array<T, Dim...>, std::index_sequence<Seq...>>{
    static constexpr int ind[] = { Ind... };
    static constexpr int dim[] = { Dim... };
    static constexpr int result[] = {calc(Seq, ind, dim)...};

    template<int... Dims>
    static auto unpack_helper(Indices<Dims...>) -> Array<T, Dims...>;

    using type = decltype(unpack_helper(typename filter<-1,  result[Seq]...>::type{}));
};


template<class T, int ... Dims, int ... Idx, 
typename std::enable_if<sizeof...(Dims)==sizeof...(Idx),bool>::type=0>
typename contraction_impl<Indices<Idx...>, Array<T,Dims...>, 
                          std::make_index_sequence<sizeof...(Dims)>>::type
contraction(const Indices<Idx...> &idx, const Array<T,Dims...> &a);

Everything except make_index_sequence is C++11. You can find a gazillion implementations of that on SO.

T.C.
  • 133,968
  • 17
  • 288
  • 421
  • I never thought of passing around index sequences as constexpr list-initializers. Neat! – md5i May 17 '16 at 18:24
  • @T.C. Can't get your solution to work under `-std=c++11`. Compiles fine with `-std=c++14`. I am using [this](http://stackoverflow.com/a/32223343/2750396) implementation for `make_index_sequence`. – romeric May 17 '16 at 22:45
  • @romeric That's not a proper implementation, but if you want to use it, you need `typename make_index_sequence::type`. – T.C. May 17 '16 at 22:49
1

This is a mess, but I think it does what you want it to do. There are almost certainly many simplifications that could be made to this, but this is my first pass that passes tests. Please note that this does not implement contraction, but just determines what the type should be. If that is not what you needed, I apologize in advance.

#include <type_traits>

template <std::size_t...>
struct Indices {};

template <typename, std::size_t...>
struct Array {};

// Count number of 'i' in 'rest...', base case
template <std::size_t i, std::size_t... rest>
struct Count : std::integral_constant<std::size_t, 0>
{};

// Count number of 'i' in 'rest...', inductive case
template <std::size_t i, std::size_t j, std::size_t... rest>
struct Count<i, j, rest...> :
    std::integral_constant<std::size_t,
                           Count<i, rest...>::value + ((i == j) ? 1 : 0)>
{};

// Is 'i' contained in 'rest...'?
template <std::size_t i, std::size_t... rest>
struct Contains :
    std::integral_constant<bool, (Count<i, rest...>::value > 0)>
{};


// Accumulation of counts of indices in all, base case
template <typename All, typename Remainder,
          typename AccIdx, typename AccCount>
struct Counts {
    using indices = AccIdx;
    using counts = AccCount;
};

// Accumulation of counts of indices in all, inductive case
template <std::size_t... all, std::size_t i, std::size_t... rest,
          std::size_t... indices, std::size_t... counts>
struct Counts<Indices<all...>, Indices<i, rest...>,
              Indices<indices...>, Indices<counts...>>
    : std::conditional<Contains<i, indices...>::value,
                       Counts<Indices<all...>, Indices<rest...>,
                              Indices<indices...>,
                              Indices<counts...>>,
                       Counts<Indices<all...>, Indices<rest...>,
                              Indices<indices..., i>,
                              Indices<counts...,
                                      Count<i, all...>::value>>>::type
{};

// Get value in From that matched the first value of Idx that matched idx
template <std::size_t idx, typename Idx, typename From>
struct First : std::integral_constant<std::size_t, 0>
{};
template <std::size_t i, std::size_t j, std::size_t k,
          std::size_t... indices, std::size_t... values>
struct First<i, Indices<j, indices...>, Indices<k, values...>>
    : std::conditional<i == j,
                       std::integral_constant<std::size_t, k>,
                       First<i, Indices<indices...>,
                             Indices<values...>>>::type
{};

// Return whether all values in From that match Idx being idx are tgt
template <std::size_t idx, std::size_t tgt, typename Idx, typename From>
struct AllMatchTarget : std::true_type
{};
template <std::size_t idx, std::size_t tgt,
          std::size_t i, std::size_t j,
          std::size_t... indices, std::size_t... values>
struct AllMatchTarget<idx, tgt,
                      Indices<i, indices...>, Indices<j, values...>>
    : std::conditional<i == idx && j != tgt, std::false_type,
                       AllMatchTarget<idx, tgt, Indices<indices...>,
                                      Indices<values...>>>::type
{};

/* Generate the dimensions, given the counts, indices, and values */
template <typename Counts, typename Indices,
          typename AllIndices, typename Values, typename Accum>
struct GenDims;

template <typename A, typename V, typename R>
struct GenDims<Indices<>, Indices<>, A, V, R> {
    using type = R;
};
template <typename T, std::size_t i, std::size_t c,
          std::size_t... counts, std::size_t... indices,
          std::size_t... dims, typename AllIndices, typename Values>
struct GenDims<Indices<c, counts...>, Indices<i, indices...>,
               AllIndices, Values, Array<T, dims...>>
{
    static constexpr auto value = First<i, AllIndices, Values>::value;
    static_assert(AllMatchTarget<i, value, AllIndices, Values>::value,
                  "Index doesn't correspond to matching dimensions");
    using type = typename GenDims<
        Indices<counts...>, Indices<indices...>,
        AllIndices, Values,
        typename std::conditional<c == 1,
                                  Array<T, dims..., value>,
                                  Array<T, dims...>>::type>::type;
};

/* Put it all together */
template <typename I, typename A>
struct ContractionType;

template <typename T, std::size_t... indices, std::size_t... values>
struct ContractionType<Indices<indices...>, Array<T, values...>> {
    static_assert(sizeof...(indices) == sizeof...(values),
                   "Number of indices and dimensions do not match");
    using counts = Counts<Indices<indices...>,
                          Indices<indices...>,
                          Indices<>, Indices<>>;
    using type = typename GenDims<typename counts::counts,
                                  typename counts::indices,
                                  Indices<indices...>, Indices<values...>,
                                  Array<T>>::type;
};

static_assert(std::is_same<typename
              ContractionType<Indices<0, 0>, Array<double, 3, 3>>::type,
              Array<double>>::value, "");
static_assert(std::is_same<typename
              ContractionType<Indices<0, 1>, Array<double, 3, 3>>::type,
              Array<double, 3, 3>>::value, "");
static_assert(std::is_same<typename
              ContractionType<Indices<0, 1, 0>, Array<double, 3, 4, 3>>::type,
              Array<double, 4>>::value, "");
static_assert(std::is_same<typename
              ContractionType<Indices<0, 1, 0, 7, 7, 2>,
              Array<double, 3, 4, 3, 5, 5, 6>>::type,
              Array<double, 4, 6>>::value, "");

// Errors appropriately when uncommented
/* static_assert(std::is_same<typename */
/*               ContractionType<Indices<10,10, 2, 3>, */
/*               Array<double, 5,6,4,4>>::type, */
/*               Array<double>::value, ""); */

Here follows an explanation of what is going on here:

  • First I generate, using Counts, a list of the unique indices (Counts::indices) and the number of times each index appears in the sequence (Counts::counts).
  • Then I walk over the index, count pairs from Counts, and for each index, if the count is 1, I accumulate the value and recurse. Otherwise, I pass on the accumulated value and recurse.

The most irritating part is the static_assert in GenDims, which verifies for an index that all matching dimensions are the same.

md5i
  • 3,018
  • 1
  • 18
  • 32