Line data Source code
1 : //
2 : // Copyright (c) 2016 The ANGLE Project Authors. All rights reserved.
3 : // Use of this source code is governed by a BSD-style license that can be
4 : // found in the LICENSE file.
5 : //
6 : // Implementation of the integer pow expressions HLSL bug workaround.
7 : // See header for more info.
8 :
9 : #include "compiler/translator/ExpandIntegerPowExpressions.h"
10 :
11 : #include <cmath>
12 : #include <cstdlib>
13 :
14 : #include "compiler/translator/IntermNode.h"
15 :
16 : namespace sh
17 : {
18 :
19 : namespace
20 : {
21 :
22 0 : class Traverser : public TIntermTraverser
23 : {
24 : public:
25 : static void Apply(TIntermNode *root, unsigned int *tempIndex);
26 :
27 : private:
28 : Traverser();
29 : bool visitAggregate(Visit visit, TIntermAggregate *node) override;
30 : void nextIteration();
31 :
32 : bool mFound = false;
33 : };
34 :
35 : // static
36 0 : void Traverser::Apply(TIntermNode *root, unsigned int *tempIndex)
37 : {
38 0 : Traverser traverser;
39 0 : traverser.useTemporaryIndex(tempIndex);
40 0 : do
41 : {
42 0 : traverser.nextIteration();
43 0 : root->traverse(&traverser);
44 0 : if (traverser.mFound)
45 : {
46 0 : traverser.updateTree();
47 : }
48 0 : } while (traverser.mFound);
49 0 : }
50 :
51 0 : Traverser::Traverser() : TIntermTraverser(true, false, false)
52 : {
53 0 : }
54 :
55 0 : void Traverser::nextIteration()
56 : {
57 0 : mFound = false;
58 0 : nextTemporaryIndex();
59 0 : }
60 :
61 0 : bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node)
62 : {
63 0 : if (mFound)
64 : {
65 0 : return false;
66 : }
67 :
68 : // Test 0: skip non-pow operators.
69 0 : if (node->getOp() != EOpPow)
70 : {
71 0 : return true;
72 : }
73 :
74 0 : const TIntermSequence *sequence = node->getSequence();
75 0 : ASSERT(sequence->size() == 2u);
76 0 : const TIntermConstantUnion *constantNode = sequence->at(1)->getAsConstantUnion();
77 :
78 : // Test 1: check for a single constant.
79 0 : if (!constantNode || constantNode->getNominalSize() != 1)
80 : {
81 0 : return true;
82 : }
83 :
84 0 : const TConstantUnion *constant = constantNode->getUnionArrayPointer();
85 :
86 0 : TConstantUnion asFloat;
87 0 : asFloat.cast(EbtFloat, *constant);
88 :
89 0 : float value = asFloat.getFConst();
90 :
91 : // Test 2: value is in the problematic range.
92 0 : if (value < -5.0f || value > 9.0f)
93 : {
94 0 : return true;
95 : }
96 :
97 : // Test 3: value is integer or pretty close to an integer.
98 0 : float absval = std::abs(value);
99 0 : float frac = absval - std::round(absval);
100 0 : if (frac > 0.0001f)
101 : {
102 0 : return true;
103 : }
104 :
105 : // Test 4: skip -1, 0, and 1
106 0 : int exponent = static_cast<int>(value);
107 0 : int n = std::abs(exponent);
108 0 : if (n < 2)
109 : {
110 0 : return true;
111 : }
112 :
113 : // Potential problem case detected, apply workaround.
114 0 : nextTemporaryIndex();
115 :
116 0 : TIntermTyped *lhs = sequence->at(0)->getAsTyped();
117 0 : ASSERT(lhs);
118 :
119 0 : TIntermDeclaration *init = createTempInitDeclaration(lhs);
120 0 : TIntermTyped *current = createTempSymbol(lhs->getType());
121 :
122 0 : insertStatementInParentBlock(init);
123 :
124 : // Create a chain of n-1 multiples.
125 0 : for (int i = 1; i < n; ++i)
126 : {
127 0 : TIntermBinary *mul = new TIntermBinary(EOpMul, current, createTempSymbol(lhs->getType()));
128 0 : mul->setLine(node->getLine());
129 0 : current = mul;
130 : }
131 :
132 : // For negative pow, compute the reciprocal of the positive pow.
133 0 : if (exponent < 0)
134 : {
135 0 : TConstantUnion *oneVal = new TConstantUnion();
136 0 : oneVal->setFConst(1.0f);
137 0 : TIntermConstantUnion *oneNode = new TIntermConstantUnion(oneVal, node->getType());
138 0 : TIntermBinary *div = new TIntermBinary(EOpDiv, oneNode, current);
139 0 : current = div;
140 : }
141 :
142 0 : queueReplacement(node, current, OriginalNode::IS_DROPPED);
143 0 : mFound = true;
144 0 : return false;
145 : }
146 :
147 : } // anonymous namespace
148 :
149 0 : void ExpandIntegerPowExpressions(TIntermNode *root, unsigned int *tempIndex)
150 : {
151 0 : Traverser::Apply(root, tempIndex);
152 0 : }
153 :
154 : } // namespace sh
|